简体   繁体   English

不平衡图像数据集 (Tensorflow2)

[英]Imbalanced Image Dataset (Tensorflow2)

I'm trying to do a binary image classification problem, but the two classes (~590 and ~5900 instances, for class 1 and 2, respectively) are heavily skewed, but still quite distinct.我正在尝试解决二进制图像分类问题,但是这两个类(分别为 class 1 和 2 的 ~590 和 ~5900 个实例)严重倾斜,但仍然非常不同。

Is there any way I can fix this, I want to try SMOTE/random weighted oversampling.有什么办法可以解决这个问题,我想尝试 SMOTE/随机加权过采样。

I've tried a lot of different things but I'm stuck.我尝试了很多不同的东西,但我被卡住了。 I've tried using class_weights=[10,1] , [5900,590] , and [1/5900,1/590] and my model still only predicts class 2. I've tried using tf.data.experimental.sample_from_datasets but I couldn't get it to work.我试过使用class_weights=[10,1][5900,590][1/5900,1/590] ,我的 model 仍然只能预测 class 2。我试过使用tf.data.experimental.sample_from_datasets但我无法让它工作。 I've even tried using sigmoid focal cross-entropy loss, which helped a lot but not enough.我什至尝试过使用 sigmoid 焦点交叉熵损失,这有很大帮助,但还不够。

I want to be able to oversample class 1 by a factor of 10, the only thing I have tried that has kinda worked is manually oversampling ie copying the train dir's class 1 instances to match the number of instances in class 2.我希望能够将 class 1 过采样 10 倍,我尝试过的唯一可行的方法是手动过采样,即复制火车目录的 class 1 实例以匹配 ZA2F2ED4F8EBC1BBD4 中的实例数量。

Is there not an easier way of doing this, I'm using Google Colab and so doing this is extremely inefficient.有没有更简单的方法可以做到这一点,我正在使用 Google Colab,所以这样做效率极低。

Is there a way to specify SMOTE params / oversampling within the data generator or similar?有没有办法在数据生成器或类似物中指定 SMOTE 参数/过采样?

data/
...class_1/
........image_1.jpg
........image_2.jpg
...class_2/
........image_1.jpg
........image_2.jpg

My data is in the form shown above.我的数据如上所示。

TRAIN_DATAGEN = ImageDataGenerator(rescale = 1./255.,
                                   rotation_range = 40,
                                   width_shift_range = 0.2,
                                   height_shift_range = 0.2,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

TEST_DATAGEN = ImageDataGenerator(rescale = 1.0/255.)

TRAIN_GENERATOR = TRAIN_DATAGEN.flow_from_directory(directory = TRAIN_DIR,
                                                    batch_size = BACTH_SIZE,
                                                    class_mode = 'binary', 
                                                    target_size = (IMG_HEIGHT, IMG_WIDTH),
                                                    subset = 'training',
                                                    seed = DATA_GENERATOR_SEED)

VALIDATION_GENERATOR = TEST_DATAGEN.flow_from_directory(directory = VALIDATION_DIR,
                                                        batch_size = BACTH_SIZE,
                                                        class_mode = 'binary', 
                                                        target_size = (IMG_HEIGHT, IMG_WIDTH),
                                                        subset = 'validation',
                                                        seed = DATA_GENERATOR_SEED)
...
...
...

HISTORY = MODEL.fit(TRAIN_GENERATOR,
                    validation_data = VALIDATION_GENERATOR,
                    epochs = EPOCHS,
                    verbose = 2,
                    callbacks = [EARLY_STOPPING],
                    class_weight = CLASS_WEIGHT)

I'm relatively new to Tensorflow but I have some experience with ML as a whole.我对 Tensorflow 比较陌生,但我对整个机器学习有一些经验。 I've been tempted to switch to PyTorch several times as they have params for data loaders that automatically (over/under)sample with sampler=WeightedRandomSampler .我一直很想切换到 PyTorch 几次,因为它们有数据加载器的参数,可以使用sampler=WeightedRandomSampler自动(过/过)采样。

Note: I've looked at many tutorials about how to oversample however none of them are image classification problems, I want to stick with TF/Keras as it allows for easy transfer learning, could you guys help out?注意:我看过很多关于如何过采样的教程,但是它们都不是图像分类问题,我想坚持使用 TF/Keras,因为它可以轻松进行迁移学习,你们能帮忙吗?

You can use this strategy to calculate weights based on the imbalance:您可以使用此策略根据不平衡计算权重:

from sklearn.utils import class_weight 
import numpy as np

class_weights = class_weight.compute_class_weight(
           'balanced',
            np.unique(train_generator.classes), 
            train_generator.classes)

train_class_weights = dict(enumerate(class_weights))
model.fit_generator(..., class_weight=train_class_weights)

In Python you can implement SMOTE using imblearn library as follows:在 Python 中,您可以使用imblearn库实现 SMOTE,如下所示:

from imblearn.over_sampling import SMOTE

oversample = SMOTE()
X, y = oversample.fit_resample(X, y)

As you already define your class_weight as a dictionary, eg, {0: 10, 1: 1} , you might try augmenting the minority class.由于您已经将class_weight定义为字典,例如{0: 10, 1: 1} ,您可以尝试增加少数 class。 See balancing an imbalanced dataset with keras image generator and the tutorial (that was mentioned there) at https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html See balancing an imbalanced dataset with keras image generator and the tutorial (that was mentioned there) at https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 使用 tensorflow2/keras 对 imdb 数据集进行训练给出了奇怪的结果 - training on imdb dataset with tensorflow2/keras give strange result 平衡几个不平衡类的图像数据集 - Balancing on the several imbalanced classes of image dataset 使用 Tensorflow Dataset 和 Keras Tuner 处理高度不平衡的数据集 - Dealing with highly imbalanced datasets using Tensorflow Dataset and Keras Tuner 在 Tensorflow2 中分配给切片 - Assign to slice in Tensorflow2 如何在 Tensorflow2 中调试? - How to debug in Tensorflow2? 如何在 tensorflow2 中制作这样的数据集: <prefetchdataset shapes: ((), ()), types: (tf.string, tf.string)></prefetchdataset> - how to make dataset like this in tensorflow2: <PrefetchDataset shapes: ((), ()), types: (tf.string, tf.string)> 具有不平衡数据集的多标签分类 - Multilabel classification with imbalanced dataset 不平衡的数据集对象检测 - Imbalanced dataset Object Detection Tensorflow2 - 使用“tf.data.experimental.make_csv_dataset”和“tf.keras.preprocessing.timeseries_dataset_from_array” - Tensorflow2 - Use "tf.data.experimental.make_csv_dataset" with "tf.keras.preprocessing.timeseries_dataset_from_array" 使用 TENSORFLOW2 的 Kmeans 聚类 - Kmeans clustering using TENSORFLOW2
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM