简体   繁体   中英

Calculating F1 score, precision, recall in tfhub retraining script

I am using tensorflow hub for image retraining classification task. The tensorflow script retrain.py by default calculates cross_entropy and accuracy.

train_accuracy, cross_entropy_value = sess.run([evaluation_step, cross_entropy],feed_dict={bottleneck_input: train_bottlenecks, ground_truth_input: train_ground_truth})

I would like to get F1 score, precision, recall and confusion matrix. How could I get these values using this script ?

Below I include a method to calculate desired metrics using scikit-learn package.

You can calculate F1 score, precision and recall using precision_recall_fscore_support method and the confusion matrix using confusion_matrix method:

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

Both methods take two 1D array-like objects which store ground truth and predicted labels respectively.

In the code provided, ground-truth labels for training data are stored in train_ground_truth variable which is defined in lines 1054 and 1060 , while validation_ground_truth stores ground-truth labels for validation data and is defined in line 1087 .

The tensor that calculates predicted class labels is defined and returned by add_evaluation_step function. You can modify line 1034 in order to capture that tensor object:

evaluation_step, prediction = add_evaluation_step(final_tensor, ground_truth_input)
# now prediction stores the tensor object that 
# calculates predicted class labels

Now you can update line 1076 in order to evaluate prediction when calling sess.run() :

train_accuracy, cross_entropy_value, train_predictions = sess.run(
    [evaluation_step, cross_entropy, prediction],
    feed_dict={bottleneck_input: train_bottlenecks,
               ground_truth_input: train_ground_truth})

# train_predictions now stores class labels predicted by model

# calculate precision, recall and F1 score
(train_precision,
 train_recall,
 train_f1_score, _) = precision_recall_fscore_support(y_true=train_ground_truth,
                                                      y_pred=train_predictions,
                                                      average='micro')
# calculate confusion matrix
train_confusion_matrix = confusion_matrix(y_true=train_ground_truth,
                                          y_pred=train_predictions)

Similarly, you can compute metrics for validation subset by modifying line 1095 :

validation_summary, validation_accuracy, validation_predictions = sess.run(
    [merged, evaluation_step, prediction],
    feed_dict={bottleneck_input: validation_bottlenecks,
               ground_truth_input: validation_ground_truth})

# validation_predictions now stores class labels predicted by model

# calculate precision, recall and F1 score
(validation_precision,
 validation_recall,
 validation_f1_score, _) = precision_recall_fscore_support(y_true=validation_ground_truth,
                                                           y_pred=validation_predictions,
                                                           average='micro')
# calculate confusion matrix
validation_confusion_matrix = confusion_matrix(y_true=validation_ground_truth,
                                               y_pred=validation_predictions)

Finally, the code calls run_final_eval to evaluate trained model on test data. In this function, prediction and test_ground_truth are already defined, so you only need to include code to calculate required metrics:

test_accuracy, predictions = eval_session.run(
    [evaluation_step, prediction],
    feed_dict={
        bottleneck_input: test_bottlenecks,
        ground_truth_input: test_ground_truth
    })

# calculate precision, recall and F1 score
(test_precision,
 test_recall,
 test_f1_score, _) = precision_recall_fscore_support(y_true=test_ground_truth,
                                                     y_pred=predictions,
                                                     average='micro')
# calculate confusion matrix
test_confusion_matrix = confusion_matrix(y_true=test_ground_truth,
                                         y_pred=predictions)

Note that the provided code calculates global F1-scores by setting average='micro' . The different averaging methods that are supported by scikit-learn package are described in User Guide .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM