繁体   English   中英

Pytorch:数据加载器如何从数据集中获取批次?

[英]Pytorch: How exactly dataloader get a batch from dataset?

我正在尝试使用 pytorch 来实现自我监督的对比学习。 有一个我无法理解的现象。 这是我从原始数据中获取两个增强视图的转换代码:

class ContrastiveTransformations:
  def __init__(self, base_transforms, n_views=2):
      self.base_transforms = base_transforms
      self.n_views = n_views
    
  def __call__(self, x):
      return [self.base_transforms(x) for i in range(self.n_views)]
contrast_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(size=96),
        transforms.ToTensor(),
    ]
)
data_set = CIFAR10(
    root='/home1/data',
    download=True,
    transform=ContrastiveTransformations(contrast_transforms, n_views=2),
)

正如ContrastiveTransformations的定义,我的数据集中的数据类型是一个包含两个张量[x_1, x_2]的列表。 据我了解,来自数据加载器的批次应该具有[data_batch, label_batch]的形式,并且data_batch中的每个项目都是[x_1, x_2] 但是实际上batch的形式是这样的: [[batch_x1, batch_x2], label_batch] ,这样计算infoNCE loss就方便多了。 我想知道DataLoader如何实现批次的提取。

我检查了pytorch中DataLoader的代码,似乎dataloader是这样获取数据的:

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

但是我仍然没有弄清楚数据加载器如何分别生成 x1 和 x2 的批次。

如果有人能给我一个解释,我将不胜感激。

为了将单独的数据集批次元素转换为组装批次,PyTorch 的数据加载器使用整理功能 这定义了数据加载器应如何将不同的元素组合在一起以形成小批量

您可以定义自己的 collat​​e 函数,并使用 collat collate_fn参数将其传递给您的data.DataLoader 默认情况下,数据加载器使用的 collat​​e 函数是在torch/utils/data/_utils/collate.py中定义的default_collate ​​e。

这是函数标题中描述的默认 collat​​e 函数的行为:

# Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])

# Example with a batch of `str`s:
>>> default_collate(['a', 'b', 'c'])
['a', 'b', 'c']

# Example with `Map` inside the batch:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}

# Example with `NamedTuple` inside the batch:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))

# Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]

# Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM