简体   繁体   中英

Pytorch DataLoader fails when the number of examples are not exactly divided by the batch size

I have coded a custom data loader class in the pytorch. But it fails when iterating through all the number of batches inside an epoch. For example, think I have 100 data examples and my batch size is 9. It will fail in the 10th iteration saying batch size is different which will give a batch size 1 instead of 10. I have put my custom data loader in below. Also I have put how I extract the data from the loader inside the for-loop.

class FlatDirectoryAudioDataset(tdata.Dataset): #customized dataloader

    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.files = self.__setup_files()

    def __len__(self):
        """
        compute the length of the dataset
        :return: len => length of dataset
        """
        return len(self.files)

    def __setup_files(self):

        file_names = os.listdir(self.data_dir)
        files = []  # initialize to empty list

        for file_name in file_names:

            possible_file = os.path.join(self.data_dir, file_name)
            if os.path.isfile(possible_file) and (file_name.lower().endswith('.wav') or file_name.lower().endswith('.mp3')): #&& (possible_file.lower().endswith('.wav') or possible_file.lower().endswith('.mp3')):
                files.append(possible_file)

        # return the files list
        return files


    def __getitem__ (self,index):
        sample, _ = librosa.load(self.files[index], 16000)

        if self.transform:
            sample=self.transform(sample)

        sample = torch.from_numpy(sample)    
        return sample


from torch.utils.data import DataLoader 

    my_dataset=FlatDirectoryAudioDataset(source_directory,source_folder,source_label,transform = None,label=True)

dataloader_my = DataLoader(
        my_dataset,
        batch_size=batch_size,
        num_workers=0,
        shuffle=True)


for (i,batch) in enumerate(dataloader_my,0):  
       print(i)
       if batch.shape[0]!=16:
          print(batch.shape)
          assert batch.shape[0]==16,"Something wrong with the batch size"



use drop_last=True utils.DataLoader(dataset,batch_size=batch_size,shuffle = True,drop_last=True)

https://pytorch.org/docs/stable/data.html

Short answer

Set drop_last=True to drop the last incomplete batch

Long answer

Based on your code making a reduced version of your Dataloader there is no error for batch sizes.

Using 9 as batch_size and having 100 items the last batch just has one item. Running the code below it produces.

Setting drop_last=False the last line is printed and the 'exception' is printed.

0 <class 'torch.Tensor'> torch.Size([9, 1])
1 <class 'torch.Tensor'> torch.Size([9, 1])
2 <class 'torch.Tensor'> torch.Size([9, 1])
3 <class 'torch.Tensor'> torch.Size([9, 1])
4 <class 'torch.Tensor'> torch.Size([9, 1])
5 <class 'torch.Tensor'> torch.Size([9, 1])
6 <class 'torch.Tensor'> torch.Size([9, 1])
7 <class 'torch.Tensor'> torch.Size([9, 1])
8 <class 'torch.Tensor'> torch.Size([9, 1])
9 <class 'torch.Tensor'> torch.Size([9, 1])
10 <class 'torch.Tensor'> torch.Size([9, 1])
# depends on drop_last=True|False
11 <class 'torch.Tensor'> torch.Size([1, 1])
Different batch size (last batch) torch.Size([1, 1])

So the batch produces good enough batch items to make it to 100 in total

from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import torch.utils.data.dataset as tdata


class FlatDirectoryAudioDataset(tdata.Dataset):  # customized dataloader

    def __init__(self):
        self.files = self.__setup_files()

    def __len__(self):
        return len(self.files)

    def __setup_files(self):
        return np.array(range(100))

    def __getitem__(self, index):
        file = self.files[index]
        sample = np.array([file])
        sample = torch.from_numpy(sample)
        return sample


data = FlatDirectoryAudioDataset()

my_dataset = FlatDirectoryAudioDataset()

batch_size = 9

dataloader_my = DataLoader(
    my_dataset,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
    drop_last=True)

for i, sample in enumerate(dataloader_my, 0):
    print(i, print(type(sample), sample.shape)
    if sample.shape[0] != batch_size:
        print("Different batch size (last batch)", sample.shape)

I wrote a library called nonechucks to do exactly this (in case your batch size is falling short not because of being unable to exactly divide but bad samples existing). It lets you deal with bad samples in your dataset dynamically (including automatically fixing the batch size). You can simply wrap your existing PyTorch Dataset around with a SafeDataset as follows:

bad_dataset = Dataset(...)

import nonechucks as nc
dataset = nc.SafeDataset(bad_dataset)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM