[英]Pytorch Dataloader not spliting data into batch
我有這樣的數據集 class:
class LoadDataset(Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
def __len__(self):
dlen = len(self.data)
return dlen
def __getitem__(self, index):
return self.data, self.label
然后我加載具有 [485, 1, 32, 32] 形狀的圖像數據集
train_dataset = LoadDataset(xtrain, ytrain)
print(len(train_dataset))
# output 485
然后我用DataLoader
加載數據
train_loader = DataLoader(train_dataset, batch_size=32)
然后我迭代數據:
for epoch in range(num_epoch):
for inputs, labels in train_loader:
print(inputs.shape)
output 打印torch.Size([32, 485, 1, 32, 32])
,它應該是torch.Size([32, 1, 32, 32])
,
誰能幫我?
__getitem__
方法應該返回 1 個數據塊,你返回了所有的數據塊。
試試這個:
class LoadDataset(Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
def __len__(self):
dlen = len(self.data)
llen = len(self.label) # different here
return min(dlen, llen) # different here
def __getitem__(self, index):
return self.data[index], self.label[index] # different here
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.