繁体   English   中英

tensorflow keras 中多 class 分类的召回率和精度指标

[英]Recall and precision metrics for multi class classification in tensorflow keras

我们在分类三个类别时遇到了问题。 我们想弄清楚每个 class 的召回率和精度指标。 我们发现tf.keras.metrics中有内置的精度和召回指标。 但这些指标似乎只适用于二元分类。 在我们的 model 中,最后一层是具有活动 function 'softmax' 的密集层。 损失 function 是sparse_categorical_crossentropy ,因为我们对 y 使用了 class label。

output = Dense(3, activation='softmax')(attention_mul)
model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer='Adam', metrics=['accuracy'])

The output of the prediction result is a vector of the probability of the each class, eg [0.3, 0.5, 0.2].To get the class label, we need to apply np.argmax() for the prediction results. 而内置的召回率和精度指标接受 class label 作为输入。

m = tf.keras.metrics.Recall()
m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
m.result().numpy()

是否有任何解决方案来获得精度和召回指标,并在训练的每个时期进行监控?

Keras 中的 Precision 和 Recall 不能用于多类分类问题是有原因的。 由于指标是按批次计算的,因此这两个指标的结果可能不准确。 实际上 Keras 有一个精度和召回的实现,正是出于这个原因决定删除。

但是,如果您真的想要,您可以为精确度和召回率创建自定义指标,并将它们传递给编译。

Keras GitHub ,删除的指标:

def precision(y_true, y_pred):
    """Precision metric.
    Only computes a batch-wise average of precision.
    Computes the precision, a metric for multi-label classification of
    how many selected items are relevant.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision


def recall(y_true, y_pred):
    """Recall metric.
    Only computes a batch-wise average of recall.
    Computes the recall, a metric for multi-label classification of
    how many relevant items are selected.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

添加要compile的指标:

model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer='Adam', metrics=['accuracy', precision, recall])

这样您就可以按照您的要求在每个时期监控两个指标。

有 keras 指标项目https://github.com/netrack/keras-metrics 但是,对于当前版本的 Tensorflow,例如 2.7,它没有维护并且已经过时。 受到项目的启发,我终于找到了解决方案:我们可以自定义度量函数。 这是代码:

    def recall(y_true, y_pred, c):
        y_true = K.flatten(y_true)
        pred_c = K.cast(K.equal(K.argmax(y_pred, axis=-1), c), K.floatx())
        true_c = K.cast(K.equal(y_true, c), K.floatx())
        true_positives = K.sum(pred_c * true_c)
        possible_postives = K.sum(true_c)
        return true_positives / (possible_postives + K.epsilon())


    def precision(y_true, y_pred, c):
        y_true = K.flatten(y_true)
        pred_c = K.cast(K.equal(K.argmax(y_pred, axis=-1), c), K.floatx())
        true_c = K.cast(K.equal(y_true, c), K.floatx())
        true_positives = K.sum(pred_c * true_c)
        pred_positives = K.sum(pred_c)
        return true_positives / (pred_positives + K.epsilon())

    def recall_c1(y_true, y_pred):
        return recall(y_true, y_pred, 1)

    def precision_c1(y_true, y_pred):
        return precision(y_true, y_pred, 1)
    
    def recall_c2(y_true, y_pred):
        return recall(y_true, y_pred, 2)

    def precision_c2(self, y_true, y_pred):
        return precision(y_true, y_pred, 2)

我们可以使用precision_c1、recall_c1 来指示类别1 的精度和召回指标,并使用precision_c2、recall_c2 来指示类别2。通过将class_id 值c 传递给function 召回(和),还可以支持更多类别。 以下是 model 训练期间的示例 output:

Epoch 2/2000
24/24 - 35s - loss: 1.1322 - accuracy: 0.0675 - recall_c1: 0.9962 - precision_c1: 0.0676 - recall_c2: 0.0054 - precision_c2: 0.0402 - val_loss: 1.1263 - val_accuracy: 0.0357 - val_recall_c1: 1.0000 - val_precision_c1: 0.0344 - val_recall_c2: 0.0000e+00 - val_precision_c2: 0.0000e+00 - 35s/epoch - 1s/step
Epoch 3/2000
24/24 - 35s - loss: 1.1321 - accuracy: 0.0678 - recall_c1: 0.9873 - precision_c1: 0.0679 - recall_c2: 0.0178 - precision_c2: 0.0876 - val_loss: 1.1254 - val_accuracy: 0.0382 - val_recall_c1: 0.8761 - val_precision_c1: 0.0346 - val_recall_c2: 0.2432 - val_precision_c2: 0.0948 - 35s/epoch - 1s/step

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM