[英]How to select specific labels in pytorch MNIST dataset
我正在嘗試僅使用 PyTorch Mnist 數據集中的特定數字來創建數據加載器
我已經嘗試創建自己的采樣器,但它不起作用,而且我不確定我是否正確使用了遮罩。
class YourSampler(torch.utils.data.sampler.Sampler):
def __init__(self, mask):
self.mask = mask
def __iter__(self):
return (self.indices[i] for i in torch.nonzero(self.mask))
def __len__(self):
return len(self.mask)
mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)
mask = [True if mnist[i][1] == 5 else False for i in range(len(mnist))]
mask = torch.tensor(mask)
sampler = YourSampler(mask)
trainloader = torch.utils.data.DataLoader(mnist, batch_size=4, sampler = sampler, shuffle=False, num_workers=2)
到目前為止,我有許多不同類型的錯誤。 對於此實現,它是“停止迭代”。 我覺得這很容易/很愚蠢,但我找不到簡單的方法來做到這一點。 謝謝您的幫助!
感謝您的幫助。 過了一會兒,我想出了一個解決方案(但可能根本不是最好的):
class YourSampler(torch.utils.data.sampler.Sampler):
def __init__(self, mask, data_source):
self.mask = mask
self.data_source = data_source
def __iter__(self):
return iter([i.item() for i in torch.nonzero(mask)])
def __len__(self):
return len(self.data_source)
mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)
mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]
mask = torch.tensor(mask)
sampler = YourSampler(mask, mnist)
trainloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,sampler = sampler, shuffle=False, num_workers=workers)
我能想到的最簡單的選擇是就地減少數據集:
indices = dataset.targets == 5 # if you want to keep images with the label 5
dataset.data, dataset.targets = dataset.data[indices], dataset.targets[indices]
您還可以使用torch.utils.data.Subset
如下:
# For indices 5, 6 and 7
indices = [idx for idx, target in enumerate(dataset.targets) if target in [5, 6, 7]]
dataloader = torch.utils.data.DataLoader(Subset(dataset, indices),
batch_size=BATCH_SIZE,
drop_last=True)
當您的迭代器耗盡時,會引發StopIteration
。 你確定你的面具工作正常嗎? 似乎您傳遞了布爾值列表,但 torch.nonzero 會期望浮點數或整數。
你應該寫:
mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]
您還應該需要將數據集傳遞給您的采樣器,例如:
sampler = YourSampler(dataset, mask=mask)
使用這個類定義
class YourSampler(torch.utils.data.sampler.Sampler):
def __init__(self, dataset, mask):
self.mask = mask
self.dataset = dataset
...
有關更多詳細信息,您可以參考 pytorch 文檔(其中顯示了源代碼)以了解他們如何實現更高級的采樣器: https ://pytorch.org/docs/stable/_modules/torch/utils/data/sampler.html# 順序采樣器
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.