简体   繁体   中英

How does PyTorch DataLoader interact with a PyTorch dataset to transform batches?

I'm creating a custom dataset for NLP-related tasks.

In the PyTorch custom datast tutorial , we see that the __getitem__() method leaves room for a transform before it returns a sample:

def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
       
        ### SOME DATA MANIPULATION HERE ###

        sample = {'image': image, 'landmarks': landmarks}
        if self.transform:
            sample = self.transform(sample)

        return sample

However, the code here:

        if torch.is_tensor(idx):
            idx = idx.tolist()

implies that multiple items should be able to be retrieved at a time which leaves me wondering:

  1. How does that transform work on multiple items? Take the custom transforms in the tutorial for example. They do not look like they could be applied to a batch of samples in a single call.

  2. Related, how does a DataLoader retrieve a batch of multiple samples in parallel and apply said transform if the transform can only be applied to a single sample?

  1. How does that transform work on multiple items? They work on multiple items through use of the data loader. By using transforms, you are specifying what should happen to a single emission of data (eg, batch_size=1 ). The data loader takes your specified batch_size and makes n calls to the __getitem__ method in the torch data set, applying the transform to each sample sent into training/validation. It then collates n samples into your batch size emitted from the data loader.

  2. Related, how does a DataLoader retrieve a batch of multiple samples in parallel and apply said transform if the transform can only be applied to a single sample? Hopefully above makes sense to you. Parallelization is done by the torch data set class and the data loader, where you specify num_workers . Torch will pickle the data set and spread it across workers.

from the documentation of transforms from torchvision :

All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Tensor Image is a tensor with (C, H, W) shape, where C is a number of channels, H and W are image height and width. Batch of Tensor Images is a tensor of (B, C, H, W) shape, where B is a number of images in the batch. Deterministic or random transformations applied on the batch of Tensor Images identically transform all the images of the batch.

This means that you can pass a batch of images, and the transform will be applied to the whole batch, as long as it respects the shape. The list indexes act on the iloc from the dataframe, which selects either a single index or a list of them, returning the requested subset.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM