[英]Class weights in CNN
我有一個非常不平衡的數據集。 首先,我將這個數據集分為訓練數據集(80%)和驗證數據集(20%)。 我使用了StratifiedShuffleSplit
,因此兩個數據集都保留了每個 class 百分比。
為了解決兩個數據集不平衡的問題,我使用了class_weight
。 這是我使用的代碼:
class_weight = {0: 70.,
1: 110.,
2: 82.,
3: 17.,
4: 9.}
model.fit(train_generator, epochs = 5, class_weight=(class_weight), validation_data=(x_val, y_val))
變量class_weight
目前有整個數據集的每個class的圖像數量,即訓練和驗證數據集的組合。 應該那樣做嗎? 還是應該有訓練數據集的圖像?
我有另一個問題。 假設我進行數據增強,我怎么能確定每個 class 的圖像數量? 有自動計算器之類的嗎?
您似乎為您的課程硬編碼了一些權重值。 但是,您可以使用sklearn.utils.class_weight.compute_class_weight
進行 class 加權來處理不平衡的數據集。 它將根據類的出現計算適當的權重值。
# imports
from sklearn.utils import class_weight
# compute class weight
# based on appearance of each class in y_trian
cls_wgts = class_weight.compute_class_weight('balanced',
sorted(np.unique(y_train)),
y_train)
# dict mapping
cls_wgts = {i : cls_wgts[i] for i, label in enumerate(sorted(np.unique(y_train)))}
# pass it to fit
model.fit(..., class_weight=cls_wgts)
就您的第二個查詢而言,如果我理解正確,我們通常不知道在訓練時間內每個 class 會發生多少增強。 但是我們可以控制數據生成器中的設置,與major
類相比, minor
類會得到更多的增強。 此外,您還可以在此處使用加權交叉熵損失 function 來處理 class 不平衡。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.