简体   繁体   English

火炬 - 对每个数据集进行不同的子采样并将它们连接起来

[英]torch - subsample each dataset differently and concatenate them

I have two datasets, but one is larger than the other and I want to subsample it (resample in each epoch).我有两个数据集,但一个比另一个大,我想对它进行子采样(在每个时期重新采样)。

I probably cannot use dataloader argument sampler, as I would pass to Dataloader the already concatenated dataset.我可能无法使用 dataloader 参数采样器,因为我会将已经连接的数据集传递给 Dataloader。

How do I achieve this simply?我如何简单地实现这一目标?

I think one solution would be to write a class SubsampledDataset(IterableDataset) which would resample every time __iter__ is called (each epoch).我认为一种解决方案是编写一个 class SubsampledDataset(IterableDataset) ,每次调用__iter__时(每个时期)都会重新采样。

(Or better use a map-style dataset, but is there a hook that gets called every epoch, like __iter__ gets?) (或者更好地使用地图风格的数据集,但是有没有像__iter__那样在每个时期都被调用的钩子?)

You can stop to iterate by raising StopIteration .您可以通过提高StopIteration来停止迭代。 This error is caught by Dataloader and simply stop the iteration.这个错误被Dataloader捕获并简单地停止迭代。 So you can do something like that:所以你可以做这样的事情:

class SubDataset(Dataset):
    """SubDataset class."""
    def __init__(self, dataset, length):
        self.dataset = dataset
        self.elem = 0
        self.length = length

    def __getitem__(self, index):
        self.elem += 1
        if self.elem > self.length:
            raise StopIteration  # caught by DataLoader
        return self.dataset[index]

    def __len__(self):
        return len(self.dataset)


if __name__ == '__main__':
    torch.manual_seed(0)
    dataloader = DataLoader(SubDataset(torch.arange(10), 5), shuffle=True)
    for x in dataloader:
        print(x)  # 6, 7, 1, 4, 2
    print(len(dataloader))  # 10!!

Note that setting __len__ to self.length will cause a problem because dataloader will use only indices between 0 and length-1 (that is not what you want).请注意,将__len__设置为self.length会导致问题,因为 dataloader 将仅使用 0 到 length-1 之间的索引(这不是您想要的)。 Unfortunately I found nothing to set the actually length without having this behaviour (due to Dataloader restriction).不幸的是,如果没有这种行为(由于Dataloader限制),我没有发现可以设置实际长度。 Thus be careful: len(dataset) is the original length and dataset.length is the new length.因此要小心: len(dataset)是原始长度, dataset.length是新长度。

This is what I have so far (untested).这是我到目前为止所拥有的(未经测试)。 Usage:用法:

dataset1: Any = ...
# subsample original_dataset2, so that it is equally large in each epoch
dataset2 = RandomSampledDataset(original_dataset2, num_samples=len(dataset1))

concat_dataset = ConcatDataset([dataset1, dataset2])

data_loader = torch.utils.data.DataLoader(
    concat_dataset,
    sampler=RandomSamplerWithNewEpochHook(dataset2.new_epoch_hook, concat_dataset)
)

The result is that the concat_dataset will be shuffled each epoch (RandomSampler), in addition, the dataset2 component is a new sample of the (possibly larger) original_dataset2, different in each epoch.结果是 concat_dataset 将在每个 epoch (RandomSampler) 中进行混洗,此外,dataset2 组件是(可能更大) original_dataset2 的新样本,在每个 epoch 中都不同。

You can add more datasets to be subsampled by doing instead of:您可以通过执行以下操作添加更多要进行子采样的数据集:

sampler=RandomSamplerWithNewEpochHook(dataset2.new_epoch_hook

this:这个:

sampler=RandomSamplerWithNewEpochHook(lambda: dataset2.new_epoch_hook and dataset3.new_epoch_hook and dataset4.new_epoch_hook, ...

Code:代码:

class RandomSamplerWithNewEpochHook(RandomSampler):
    """ Wraps torch.RandomSampler and calls supplied new_epoch_hook before each epoch. """
    
    def __init__(self, new_epoch_hook: Callable, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None):
        super().__init__(data_source, replacement, num_samples, generator)
        self.new_epoch_hook = new_epoch_hook

    def __iter__(self):
        self.new_epoch_hook()
        return super().__iter__()


class RandomSampledDataset(Dataset):
    """ Subsamples a dataset. The sample is different in each epoch.

    This helps when concatenating datasets, as the subsampling rate can be different for each dataset.
    
    Call new_epoch_hook before each epoch. (This can be done using e.g. RandomSamplerWithNewEpochHook.)

    This would be arguably harder to achieve with a concatenated dataset and a sampler argument to Dataloader. The
    sampler would have to be aware of the indices of subdatasets' items in the concatenated dataset, of the subsampling 
    for each subdataset."""
    def __init__(self, dataset, num_samples, transform=lambda im: im):
        self.dataset = dataset
        self.transform = transform
        self.num_samples = num_samples

        self.sampler = RandomSampler(dataset, num_samples=num_samples)
        self.current_epoch_samples = None

    def new_epoch_hook(self):
        self.current_epoch_samples = torch.tensor(iter(self.sampler), dtype=torch.int)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, item):
        if item < 0 or item >= len(self):
            raise IndexError

        img = self.dataset[self.current_epoch_samples[item].item()]

        return self.transform(img)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 在某个索引处连接火炬张量 - Concatenate torch tensor at a certain index 如何计算每组中的记录数并将它们添加到主数据集中? - How to count number of records in each group and add them to main dataset? 如何根据每个类的频率对熊猫数据框进行二次采样? - How to subsample a pandas dataframe respecting the frequency of each class? 如何连接每组的连续行并将它们作为 dataframe 中的列并计算每组的出现次数? - How to concatenate consecutive rows of each group and make them as columns in dataframe and count the occurrences for each group? 如何将每个大块数据正确保存为pandas数据帧并将它们相互连接起来 - How to properly save each large chunk of data as a pandas dataframe and concatenate them with each other 连接新字符串中的字符,然后将它们转换为整数 - concatenate characters in new string then convert them in integers 如何在一行数据框中平铺并将它们连接起来 - How to flat in one row Dataframe and concatenate them 以不同的方式添加图例和颜色每个条 - Adding legend and colour each bar differently 如何将字符串连接到字符串列表的每个元素? - How to concatenate a string to each element of a list of strings? 有什么方法可以连接 3 个或更多 tf.data.Dataset - Is there any ways to concatenate 3 or more tf.data.Dataset
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM