简体   繁体   中英

Class weights in CNN

I have a very unbalanced dataset. First, I divided this dataset into a training dataset(80%) and a validation dataset(20%). I used the StratifiedShuffleSplit so both datasets preserve each class percentage.

To tackle the fact that both datasets are unbalanced I am using the class_weight . This is the code I use for that:

class_weight = {0: 70.,
                1: 110.,
                2: 82.,
                3: 17.,
                4: 9.}


model.fit(train_generator, epochs = 5, class_weight=(class_weight), validation_data=(x_val, y_val))  

The variable class_weight currently has the number of images of each class of the whole dataset, that is, the combination of training and validation datasets. Should it be done like that? Or should it have the images of the training dataset?

I have another question. Supposing I do data augmentation how can I know for sure the number of images per class? Is there an automatic calculator or something of sorts?

It seems like you hard-coded some weight values for your classes. However, you can do class weighting using sklearn.utils.class_weight.compute_class_weight to tackle imbalanced dataset. It will compute appropriate weight values based on the occurrence of the classes.

# imports 
from sklearn.utils import class_weight

# compute class weight 
# based on appearance of each class in y_trian
cls_wgts = class_weight.compute_class_weight('balanced',
                                             sorted(np.unique(y_train)),
                                             y_train)
# dict mapping
cls_wgts = {i : cls_wgts[i] for i, label in enumerate(sorted(np.unique(y_train)))}

# pass it to fit
model.fit(..., class_weight=cls_wgts)

As far as your second query concerned, if I understand it properly, we usually don't know how much augmentation will happen per class in training time. But we can take control of settings in the data generator where minor classes would get more augmented compare to major classes. Additionally, you can also use the weighted cross-entropy loss function here to handle class imbalance.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM