简体   繁体   中英

How to select specific labels in pytorch MNIST dataset

I'm trying to create dataloaders using only a specific digit from PyTorch Mnist dataset

I already tried to create my own Sampler but it doesn't work and I'm not sure I'm using correctly the mask.

class YourSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, mask):

        self.mask = mask


    def __iter__(self):

        return (self.indices[i] for i in torch.nonzero(self.mask))


    def __len__(self):

        return len(self.mask)


mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)   

mask = [True if mnist[i][1] == 5 else False for i in range(len(mnist))]

mask = torch.tensor(mask)   

sampler = YourSampler(mask)

trainloader = torch.utils.data.DataLoader(mnist, batch_size=4, sampler = sampler, shuffle=False, num_workers=2)

So far I had many different types of errors. For this implementation, it's "Stop Iteration". I feel like this is very easy/stupid but I can't find a simple way to do it. Thank you for your help!

Thank you for your help. After a while I figured out a solution (but might not be the best at all):

class YourSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, mask, data_source):
        self.mask = mask
        self.data_source = data_source

    def __iter__(self):
        return iter([i.item() for i in torch.nonzero(mask)])

    def __len__(self):
        return len(self.data_source)

mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)    
mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]
mask = torch.tensor(mask)   
sampler = YourSampler(mask, mnist)
trainloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,sampler = sampler, shuffle=False, num_workers=workers)

The easiest option I can think of is to reduce the data set in-place:

indices = dataset.targets == 5 # if you want to keep images with the label 5
dataset.data, dataset.targets = dataset.data[indices], dataset.targets[indices]

You could also use torch.utils.data.Subset as following:

# For indices 5, 6 and 7
indices = [idx for idx, target in enumerate(dataset.targets) if target in [5, 6, 7]]
dataloader = torch.utils.data.DataLoader(Subset(dataset, indices),
                                         batch_size=BATCH_SIZE, 
                                         drop_last=True)

StopIteration is raised when your iterator is exhausted. Are you sure your mask is working correctly? it seems like you pass list of boolean values, yet torch.nonzero would expect floats or ints.

You should write:

mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]

You should also need to pass the dataset to your sampler such as:

sampler = YourSampler(dataset, mask=mask)

with this class definition

class YourSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, dataset, mask):

        self.mask = mask
        self.dataset = dataset
...

For more details, you can refer to pytorch documentation(which shows the source code) to see how they implemented more advanced samplers: https://pytorch.org/docs/stable/_modules/torch/utils/data/sampler.html#SequentialSampler

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM