简体   繁体   中英

Pytorch: How exactly dataloader get a batch from dataset?

I am trying to use pytorch to implement self-supervised contrastive learning. There is a phenomenon that I can't understand. Here is my code of transformation to get two augmented views from original data:

class ContrastiveTransformations:
  def __init__(self, base_transforms, n_views=2):
      self.base_transforms = base_transforms
      self.n_views = n_views
    
  def __call__(self, x):
      return [self.base_transforms(x) for i in range(self.n_views)]
contrast_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(size=96),
        transforms.ToTensor(),
    ]
)
data_set = CIFAR10(
    root='/home1/data',
    download=True,
    transform=ContrastiveTransformations(contrast_transforms, n_views=2),
)

As the definition of ContrastiveTransformations , the type of data in my dataset is a list containing two tensors [x_1, x_2] . In my understanding, the batch from the dataloader should have the form of [data_batch, label_batch] , and each item in data_batch is [x_1, x_2] . However, in fact, the form of the batch is in this way: [[batch_x1, batch_x2], label_batch] , which is much more convinient to calculate infoNCE loss. I wonder that how DataLoader implements the fetch of the batch.

I have checked the code of DataLoader in pytorch, it seems that dataloader fetches the data in this way:

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

However I still didn't figure out how the dataloader generates the batch of x1 and x2 separately.

I would be very thankful if someone could give me an explanation.

In order to convert the separate dataset batch elements to an assembled batch, PyTorch's data loaders use a collate function . This defines how the dataloader should assemble the different elements together to form a minibatch

You can define your own collate function and pass it to your data.DataLoader with the collate_fn argument. By default, the collate function used by dataloaders is default_collate defined in torch/utils/data/_utils/collate.py .

This is the behaviour of the default collate function as described in the header of the function:

# Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])

# Example with a batch of `str`s:
>>> default_collate(['a', 'b', 'c'])
['a', 'b', 'c']

# Example with `Map` inside the batch:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}

# Example with `NamedTuple` inside the batch:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))

# Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]

# Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM