简体   繁体   中英

Is A PyTorch Dataset Accessed by Multiple DataLoader Workers?

When using more than 1 DataLoader workers in PyTorch, does every worker access the same Dataset instance? Or does each DataLoader worker have their own instance of Dataset?

from torch.utils.data import DataLoader, Dataset

class NumbersDataset(Dataset):
    def __init__(self):
        self.samples = list(range(1, 1001))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

dataset = NumbersDataset()
train_loader = DataLoader(dataset, num_workers=4)

It seams like they are accessing to the same instance. I have tried adding a static variable inside the dataset class and incrementing it every time a new instance is created. Code can be found below.

from torch.utils.data import DataLoader, Dataset


class NumbersDataset(Dataset):
    i = 0

    def __init__(self):
        NumbersDataset.i += 1
        self.samples = list(range(1, 1001))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


dataset_1 = NumbersDataset()
train_loader = DataLoader(dataset_1, num_workers=4)

for i, data in enumerate(train_loader):
    pass

dataset_2 = NumbersDataset()
train_loader = DataLoader(dataset_2, num_workers=4)

for i, data in enumerate(train_loader):
    pass

print(NumbersDataset.i)

The output is 2. Hope it helps:D

I've stumbled upon this great thread while trying to figure out if Pytorch's DataLoader is making deepcopy of dataset instances. @yilmazdoga's answer is not the right answer, but by modifying that code slightly, we can get the results that we need:

from torch.utils.data import DataLoader, Dataset


class NumbersDataset(Dataset):
    id_set = set()

    def __init__(self):
        self.samples = list(range(1, 100))

    def __len__(self):
        NumbersDataset.id_set.add(id(self.samples))
        return len(self.samples)

    def __getitem__(self, idx):
        NumbersDataset.id_set.add(id(self.samples))
        return self.samples[idx]


dataset = NumbersDataset()
train_loader = DataLoader(dataset, num_workers=4)

for i, data in enumerate(train_loader):
    pass

print(NumbersDataset.id_set)

What happened? We are storing the id of self.samples object. If DataLoader is making deep copies, then the id of samples should be different, so we would have NumbersDataset.id_set containing 4 elements. But in reality, it contains only 1 element, so each worker is only having a soft copy of the dataset (which is logical).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM