简体   繁体   中英

PyTorch DataLoader shuffle

I did an experiment and I did not get the result I was expecting.

For the first part, I am using

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=False, num_workers=0)

I save trainloader.dataset.targets to the variable a , and trainloader.dataset.data to the variable b before training my model. Then, I train the model using trainloader .
After the training is finished, I save trainloader.dataset.targets to the variable c , and trainloader.dataset.data to the variable d . Finally, I check a == c and b == d and they both give True , which was expected because the shuffle parameter of the DataLoader is False .

For the second part, I am using

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=True, num_workers=0)

I save trainloader.dataset.targets to the variable e , and trainloader.dataset.data to the variable f before training my model. Then, I train the model using trainloader . After the training is finished, I save trainloader.dataset.targets to the variable g , and trainloader.dataset.data to the variable h . I expect e == g and f == h to be both False since shuffle=True , but they give True again. What am I missing from the definition of DataLoader class?

I believe that the data that is stored directly in the trainloader.dataset.data or.target will not be shuffled, the data is only shuffled when the DataLoader is called as a generator or as iterator

You can check it by doing next(iter(trainloader)) a few times without shuffling and with shuffling and they should give different results

import torch
import torchvision

transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        ])
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)
dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = False,
                                         num_workers = 10)
target = dataLoader.dataset.targets


MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)

dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = True,
                                         num_workers = 10)

target_shuffled = dataLoader_shuffled.dataset.targets

print(target == target_shuffled)

_, target = next(iter(dataLoader));
_, target_shuffled = next(iter(dataLoader_shuffled))

print(target == target_shuffled)

This will give:

tensor([True, True, True,  ..., True, True, True])
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False,  True, False,
        False, False, False, False, False, False, False, False])

However the data and label stored in data and target is a fixed list and since you are trying to access it directly, they will not be shuffled.

I was facing a similar issue while loading the data using the Dataset class. I stopped loading the data using Dataset class, and instead use the following code which is working fine with me

X = torch.from_numpy(X)
y = torch.from_numpy(y)

train_data = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)

where X & y are numpy array from csv file.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM