簡體   English   中英

Pytorch Dataloader 未將數據拆分為批次

[英]Pytorch Dataloader not spliting data into batch

我有這樣的數據集 class:

class LoadDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
    def __len__(self):
        dlen = len(self.data)
        return dlen
    def __getitem__(self, index):
        return self.data, self.label

然后我加載具有 [485, 1, 32, 32] 形狀的圖像數據集

train_dataset = LoadDataset(xtrain, ytrain)
print(len(train_dataset))
# output 485

然后我用DataLoader加載數據

train_loader = DataLoader(train_dataset, batch_size=32)

然后我迭代數據:

for epoch in range(num_epoch):
        for inputs, labels in train_loader:   
            print(inputs.shape)

output 打印torch.Size([32, 485, 1, 32, 32]) ,它應該是torch.Size([32, 1, 32, 32])

誰能幫我?

__getitem__方法應該返回 1 個數據塊,你返回了所有的數據塊。

試試這個:

class LoadDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
    def __len__(self):
        dlen = len(self.data)
        llen = len(self.label)  # different here
        return min(dlen, llen)  # different here
    def __getitem__(self, index):
        return self.data[index], self.label[index]  # different here

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM