簡體   English   中英

當示例數未完全除以批處理大小時,Pytorch DataLoader失敗

[英]Pytorch DataLoader fails when the number of examples are not exactly divided by the batch size

我在pytorch中編寫了一個自定義數據加載器類。 但是,當在一個紀元內遍歷所有批次時,它將失敗。 例如,假設我有100個數據示例,我的批處理大小為9。它將在第10次迭代中失敗,原因是批處理大小不同,這將使批處理大小為1而不是10。我將自定義數據加載器放在下面。 我還介紹了如何從for循環內的加載程序中提取數據。

class FlatDirectoryAudioDataset(tdata.Dataset): #customized dataloader

    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.files = self.__setup_files()

    def __len__(self):
        """
        compute the length of the dataset
        :return: len => length of dataset
        """
        return len(self.files)

    def __setup_files(self):

        file_names = os.listdir(self.data_dir)
        files = []  # initialize to empty list

        for file_name in file_names:

            possible_file = os.path.join(self.data_dir, file_name)
            if os.path.isfile(possible_file) and (file_name.lower().endswith('.wav') or file_name.lower().endswith('.mp3')): #&& (possible_file.lower().endswith('.wav') or possible_file.lower().endswith('.mp3')):
                files.append(possible_file)

        # return the files list
        return files


    def __getitem__ (self,index):
        sample, _ = librosa.load(self.files[index], 16000)

        if self.transform:
            sample=self.transform(sample)

        sample = torch.from_numpy(sample)    
        return sample


from torch.utils.data import DataLoader 

    my_dataset=FlatDirectoryAudioDataset(source_directory,source_folder,source_label,transform = None,label=True)

dataloader_my = DataLoader(
        my_dataset,
        batch_size=batch_size,
        num_workers=0,
        shuffle=True)


for (i,batch) in enumerate(dataloader_my,0):  
       print(i)
       if batch.shape[0]!=16:
          print(batch.shape)
          assert batch.shape[0]==16,"Something wrong with the batch size"



使用drop_last = True utils.DataLoader(數據集,batch_size = batch_size,隨機播放= True,drop_last = True)

https://pytorch.org/docs/stable/data.html

簡短答案

設置drop_last=True刪除最后一個不完整的批次

長答案

根據您的代碼制作的Dataloader精簡版,批次大小沒有錯誤。

使用9作為batch_size並具有100個項目,最后一個批次只有一個項目。 運行下面的代碼即可生成。

設置drop_last = False,最后一行被打印,並且'exception'被打印。

0 <class 'torch.Tensor'> torch.Size([9, 1])
1 <class 'torch.Tensor'> torch.Size([9, 1])
2 <class 'torch.Tensor'> torch.Size([9, 1])
3 <class 'torch.Tensor'> torch.Size([9, 1])
4 <class 'torch.Tensor'> torch.Size([9, 1])
5 <class 'torch.Tensor'> torch.Size([9, 1])
6 <class 'torch.Tensor'> torch.Size([9, 1])
7 <class 'torch.Tensor'> torch.Size([9, 1])
8 <class 'torch.Tensor'> torch.Size([9, 1])
9 <class 'torch.Tensor'> torch.Size([9, 1])
10 <class 'torch.Tensor'> torch.Size([9, 1])
# depends on drop_last=True|False
11 <class 'torch.Tensor'> torch.Size([1, 1])
Different batch size (last batch) torch.Size([1, 1])

因此該批次產生了足夠好的批次物料,使其總數達到100

from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import torch.utils.data.dataset as tdata


class FlatDirectoryAudioDataset(tdata.Dataset):  # customized dataloader

    def __init__(self):
        self.files = self.__setup_files()

    def __len__(self):
        return len(self.files)

    def __setup_files(self):
        return np.array(range(100))

    def __getitem__(self, index):
        file = self.files[index]
        sample = np.array([file])
        sample = torch.from_numpy(sample)
        return sample


data = FlatDirectoryAudioDataset()

my_dataset = FlatDirectoryAudioDataset()

batch_size = 9

dataloader_my = DataLoader(
    my_dataset,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
    drop_last=True)

for i, sample in enumerate(dataloader_my, 0):
    print(i, print(type(sample), sample.shape)
    if sample.shape[0] != batch_size:
        print("Different batch size (last batch)", sample.shape)

我編寫了一個名為nonechucks的庫來精確地做到這一點(以防萬一您的批處理量不足,而不是因為無法精確划分而存在的不良樣本)。 它使您可以動態處理數據集中的不良樣品(包括自動確定批次大小)。 你可以簡單地包裝現有的PyTorch Dataset與周圍SafeDataset如下:

bad_dataset = Dataset(...)

import nonechucks as nc
dataset = nc.SafeDataset(bad_dataset)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM