簡體   English   中英

tensorflow keras 中多 class 分類的召回率和精度指標

[英]Recall and precision metrics for multi class classification in tensorflow keras

我們在分類三個類別時遇到了問題。 我們想弄清楚每個 class 的召回率和精度指標。 我們發現tf.keras.metrics中有內置的精度和召回指標。 但這些指標似乎只適用於二元分類。 在我們的 model 中,最后一層是具有活動 function 'softmax' 的密集層。 損失 function 是sparse_categorical_crossentropy ,因為我們對 y 使用了 class label。

output = Dense(3, activation='softmax')(attention_mul)
model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer='Adam', metrics=['accuracy'])

The output of the prediction result is a vector of the probability of the each class, eg [0.3, 0.5, 0.2].To get the class label, we need to apply np.argmax() for the prediction results. 而內置的召回率和精度指標接受 class label 作為輸入。

m = tf.keras.metrics.Recall()
m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
m.result().numpy()

是否有任何解決方案來獲得精度和召回指標,並在訓練的每個時期進行監控?

Keras 中的 Precision 和 Recall 不能用於多類分類問題是有原因的。 由於指標是按批次計算的,因此這兩個指標的結果可能不准確。 實際上 Keras 有一個精度和召回的實現,正是出於這個原因決定刪除。

但是,如果您真的想要,您可以為精確度和召回率創建自定義指標,並將它們傳遞給編譯。

Keras GitHub ,刪除的指標:

def precision(y_true, y_pred):
    """Precision metric.
    Only computes a batch-wise average of precision.
    Computes the precision, a metric for multi-label classification of
    how many selected items are relevant.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision


def recall(y_true, y_pred):
    """Recall metric.
    Only computes a batch-wise average of recall.
    Computes the recall, a metric for multi-label classification of
    how many relevant items are selected.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

添加要compile的指標:

model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer='Adam', metrics=['accuracy', precision, recall])

這樣您就可以按照您的要求在每個時期監控兩個指標。

有 keras 指標項目https://github.com/netrack/keras-metrics 但是,對於當前版本的 Tensorflow,例如 2.7,它沒有維護並且已經過時。 受到項目的啟發,我終於找到了解決方案:我們可以自定義度量函數。 這是代碼:

    def recall(y_true, y_pred, c):
        y_true = K.flatten(y_true)
        pred_c = K.cast(K.equal(K.argmax(y_pred, axis=-1), c), K.floatx())
        true_c = K.cast(K.equal(y_true, c), K.floatx())
        true_positives = K.sum(pred_c * true_c)
        possible_postives = K.sum(true_c)
        return true_positives / (possible_postives + K.epsilon())


    def precision(y_true, y_pred, c):
        y_true = K.flatten(y_true)
        pred_c = K.cast(K.equal(K.argmax(y_pred, axis=-1), c), K.floatx())
        true_c = K.cast(K.equal(y_true, c), K.floatx())
        true_positives = K.sum(pred_c * true_c)
        pred_positives = K.sum(pred_c)
        return true_positives / (pred_positives + K.epsilon())

    def recall_c1(y_true, y_pred):
        return recall(y_true, y_pred, 1)

    def precision_c1(y_true, y_pred):
        return precision(y_true, y_pred, 1)
    
    def recall_c2(y_true, y_pred):
        return recall(y_true, y_pred, 2)

    def precision_c2(self, y_true, y_pred):
        return precision(y_true, y_pred, 2)

我們可以使用precision_c1、recall_c1 來指示類別1 的精度和召回指標,並使用precision_c2、recall_c2 來指示類別2。通過將class_id 值c 傳遞給function 召回(和),還可以支持更多類別。 以下是 model 訓練期間的示例 output:

Epoch 2/2000
24/24 - 35s - loss: 1.1322 - accuracy: 0.0675 - recall_c1: 0.9962 - precision_c1: 0.0676 - recall_c2: 0.0054 - precision_c2: 0.0402 - val_loss: 1.1263 - val_accuracy: 0.0357 - val_recall_c1: 1.0000 - val_precision_c1: 0.0344 - val_recall_c2: 0.0000e+00 - val_precision_c2: 0.0000e+00 - 35s/epoch - 1s/step
Epoch 3/2000
24/24 - 35s - loss: 1.1321 - accuracy: 0.0678 - recall_c1: 0.9873 - precision_c1: 0.0679 - recall_c2: 0.0178 - precision_c2: 0.0876 - val_loss: 1.1254 - val_accuracy: 0.0382 - val_recall_c1: 0.8761 - val_precision_c1: 0.0346 - val_recall_c2: 0.2432 - val_precision_c2: 0.0948 - 35s/epoch - 1s/step

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM