简体   繁体   English

在tfhub重新训练脚本中计算F1得分,精度,召回率

[英]Calculating F1 score, precision, recall in tfhub retraining script

I am using tensorflow hub for image retraining classification task. 我正在使用tensorflow hub进行图像再训练分类任务。 The tensorflow script retrain.py by default calculates cross_entropy and accuracy. tensorflow脚本retrain.py默认情况下计算cross_entropy和准确性。

train_accuracy, cross_entropy_value = sess.run([evaluation_step, cross_entropy],feed_dict={bottleneck_input: train_bottlenecks, ground_truth_input: train_ground_truth})

I would like to get F1 score, precision, recall and confusion matrix. 我想获得F1得分,准确性,召回率和混乱矩阵。 How could I get these values using this script ? 如何使用此脚本获取这些值?

Below I include a method to calculate desired metrics using scikit-learn package. 下面,我提供了一种使用scikit-learn软件包计算所需指标的方法。

You can calculate F1 score, precision and recall using precision_recall_fscore_support method and the confusion matrix using confusion_matrix method: 您可以使用precision_recall_fscore_support方法计算F1得分,精度和召回率,并使用confusion_matrix方法计算混淆矩阵:

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

Both methods take two 1D array-like objects which store ground truth and predicted labels respectively. 两种方法都采用两个类似一维数组的对象,它们分别存储地面真实情况和预测标签。

In the code provided, ground-truth labels for training data are stored in train_ground_truth variable which is defined in lines 1054 and 1060 , while validation_ground_truth stores ground-truth labels for validation data and is defined in line 1087 . 在提供的代码中,用于训练数据的真实标签存储在train_ground_truth变量中,该变量在第10541060行中定义,而validation_ground_truth存储用于检验数据的真实标签并在1087行中定义。

The tensor that calculates predicted class labels is defined and returned by add_evaluation_step function. add_evaluation_step函数定义并返回用于计算预测类标签的张量。 You can modify line 1034 in order to capture that tensor object: 您可以修改第1034行以捕获该张量对象:

evaluation_step, prediction = add_evaluation_step(final_tensor, ground_truth_input)
# now prediction stores the tensor object that 
# calculates predicted class labels

Now you can update line 1076 in order to evaluate prediction when calling sess.run() : 现在,您可以更新第1076行,以便在调用sess.run()时评估prediction

train_accuracy, cross_entropy_value, train_predictions = sess.run(
    [evaluation_step, cross_entropy, prediction],
    feed_dict={bottleneck_input: train_bottlenecks,
               ground_truth_input: train_ground_truth})

# train_predictions now stores class labels predicted by model

# calculate precision, recall and F1 score
(train_precision,
 train_recall,
 train_f1_score, _) = precision_recall_fscore_support(y_true=train_ground_truth,
                                                      y_pred=train_predictions,
                                                      average='micro')
# calculate confusion matrix
train_confusion_matrix = confusion_matrix(y_true=train_ground_truth,
                                          y_pred=train_predictions)

Similarly, you can compute metrics for validation subset by modifying line 1095 : 同样,您可以通过修改第1095行来计算验证子集的指标:

validation_summary, validation_accuracy, validation_predictions = sess.run(
    [merged, evaluation_step, prediction],
    feed_dict={bottleneck_input: validation_bottlenecks,
               ground_truth_input: validation_ground_truth})

# validation_predictions now stores class labels predicted by model

# calculate precision, recall and F1 score
(validation_precision,
 validation_recall,
 validation_f1_score, _) = precision_recall_fscore_support(y_true=validation_ground_truth,
                                                           y_pred=validation_predictions,
                                                           average='micro')
# calculate confusion matrix
validation_confusion_matrix = confusion_matrix(y_true=validation_ground_truth,
                                               y_pred=validation_predictions)

Finally, the code calls run_final_eval to evaluate trained model on test data. 最后,代码调用run_final_eval来评估测试数据上的训练模型。 In this function, prediction and test_ground_truth are already defined, so you only need to include code to calculate required metrics: 在此函数中, predictiontest_ground_truth已经定义,因此您只需要包括代码即可计算所需的指标:

test_accuracy, predictions = eval_session.run(
    [evaluation_step, prediction],
    feed_dict={
        bottleneck_input: test_bottlenecks,
        ground_truth_input: test_ground_truth
    })

# calculate precision, recall and F1 score
(test_precision,
 test_recall,
 test_f1_score, _) = precision_recall_fscore_support(y_true=test_ground_truth,
                                                     y_pred=predictions,
                                                     average='micro')
# calculate confusion matrix
test_confusion_matrix = confusion_matrix(y_true=test_ground_truth,
                                         y_pred=predictions)

Note that the provided code calculates global F1-scores by setting average='micro' . 请注意,所提供的代码通过设置average='micro'来计算全局 F1分数。 The different averaging methods that are supported by scikit-learn package are described in User Guide . 用户指南中介绍了scikit-learn软件包支持的各种平均方法。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 精度,召回率,F1得分与sklearn相等 - Precision, recall, F1 score equal with sklearn Tensorflow:计算精度、召回率、F1 分数 - Tensorflow: Compute Precision, Recall, F1 Score 使用自制分离器获得精确度、召回率、F1 分数 - Get Precision, Recall, F1 Score with self-made splitter 如何在训练 SSD 后预测 Precision、Recall 和 F1 分数 - How to predict Precision, Recall and F1 score after training SSD 计算多 label 分类 keras 的召回精度和 F1 分数 - compute the recall precision and F1 score for a multi label classification keras Tensorflow Precision / Recall / F1 分数和混淆矩阵 - Tensorflow Precision / Recall / F1 score and Confusion matrix 在Keras获得每班的精确度,召回率和F1分数 - Getting precision, recall and F1 score per class in Keras 使用 precision_recall_curve 计算最大 f1 分数? - compute maximum f1 score using precision_recall_curve? 在一次传递中计算精度,召回和F分数 - 蟒蛇 - Calculating Precision, Recall and F-score in one pass - python 相同的测试和预测值给出 0 精度、召回率和 NER 的 f1 分数 - Same test and prediction values gives 0 precision, recall, f1 score for NER
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM