簡體   English   中英

具有不平衡數據集的多標簽分類

[英]Multilabel classification with imbalanced dataset

我正在嘗試做一個多標簽分類問題,它有一個不平衡的數據集。 樣本總數為 1130,在 1130 個樣本中,第一個 class 出現在其中 913 個中。 第二個 class 215 次,第三個 423 次。

在 model 架構中,我有 3 個 output 節點,並應用了 sigmoid 激活。

input_tensor = Input(shape=(256, 256, 3))
base_model = VGG16(input_tensor=input_tensor,weights='imagenet',pooling=None, include_top=False)

#base_model.summary()

x = base_model.output

x = GlobalAveragePooling2D()(x)

x = tf.math.reduce_max(x,axis=0,keepdims=True)

x = Dense(512,activation='relu')(x)

output_1 = Dense(3, activation='sigmoid')(x)

sagittal_model_abn = Model(inputs=base_model.input, outputs=output_1)

for layer in base_model.layers:
    layer.trainable = True

我正在使用二進制交叉熵損失,我使用這個 function 計算。 我正在使用加權損失來處理不平衡。

        if y_true[0]==1:
            loss_abn = -1*K.log(y_pred[0][0])*cwb[0][1]
        elif y_true[0]==0:
            loss_abn = -1*K.log(1-y_pred[0][0])*cwb[0][0]
        if y_true[1]==1:
            loss_acl = -1*K.log(y_pred[0][1])*cwb[1][1]
        elif y_true[1]==0:
            loss_acl = -1*K.log(1-y_pred[0][1])*cwb[1][0]
        if y_true[2]==1:
            loss_men = -1*K.log(y_pred[0][2])*cwb[2][1]
        elif y_true[2]==0:
            loss_men = -1*K.log(1-y_pred[0][2])*cwb[2][0]

        loss_value_ds = loss_abn + loss_acl + loss_men

cwb包含 class 權重。

y_true是長度為 3 的地面實況標簽。

y_pred是一個形狀為 (1,3) 的 numpy 數組

我將類單獨加權為類的出現和不出現。

就像,如果 label 為 1,我將其視為發生,如果為 0,則將其視為未發生。

所以,第一類的 label 1 在 1130 次中出現了 913 次

So the class weight of label 1 for the first class is 1130/913 which is about 1.23 and the weight of the label 0 for the first class is 1130/(1130-913)

當我訓練 model 時,精度會出現波動(或幾乎保持不變),並且損失會減少。

對於每個樣本,我都會得到這樣的預測

[[0.51018655 0.5010625 0.50482965]]

在所有類的每次迭代中,預測值都在 0.49 - 0.51 范圍內

嘗試更改 FC 層中的節點數,但它的行為仍然相同。

任何人都可以幫忙嗎?

使用tf.math,reduce_max會導致問題嗎? 使用@tf.function執行我正在使用tf.math.reduce_max進行的操作是否有益?

注意

我分別為每個 class 加權標簽 1 和 0。

cwb = {0: {0: 5.207373271889401, 1: 1.2376779846659365}, 
       1: {0: 1.2255965292841648, 1: 5.4326923076923075}, 
       2: {0: 1.5416098226466575, 1: 2.8463476070528966}}

編輯

我使用model.fit()訓練時的結果。

Epoch 1/20
1130/1130 [==============================] - 1383s 1s/step - loss: 4.1638 - binary_accuracy: 0.4558 - val_loss: 5.0439 - val_binary_accuracy: 0.3944
Epoch 2/20
1130/1130 [==============================] - 1397s 1s/step - loss: 4.1608 - binary_accuracy: 0.4165 - val_loss: 5.0526 - val_binary_accuracy: 0.5194
Epoch 3/20
1130/1130 [==============================] - 1402s 1s/step - loss: 4.1608 - binary_accuracy: 0.4814 - val_loss: 5.1469 - val_binary_accuracy: 0.6361
Epoch 4/20
1130/1130 [==============================] - 1407s 1s/step - loss: 4.1722 - binary_accuracy: 0.4472 - val_loss: 5.0501 - val_binary_accuracy: 0.5583
Epoch 5/20
1130/1130 [==============================] - 1397s 1s/step - loss: 4.1591 - binary_accuracy: 0.4991 - val_loss: 5.0521 - val_binary_accuracy: 0.6028
Epoch 6/20
1130/1130 [==============================] - 1375s 1s/step - loss: 4.1596 - binary_accuracy: 0.5431 - val_loss: 5.0515 - val_binary_accuracy: 0.5917
Epoch 7/20
1130/1130 [==============================] - 1370s 1s/step - loss: 4.1595 - binary_accuracy: 0.4962 - val_loss: 5.0526 - val_binary_accuracy: 0.6000
Epoch 8/20
1130/1130 [==============================] - 1387s 1s/step - loss: 4.1591 - binary_accuracy: 0.5316 - val_loss: 5.0523 - val_binary_accuracy: 0.6028
Epoch 9/20
1130/1130 [==============================] - 1391s 1s/step - loss: 4.1590 - binary_accuracy: 0.4909 - val_loss: 5.0521 - val_binary_accuracy: 0.6028
Epoch 10/20
1130/1130 [==============================] - 1400s 1s/step - loss: 4.1590 - binary_accuracy: 0.5369 - val_loss: 5.0519 - val_binary_accuracy: 0.6028
Epoch 11/20
1130/1130 [==============================] - 1397s 1s/step - loss: 4.1590 - binary_accuracy: 0.4808 - val_loss: 5.0519 - val_binary_accuracy: 0.6028
Epoch 12/20
1130/1130 [==============================] - 1394s 1s/step - loss: 4.1590 - binary_accuracy: 0.5469 - val_loss: 5.0522 - val_binary_accuracy: 0.6028

我會嘗試 label powerset 方法。

嘗試根據您的標簽和數據集將其設置為可能的組合總數,而不是 3 個 output 節點。 例如,對於具有 3 個不同類別的多標簽分類,有 7 個可能的輸出。

比如說,標簽是 A、B 和 C。 Map output 0 到 A,1 到 B,2 到 C,3 到 AB,4 到 AC 等等。

在訓練和測試之前使用一個簡單的轉換 function,這個問題可以轉換為一個多類、單一的 label 問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM