簡體   English   中英

如何使用數據加載器從 MNIST 數據集中提取特定數字?

[英]How to extract a specific digit from the MNIST dataset with dataloader?

我正在輸入 MNIST 數據集以以下方式訓練我的神經網絡

indices = torch.arange(60000)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
datasetsmall = data_utils.Subset(dataset, indices)
loader = DataLoader(datasetsmall, batch_size=batch_size, shuffle=True)

但是,由於訓練需要大量時間才能完成,我決定僅使用 MNIST 數據集中的特定數字(例如數字 4)來訓練 model。我怎樣才能提取數字 4 並將其輸入到我的神經網絡中一樣的方法。 訓練神經網絡的循環就像

for batch_idx, (real, _) in enumerate(loader):

現在我只想要加載器中的數字 4。 在這種情況下我應該如何進行?

這段代碼能解決你的問題嗎?

import torch
from torchvision import datasets
from torch.utils.data import TensorDataset, DataLoader
from torchvision.transforms import ToTensor

cls = 4 # needed class
batch_size = 32

dataset = datasets.MNIST(root="dataset/", download=True, transform=ToTensor())
dataset = list(filter(lambda i: i[1] == cls, dataset))
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

s = 0
for i in loader:
  s += 1

print(f'We\'ve got {s} batches with batch_size {batch_size} only for class {cls}')

# print(i) # uncomment this line if you want to examine last batch by yourself

結果:

We've got 183 batches with batch_size 32 only for class 4

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM