繁体   English   中英

Tensorflow 中多类分类的分类精度和召回率?

[英]Class wise precision and recall for multi class classification in Tensorflow?

在使用张量流进行多类分类时,有没有办法获得每类的精度或召回率。

例如,如果我有每个批次的 y_true 和 y_pred,如果我有 2 个以上的类,是否有一种功能性方法来获得每个类的精度或召回率。

这是一个对我有用的解决方案,用于解决 n=6 类的问题。 如果你有更多的类,这个解决方案可能很慢,你应该使用某种映射而不是循环。

假设你有张行一个热编码的等级标签, labels和张量logits(或后验) labels 然后,如果n是类的数量,试试这个:

y_true = tf.argmax(labels, 1)
y_pred = tf.argmax(logits, 1)

recall = [0] * n
update_op_rec = [[]] * n

for k in range(n):
    recall[k], update_op_rec[k] = tf.metrics.recall(
        labels=tf.equal(y_true, k),
        predictions=tf.equal(y_pred, k)
    )

请注意,在tf.metrics.recall ,变量labelspredictions被设置为布尔向量,就像在 2 变量情况下一样,这允许使用该函数。

2个事实:

  1. 正如其他答案中所述,Tensorflow 内置指标精度召回率不支持多类(文档说will be cast to bool

  2. 有通过使用获得一个抗所有得分方式precision_at_k通过指定class_id ,或者通过简单的铸造你的labels ,并predictionstf.bool以正确的方式。

因为这令人不满意tf_metrics完整,所以我编写了tf_metrics ,这是一个用于多类度量的简单包,您可以在github 上找到。 它支持多种平均方法,如scikit-learn

示例

import tensorflow as tf
import tf_metrics

y_true = [0, 1, 0, 0, 0, 2, 3, 0, 0, 1]
y_pred = [0, 1, 0, 0, 1, 2, 0, 3, 3, 1]
pos_indices = [1]        # Metrics for class 1 -- or
pos_indices = [1, 2, 3]  # Average metrics, 0 is the 'negative' class
num_classes = 4
average = 'micro'

# Tuple of (value, update_op)
precision = tf_metrics.precision(
    y_true, y_pred, num_classes, pos_indices, average=average)
recall = tf_metrics.recall(
    y_true, y_pred, num_classes, pos_indices, average=average)
f2 = tf_metrics.fbeta(
    y_true, y_pred, num_classes, pos_indices, average=average, beta=2)
f1 = tf_metrics.f1(
    y_true, y_pred, num_classes, pos_indices, average=average)

我相信你不能用tf.metrics.precision/recall函数做多类精度、召回、f1。 您可以像这样将 sklearn 用于 3 类场景:

from sklearn.metrics import precision_recall_fscore_support as score

prediction = [1,2,3,2] 
y_original = [1,2,3,3]

precision, recall, f1, _ = score(y_original, prediction)

print('precision: {}'.format(precision))
print('recall: {}'.format(recall))
print('fscore: {}'.format(f1))

这将打印一个精度数组,召回值,但可以根据需要对其进行格式化。

我被这个问题困扰了很长时间。 我知道这个问题可以通过 sklearn 来解决,但我真的很想通过 Tensorflow 的 API 来解决这个问题。 通过阅读它的代码,我终于弄清楚了这个 API 是如何工作的。

tf.metrics.precision_at_k(labels, predictions, k, class_id)
  • 首先,让我们假设这是一个4 类问题。
  • 其次,我们有两个样本,它们的标签是 3 和 1它们的预测是 [0.5,0.3,0.1,0.1], [0.5,0.3,0.1,0.1] 。根据我们的预测,我们可以得到两个结果样本已预测为1,1
  • 第三,如果你想得到class 1的精度,使用公式TP/(TP+FP) ,我们假设结果是1/(1+1)=0.5 因为两个样本都被预测为1 ,但其中一个实际上是3 ,所以TP为1FP为1结果为0.5
  • 最后,让我们使用这个 API 来验证我们的假设。

     import tensorflow as tf labels = tf.constant([[2],[0]],tf.int64) predictions = tf.constant([[0.5,0.3,0.1,0.1],[0.5,0.3,0.1,0.1]]) metric = tf.metrics.precision_at_k(labels, predictions, 1, class_id=0) sess = tf.Session() sess.run(tf.local_variables_initializer()) precision, update = sess.run(metric) print(precision) # 0.5

通知

  • k不是类的数量。 它表示我们要排序的数量,这意味着预测的最后一个维度必须与 k 的值匹配。

  • class_id表示我们想要二进制度量的类。

  • 如果k=1,意味着我们不会对预测进行排序,因为我们想要做的实际上是一个二元分类,而是指不同的类。 所以如果我们对预测进行排序, class_id 就会混淆,结果就会出错。

  • 还有更重要的一点是,如果我们想要得到正确的结果, label的输入应该是负1,因为class_id实际上代表的是label的索引,而label的下标是从0开始的

在 TensorFlow 中有一种方法可以做到这一点。

tf.metrics.precision_at_k(labels, predictions, k, class_id)

设置 k = 1 并设置相应的 class_id。 例如 class_id=0 计算第一类的精度。

我相信 TF 还没有提供这样的功能。 根据文档(https://www.tensorflow.org/api_docs/python/tf/metrics/precision ),它说标签和预测都将转换为 bool,因此它仅与二进制分类有关。 也许可以对示例进行单热编码并且它会起作用? 但不确定这一点。

这是从 Tensorflow 中的预测到通过 scikit-learn 报告的完整示例:

import tensorflow as tf
from sklearn.metrics import classification_report

# given trained model `model` and test vector `X_test` gives `y_test`
# where `y_test` and `y_predicted` are integers, who labels are indexed in 
# `labels`
y_predicted = tf.argmax(model.predict(X_test), axis=1)

# Confusion matrix
cf = tf.math.confusion_matrix(y_test, y_predicted)
plt.matshow(cf, cmap='magma')
plt.colorbar()
plt.xticks(np.arange(len(labels)), labels=labels, rotation=90)
plt.yticks(np.arange(len(labels)), labels=labels)
plt.clim(0, None)

# Report
print(classification_report(y_test, y_predicted, target_names=labels))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM