[英]Augmenting data proportionally
I'm facing a classification problem between 2 classes.我面临两个班级之间的分类问题。 Currently I augment the dataset using this code:
目前我使用这段代码扩充数据集:
aug_train_data_gen = ImageDataGenerator(rotation_range=0,
height_shift_range=40,
width_shift_range=40,
zoom_range=0,
horizontal_flip=True,
vertical_flip=True,
fill_mode='reflect',
rescale=1/255.)
aug_train_gen = aug_train_data_gen.flow_from_directory(directory=training_dir,
target_size=(96,96),
color_mode='rgb',
classes=None, # can be set to labels
class_mode='categorical',
batch_size=64,
shuffle= False #set to false if need to compare images
)
But I think that increasing the data of class1 using augmentation would improve my performance, since at the moment class2 images are 6x more than class1, resulting in a CNN that tends to classify images to class2.但我认为使用增强增加 class1 的数据会提高我的性能,因为目前 class2 图像比 class1 多 6 倍,导致 CNN 倾向于将图像分类为 class2。 How can I do so?
我该怎么做?
In order to get a balanced batch you can use the attached class.为了获得平衡的批次,您可以使用随附的 class。
On init you supply a list with multiple datasets.在 init 上,您提供一个包含多个数据集的列表。 A single dataset per a class. The number of the multiple datasets is equal to the number of classes.
每个class的单个数据集。多个数据集的数量等于类的数量。
On runtime, the __ get_item __() chooses randomly among the classes and inside the class a random sample.在运行时,__get_item__() 在类中随机选择,并在 class 中随机选择一个样本。
Best最好
from torch.utils.data import Dataset
class MultipleDataset(Dataset):
"""
Choose randomly from which dataset return an item on each call to __get_item__()
"""
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(MultipleDataset, self).__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.dataset_sizes = [len(d) for d in self.datasets]
self.num_datasets = len(self.datasets)
def __len__(self):
return max(self.dataset_sizes)
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = randint(self.num_datasets)
sample_idx = idx % self.dataset_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.