簡體   English   中英

如何在 pytorch MNIST 數據集中 select 特定標簽

[英]How to select specific labels in pytorch MNIST dataset

我正在嘗試僅使用 PyTorch Mnist 數據集中的特定數字來創建數據加載器

我已經嘗試創建自己的采樣器,但它不起作用,而且我不確定我是否正確使用了遮罩。

class YourSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, mask):

        self.mask = mask


    def __iter__(self):

        return (self.indices[i] for i in torch.nonzero(self.mask))


    def __len__(self):

        return len(self.mask)


mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)   

mask = [True if mnist[i][1] == 5 else False for i in range(len(mnist))]

mask = torch.tensor(mask)   

sampler = YourSampler(mask)

trainloader = torch.utils.data.DataLoader(mnist, batch_size=4, sampler = sampler, shuffle=False, num_workers=2)

到目前為止,我有許多不同類型的錯誤。 對於此實現,它是“停止迭代”。 我覺得這很容易/很愚蠢,但我找不到簡單的方法來做到這一點。 謝謝您的幫助!

感謝您的幫助。 過了一會兒,我想出了一個解決方案(但可能根本不是最好的):

class YourSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, mask, data_source):
        self.mask = mask
        self.data_source = data_source

    def __iter__(self):
        return iter([i.item() for i in torch.nonzero(mask)])

    def __len__(self):
        return len(self.data_source)

mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)    
mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]
mask = torch.tensor(mask)   
sampler = YourSampler(mask, mnist)
trainloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,sampler = sampler, shuffle=False, num_workers=workers)

我能想到的最簡單的選擇是就地減少數據集:

indices = dataset.targets == 5 # if you want to keep images with the label 5
dataset.data, dataset.targets = dataset.data[indices], dataset.targets[indices]

您還可以使用torch.utils.data.Subset如下:

# For indices 5, 6 and 7
indices = [idx for idx, target in enumerate(dataset.targets) if target in [5, 6, 7]]
dataloader = torch.utils.data.DataLoader(Subset(dataset, indices),
                                         batch_size=BATCH_SIZE, 
                                         drop_last=True)

當您的迭代器耗盡時,會引發StopIteration 你確定你的面具工作正常嗎? 似乎您傳遞了布爾值列表,但 torch.nonzero 會期望浮點數或整數。

你應該寫:

mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]

您還應該需要將數據集傳遞給您的采樣器,例如:

sampler = YourSampler(dataset, mask=mask)

使用這個類定義

class YourSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, dataset, mask):

        self.mask = mask
        self.dataset = dataset
...

有關更多詳細信息,您可以參考 pytorch 文檔(其中顯示了源代碼)以了解他們如何實現更高級的采樣器: https ://pytorch.org/docs/stable/_modules/torch/utils/data/sampler.html# 順序采樣器

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM