简体   繁体   中英

Impact of using data shuffling in Pytorch dataloader

I implemented an image classification network to classify a dataset of 100 classes by using Alexnet as a pretrained model and changing the final output layers. I noticed when I was loading my data like

trainloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=False)

, I was getting accuracy on validation dataset around 2-3 % for around 10 epochs but when I just changed shuffle=True and retrained the network, the accuracy jumped to 70% in the first epoch itself.

I was wondering if it happened because in the first case the network was being shown one example after the other continuously for just one class for few instances resulting in network making poor generalizations during training or is there some other reason behind it?

But, I did not expect that to have such a drastic impact.

PS: All the code and parameters were exactly the same for both the cases except changing the shuffle option.

Yes it totally can affect the result! Shuffling the order of the data that we use to fit the classifier is so important, as the batches between epochs do not look alike.

Checking the Data Loader Documentation it says: "shuffle (bool, optional) – set to True to have the data reshuffled at every epoch"

In any case, it will make the model more robust and avoid over/underfitting.

In your case this heavy increase of accuracy (from the lack of awareness of the dataset) probably is due to how the dataset is "organised" as maybe, as an example, each category goes to a different batch, and in every epoch, a batch contains the same category, which derives to a very bad accuracy when you are testing.

PyTorch did many great things, and one of them is the DataLoader class.

DataLoader class takes the dataset (data), sets the batch_size (which is how many samples per batch to load), and invokes the sampler from a list of classes:

  • DistributedSampler
  • SequentialSampler
  • RandomSampler
  • SubsetRandomSampler
  • WeightedRandomSampler
  • BatchSampler


The key thing samplers do is how they implement the iter() method.

In case of SequentionalSampler it looks like this:

def __iter__(self):
    return iter(range(len(self.data_source))) 

This returns an iterator, for every item in the data_source.

When you set shuffle=True that would not use SequentionalSampler , but instead the RandomSampler .

And this may improve the learning process.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM