简体   繁体   中英

Imbalanced Image Dataset (Tensorflow2)

I'm trying to do a binary image classification problem, but the two classes (~590 and ~5900 instances, for class 1 and 2, respectively) are heavily skewed, but still quite distinct.

Is there any way I can fix this, I want to try SMOTE/random weighted oversampling.

I've tried a lot of different things but I'm stuck. I've tried using class_weights=[10,1] , [5900,590] , and [1/5900,1/590] and my model still only predicts class 2. I've tried using tf.data.experimental.sample_from_datasets but I couldn't get it to work. I've even tried using sigmoid focal cross-entropy loss, which helped a lot but not enough.

I want to be able to oversample class 1 by a factor of 10, the only thing I have tried that has kinda worked is manually oversampling ie copying the train dir's class 1 instances to match the number of instances in class 2.

Is there not an easier way of doing this, I'm using Google Colab and so doing this is extremely inefficient.

Is there a way to specify SMOTE params / oversampling within the data generator or similar?

data/
...class_1/
........image_1.jpg
........image_2.jpg
...class_2/
........image_1.jpg
........image_2.jpg

My data is in the form shown above.

TRAIN_DATAGEN = ImageDataGenerator(rescale = 1./255.,
                                   rotation_range = 40,
                                   width_shift_range = 0.2,
                                   height_shift_range = 0.2,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

TEST_DATAGEN = ImageDataGenerator(rescale = 1.0/255.)

TRAIN_GENERATOR = TRAIN_DATAGEN.flow_from_directory(directory = TRAIN_DIR,
                                                    batch_size = BACTH_SIZE,
                                                    class_mode = 'binary', 
                                                    target_size = (IMG_HEIGHT, IMG_WIDTH),
                                                    subset = 'training',
                                                    seed = DATA_GENERATOR_SEED)

VALIDATION_GENERATOR = TEST_DATAGEN.flow_from_directory(directory = VALIDATION_DIR,
                                                        batch_size = BACTH_SIZE,
                                                        class_mode = 'binary', 
                                                        target_size = (IMG_HEIGHT, IMG_WIDTH),
                                                        subset = 'validation',
                                                        seed = DATA_GENERATOR_SEED)
...
...
...

HISTORY = MODEL.fit(TRAIN_GENERATOR,
                    validation_data = VALIDATION_GENERATOR,
                    epochs = EPOCHS,
                    verbose = 2,
                    callbacks = [EARLY_STOPPING],
                    class_weight = CLASS_WEIGHT)

I'm relatively new to Tensorflow but I have some experience with ML as a whole. I've been tempted to switch to PyTorch several times as they have params for data loaders that automatically (over/under)sample with sampler=WeightedRandomSampler .

Note: I've looked at many tutorials about how to oversample however none of them are image classification problems, I want to stick with TF/Keras as it allows for easy transfer learning, could you guys help out?

You can use this strategy to calculate weights based on the imbalance:

from sklearn.utils import class_weight 
import numpy as np

class_weights = class_weight.compute_class_weight(
           'balanced',
            np.unique(train_generator.classes), 
            train_generator.classes)

train_class_weights = dict(enumerate(class_weights))
model.fit_generator(..., class_weight=train_class_weights)

In Python you can implement SMOTE using imblearn library as follows:

from imblearn.over_sampling import SMOTE

oversample = SMOTE()
X, y = oversample.fit_resample(X, y)

As you already define your class_weight as a dictionary, eg, {0: 10, 1: 1} , you might try augmenting the minority class. See balancing an imbalanced dataset with keras image generator and the tutorial (that was mentioned there) at https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM