[英]torch - subsample each dataset differently and concatenate them
I have two datasets, but one is larger than the other and I want to subsample it (resample in each epoch).我有两个数据集,但一个比另一个大,我想对它进行子采样(在每个时期重新采样)。
I probably cannot use dataloader argument sampler, as I would pass to Dataloader the already concatenated dataset.我可能无法使用 dataloader 参数采样器,因为我会将已经连接的数据集传递给 Dataloader。
How do I achieve this simply?我如何简单地实现这一目标?
I think one solution would be to write a class SubsampledDataset(IterableDataset) which would resample every time __iter__
is called (each epoch).我认为一种解决方案是编写一个 class SubsampledDataset(IterableDataset) ,每次调用__iter__
时(每个时期)都会重新采样。
(Or better use a map-style dataset, but is there a hook that gets called every epoch, like __iter__
gets?) (或者更好地使用地图风格的数据集,但是有没有像__iter__
那样在每个时期都被调用的钩子?)
You can stop to iterate by raising StopIteration
.您可以通过提高StopIteration
来停止迭代。 This error is caught by Dataloader
and simply stop the iteration.这个错误被Dataloader
捕获并简单地停止迭代。 So you can do something like that:所以你可以做这样的事情:
class SubDataset(Dataset):
"""SubDataset class."""
def __init__(self, dataset, length):
self.dataset = dataset
self.elem = 0
self.length = length
def __getitem__(self, index):
self.elem += 1
if self.elem > self.length:
raise StopIteration # caught by DataLoader
return self.dataset[index]
def __len__(self):
return len(self.dataset)
if __name__ == '__main__':
torch.manual_seed(0)
dataloader = DataLoader(SubDataset(torch.arange(10), 5), shuffle=True)
for x in dataloader:
print(x) # 6, 7, 1, 4, 2
print(len(dataloader)) # 10!!
Note that setting __len__
to self.length
will cause a problem because dataloader will use only indices between 0 and length-1 (that is not what you want).请注意,将__len__
设置为self.length
会导致问题,因为 dataloader 将仅使用 0 到 length-1 之间的索引(这不是您想要的)。 Unfortunately I found nothing to set the actually length without having this behaviour (due to Dataloader
restriction).不幸的是,如果没有这种行为(由于Dataloader
限制),我没有发现可以设置实际长度。 Thus be careful: len(dataset)
is the original length and dataset.length
is the new length.因此要小心: len(dataset)
是原始长度, dataset.length
是新长度。
This is what I have so far (untested).这是我到目前为止所拥有的(未经测试)。 Usage:用法:
dataset1: Any = ...
# subsample original_dataset2, so that it is equally large in each epoch
dataset2 = RandomSampledDataset(original_dataset2, num_samples=len(dataset1))
concat_dataset = ConcatDataset([dataset1, dataset2])
data_loader = torch.utils.data.DataLoader(
concat_dataset,
sampler=RandomSamplerWithNewEpochHook(dataset2.new_epoch_hook, concat_dataset)
)
The result is that the concat_dataset will be shuffled each epoch (RandomSampler), in addition, the dataset2 component is a new sample of the (possibly larger) original_dataset2, different in each epoch.结果是 concat_dataset 将在每个 epoch (RandomSampler) 中进行混洗,此外,dataset2 组件是(可能更大) original_dataset2 的新样本,在每个 epoch 中都不同。
You can add more datasets to be subsampled by doing instead of:您可以通过执行以下操作添加更多要进行子采样的数据集:
sampler=RandomSamplerWithNewEpochHook(dataset2.new_epoch_hook
this:这个:
sampler=RandomSamplerWithNewEpochHook(lambda: dataset2.new_epoch_hook and dataset3.new_epoch_hook and dataset4.new_epoch_hook, ...
Code:代码:
class RandomSamplerWithNewEpochHook(RandomSampler):
""" Wraps torch.RandomSampler and calls supplied new_epoch_hook before each epoch. """
def __init__(self, new_epoch_hook: Callable, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None):
super().__init__(data_source, replacement, num_samples, generator)
self.new_epoch_hook = new_epoch_hook
def __iter__(self):
self.new_epoch_hook()
return super().__iter__()
class RandomSampledDataset(Dataset):
""" Subsamples a dataset. The sample is different in each epoch.
This helps when concatenating datasets, as the subsampling rate can be different for each dataset.
Call new_epoch_hook before each epoch. (This can be done using e.g. RandomSamplerWithNewEpochHook.)
This would be arguably harder to achieve with a concatenated dataset and a sampler argument to Dataloader. The
sampler would have to be aware of the indices of subdatasets' items in the concatenated dataset, of the subsampling
for each subdataset."""
def __init__(self, dataset, num_samples, transform=lambda im: im):
self.dataset = dataset
self.transform = transform
self.num_samples = num_samples
self.sampler = RandomSampler(dataset, num_samples=num_samples)
self.current_epoch_samples = None
def new_epoch_hook(self):
self.current_epoch_samples = torch.tensor(iter(self.sampler), dtype=torch.int)
def __len__(self):
return self.num_samples
def __getitem__(self, item):
if item < 0 or item >= len(self):
raise IndexError
img = self.dataset[self.current_epoch_samples[item].item()]
return self.transform(img)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.