簡體   English   中英

Keras 多類多標簽圖像分類:處理獨立和從屬標簽和非二進制的混合 output

[英]Keras Multi-class Multi-label image classification: handle a mix of independent and dependent labels & non-binary output

我正在嘗試從 Keras 訓練一個預訓練的 VGG16 model 用於多類多標簽分類任務。 這些圖像來自 NIH 的胸部 X 射線 8 數據集。 該數據集有 14 個標簽(14 種疾病)加上“未發現”label。

我知道對於獨立的標簽,比如 14 種疾病,我應該使用 sigmoid 激活 + binary_crossentropy loss function; 對於依賴標簽,我應該使用 softmax + categorical_crossentropy。

然而,在我總共 15 個標簽中,其中 14 個是獨立的,但一個“未發現”在技術上依賴於 rest 14 -->“未發現”和患有疾病的概率加起來應為1,但是應該獨立給出患有什么疾病的概率。 那么我應該使用什么損失呢?

此外,我的 output 是浮點數(概率)列表,每列是一個 label。

y_true:
[[0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 1.]]

y_predict:
[[0.1749 0.0673 0.1046 ... 0.     0.     0.112 ]
 [0.     0.1067 0.2804 ... 0.     0.     0.722 ]
 [0.     0.     0.0686 ... 0.     0.     0.5373]
 ...
 [0.0571 0.0679 0.0815 ... 0.     0.     0.532 ]
 [0.0723 0.0555 0.2373 ... 0.     0.     0.4263]
 [0.0506 0.1305 0.4399 ... 0.     0.     0.2792]]

這樣的結果使得無法使用classification_report() function 來評估我的 model。 我正在考慮獲得一個將其轉換為二進制的閾值,但這將是更多的人工修改而不是 CNN 預測,因為我必須為 select 設置一個閾值。 所以我不確定我是否應該做一些硬編碼的東西,或者是否有任何其他已經存在的方法來處理這種情況?

我對 CNN 和分類很陌生,所以如果有人可以指導我或給我任何提示,我將非常感激。 謝謝!

主體代碼如下:

vgg16_model = VGG16()
last_layer = vgg16_model.get_layer('fc2').output

#I am treating them all as independent labels
out = Dense(15, activation='sigmoid', name='output_layer')(last_layer)
custom_vgg16_model = Model(inputs=vgg16_model.input, outputs=out)

for layer in custom_vgg16_model.layers[:-1]:
layer.trainable = False

custom_vgg16_model.compile(Adam(learning_rate=0.00001), 
                           loss = "binary_crossentropy", 
                           metrics = ['accuracy'])   # metrics=accuracy gives me very good result, 
                                                     # but I suppose it is due to the large amount 
                                                     # of 0 label(not-this-disease prediction),
                                                     # therefore I am thinking to change it to 
                                                     # recall and precision as metrics. If you have
                                                     # any suggestion on this I'd also like to hear!

關於我的項目的一些更新,實際上我已經設法解決了這個問題中提到的大部分問題。

首先,由於這是一個多類多標簽分類問題,我決定使用 ROC-AUC 分數而不是精確率或召回率作為評估指標。 它的優點是不涉及閾值——AUC 有點像一系列閾值下的性能平均值。 而且它只看正面預測,因此它減少了數據集中大多數 0 的影響。 在我的案例中,這可以更准確地預測模型的性能。

對於 output class,我決定使用 14 個類別而不是 15 個類別——如果所有標簽均為 0,則表示“未找到”。 然后我可以愉快地在我的 output 層中使用 sigmoid 激活。 盡管如此,我還是使用焦點損失而不是二元交叉熵,因為我的數據集高度不平衡。

我仍然面臨問題,因為我的 ROC 不好(非常接近 y=x,有時低於 y=x)。 但我希望我的進步能給任何發現這一點的人一些啟發。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM