简体   繁体   English

按比例增加数据

[英]Augmenting data proportionally

I'm facing a classification problem between 2 classes.我面临两个班级之间的分类问题。 Currently I augment the dataset using this code:目前我使用这段代码扩充数据集:

aug_train_data_gen = ImageDataGenerator(rotation_range=0,
                                    height_shift_range=40,
                                    width_shift_range=40,
                                    zoom_range=0,
                                    horizontal_flip=True,
                                    vertical_flip=True, 
                                    fill_mode='reflect',
                                    rescale=1/255.)

aug_train_gen = aug_train_data_gen.flow_from_directory(directory=training_dir,
                                                   target_size=(96,96),
                                                   color_mode='rgb',
                                                   classes=None, # can be set to labels
                                                   class_mode='categorical',
                                                   batch_size=64,
                                                   shuffle= False #set to false if need to compare images
                                                   )

But I think that increasing the data of class1 using augmentation would improve my performance, since at the moment class2 images are 6x more than class1, resulting in a CNN that tends to classify images to class2.但我认为使用增强增加 class1 的数据会提高我的性能,因为目前 class2 图像比 class1 多 6 倍,导致 CNN 倾向于将图像分类为 class2。 How can I do so?我该怎么做?

In order to get a balanced batch you can use the attached class.为了获得平衡的批次,您可以使用随附的 class。

On init you supply a list with multiple datasets.在 init 上,您提供一个包含多个数据集的列表。 A single dataset per a class. The number of the multiple datasets is equal to the number of classes.每个class的单个数据集。多个数据集的数量等于类的数量。

On runtime, the __ get_item __() chooses randomly among the classes and inside the class a random sample.在运行时,__get_item__() 在类中随机选择,并在 class 中随机选择一个样本。

Best最好

from torch.utils.data import Dataset

class MultipleDataset(Dataset):
"""
Choose randomly from which dataset return an item on each call to __get_item__()
"""

def __init__(self, datasets: Iterable[Dataset]) -> None:
    super(MultipleDataset, self).__init__()
    self.datasets = list(datasets)
    assert len(self.datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
    for d in self.datasets:
        assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
    self.dataset_sizes = [len(d) for d in self.datasets]
    self.num_datasets = len(self.datasets)

def __len__(self):
    return max(self.dataset_sizes)

def __getitem__(self, idx):
    if idx < 0:
        if -idx > len(self):
            raise ValueError("absolute value of index should not exceed dataset length")
        idx = len(self) + idx
    dataset_idx = randint(self.num_datasets)
    sample_idx = idx % self.dataset_sizes[dataset_idx]
    return self.datasets[dataset_idx][sample_idx]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM