简体   繁体   English

CNN 中的 Class 权重

[英]Class weights in CNN

I have a very unbalanced dataset.我有一个非常不平衡的数据集。 First, I divided this dataset into a training dataset(80%) and a validation dataset(20%).首先,我将这个数据集分为训练数据集(80%)和验证数据集(20%)。 I used the StratifiedShuffleSplit so both datasets preserve each class percentage.我使用了StratifiedShuffleSplit ,因此两个数据集都保留了每个 class 百分比。

To tackle the fact that both datasets are unbalanced I am using the class_weight .为了解决两个数据集不平衡的问题,我使用了class_weight This is the code I use for that:这是我使用的代码:

class_weight = {0: 70.,
                1: 110.,
                2: 82.,
                3: 17.,
                4: 9.}


model.fit(train_generator, epochs = 5, class_weight=(class_weight), validation_data=(x_val, y_val))  

The variable class_weight currently has the number of images of each class of the whole dataset, that is, the combination of training and validation datasets.变量class_weight目前有整个数据集的每个class的图像数量,即训练和验证数据集的组合。 Should it be done like that?应该那样做吗? Or should it have the images of the training dataset?还是应该有训练数据集的图像?

I have another question.我有另一个问题。 Supposing I do data augmentation how can I know for sure the number of images per class?假设我进行数据增强,我怎么能确定每个 class 的图像数量? Is there an automatic calculator or something of sorts?有自动计算器之类的吗?

It seems like you hard-coded some weight values for your classes.您似乎为您的课程硬编码了一些权重值。 However, you can do class weighting using sklearn.utils.class_weight.compute_class_weight to tackle imbalanced dataset.但是,您可以使用sklearn.utils.class_weight.compute_class_weight进行 class 加权来处理不平衡的数据集。 It will compute appropriate weight values based on the occurrence of the classes.它将根据类的出现计算适当的权重值。

# imports 
from sklearn.utils import class_weight

# compute class weight 
# based on appearance of each class in y_trian
cls_wgts = class_weight.compute_class_weight('balanced',
                                             sorted(np.unique(y_train)),
                                             y_train)
# dict mapping
cls_wgts = {i : cls_wgts[i] for i, label in enumerate(sorted(np.unique(y_train)))}

# pass it to fit
model.fit(..., class_weight=cls_wgts)

As far as your second query concerned, if I understand it properly, we usually don't know how much augmentation will happen per class in training time.就您的第二个查询而言,如果我理解正确,我们通常不知道在训练时间内每个 class 会发生多少增强。 But we can take control of settings in the data generator where minor classes would get more augmented compare to major classes.但是我们可以控制数据生成器中的设置,与major类相比, minor类会得到更多的增强。 Additionally, you can also use the weighted cross-entropy loss function here to handle class imbalance.此外,您还可以在此处使用加权交叉熵损失 function 来处理 class 不平衡。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM