简体   繁体   中英

DataLoader create dataset with pytorch

I have a folder with subfolders (classes), with images inside each subfolder.

data
  |_ classe1
        |_ image1
        |_ image2
  |_ classe2
        |_ ...

My goal is to create a dataset (train + test set) to train my model with pytorch resnet. I have a error, i dont know how to solve it because i don't really understand the DataLoader structure, so i tried this:

I have this:

dataset = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['data']}

batch_size = 32
validation_split = .3
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=valid_sampler)

dataloaders_dict = {'train': train_loader, 'val': validation_loader}

But when i try to run my model, i have this error:

Epoch 0/99
----------
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-79-8c30eb5e6a01> in <module>()
      3 
      4 # Train and evaluate
----> 5 model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=False)

4 frames
<ipython-input-56-9421c2d39473> in train_model(model, dataloaders, criterion, optimizer, num_epochs, is_inception)
     22 
     23             # Iterate over data.
---> 24             for inputs, labels in dataloaders[phase]:
     25                 inputs = inputs.to(device)
     26                 labels = labels.to(device)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    361 
    362     def __next__(self):
--> 363         data = self._next_data()
    364         self._num_yielded += 1
    365         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    401     def _next_data(self):
    402         index = self._next_index()  # may raise StopIteration
--> 403         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    404         if self._pin_memory:
    405             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

KeyError: 0

Any suggestions? Any errors detected?

The problem most likely comes from your first line, where your dataset is actually a dict containing one element (a pytorch dataset). This would be better:

x = 'data'
dataset = datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])

I assume data_transforms['data'] is a transformation of the expected type (as detailed here ).

The keyerror is probably yielded when pytorch tries to get a tensor from your "dataset" (the dict), which merely contains one element.

By the way, I think pytorch provides the torch.utils.data.random_split` features so that you don't have to do the train/test split yourself. You may want to look it up.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM