I'm trying to create dataloaders using only a specific digit from PyTorch Mnist dataset
I already tried to create my own Sampler but it doesn't work and I'm not sure I'm using correctly the mask.
class YourSampler(torch.utils.data.sampler.Sampler):
def __init__(self, mask):
self.mask = mask
def __iter__(self):
return (self.indices[i] for i in torch.nonzero(self.mask))
def __len__(self):
return len(self.mask)
mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)
mask = [True if mnist[i][1] == 5 else False for i in range(len(mnist))]
mask = torch.tensor(mask)
sampler = YourSampler(mask)
trainloader = torch.utils.data.DataLoader(mnist, batch_size=4, sampler = sampler, shuffle=False, num_workers=2)
So far I had many different types of errors. For this implementation, it's "Stop Iteration". I feel like this is very easy/stupid but I can't find a simple way to do it. Thank you for your help!
Thank you for your help. After a while I figured out a solution (but might not be the best at all):
class YourSampler(torch.utils.data.sampler.Sampler):
def __init__(self, mask, data_source):
self.mask = mask
self.data_source = data_source
def __iter__(self):
return iter([i.item() for i in torch.nonzero(mask)])
def __len__(self):
return len(self.data_source)
mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)
mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]
mask = torch.tensor(mask)
sampler = YourSampler(mask, mnist)
trainloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,sampler = sampler, shuffle=False, num_workers=workers)
The easiest option I can think of is to reduce the data set in-place:
indices = dataset.targets == 5 # if you want to keep images with the label 5
dataset.data, dataset.targets = dataset.data[indices], dataset.targets[indices]
You could also use torch.utils.data.Subset
as following:
# For indices 5, 6 and 7
indices = [idx for idx, target in enumerate(dataset.targets) if target in [5, 6, 7]]
dataloader = torch.utils.data.DataLoader(Subset(dataset, indices),
batch_size=BATCH_SIZE,
drop_last=True)
StopIteration
is raised when your iterator is exhausted. Are you sure your mask is working correctly? it seems like you pass list of boolean values, yet torch.nonzero would expect floats or ints.
You should write:
mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]
You should also need to pass the dataset to your sampler such as:
sampler = YourSampler(dataset, mask=mask)
with this class definition
class YourSampler(torch.utils.data.sampler.Sampler):
def __init__(self, dataset, mask):
self.mask = mask
self.dataset = dataset
...
For more details, you can refer to pytorch documentation(which shows the source code) to see how they implemented more advanced samplers: https://pytorch.org/docs/stable/_modules/torch/utils/data/sampler.html#SequentialSampler
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.