簡體   English   中英

CNN 中的 Class 權重

[英]Class weights in CNN

我有一個非常不平衡的數據集。 首先,我將這個數據集分為訓練數據集(80%)和驗證數據集(20%)。 我使用了StratifiedShuffleSplit ,因此兩個數據集都保留了每個 class 百分比。

為了解決兩個數據集不平衡的問題,我使用了class_weight 這是我使用的代碼:

class_weight = {0: 70.,
                1: 110.,
                2: 82.,
                3: 17.,
                4: 9.}


model.fit(train_generator, epochs = 5, class_weight=(class_weight), validation_data=(x_val, y_val))  

變量class_weight目前有整個數據集的每個class的圖像數量,即訓練和驗證數據集的組合。 應該那樣做嗎? 還是應該有訓練數據集的圖像?

我有另一個問題。 假設我進行數據增強,我怎么能確定每個 class 的圖像數量? 有自動計算器之類的嗎?

您似乎為您的課程硬編碼了一些權重值。 但是,您可以使用sklearn.utils.class_weight.compute_class_weight進行 class 加權來處理不平衡的數據集。 它將根據類的出現計算適當的權重值。

# imports 
from sklearn.utils import class_weight

# compute class weight 
# based on appearance of each class in y_trian
cls_wgts = class_weight.compute_class_weight('balanced',
                                             sorted(np.unique(y_train)),
                                             y_train)
# dict mapping
cls_wgts = {i : cls_wgts[i] for i, label in enumerate(sorted(np.unique(y_train)))}

# pass it to fit
model.fit(..., class_weight=cls_wgts)

就您的第二個查詢而言,如果我理解正確,我們通常不知道在訓練時間內每個 class 會發生多少增強。 但是我們可以控制數據生成器中的設置,與major類相比, minor類會得到更多的增強。 此外,您還可以在此處使用加權交叉熵損失 function 來處理 class 不平衡。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM