简体   繁体   English

了解 tf.keras.metrics.Precision and Recall 进行多类分类

[英]Understanding tf.keras.metrics.Precision and Recall for multiclass classification

I am building a model for a multiclass classification problem.我正在为多类分类问题构建 model。 So I want to evaluate the model performance using the Recall and Precision.所以我想使用召回率和精度来评估 model 的性能。 I have 4 classes in the dataset and it is provided in one hot representation.我在数据集中有 4 个类,它以one hot表示形式提供。

I was reading the Precision and Recall tf.keras documentation, and have some questions:我正在阅读Precision and Recall tf.keras文档,并且有一些问题:

  1. When it calculating the Precision and Recall for the multi-class classification, how can we take the average of all of the labels, meaning the global precision & Recall?在计算多类分类的 Precision 和 Recall 时,我们如何取所有标签的平均值,即全局 Precision & Recall? is it calculated with macro or micro since it is not specified in the documentation as in the Sikit learn .它是用macro还是micro计算的,因为它没有像Sikit learn那样在文档中指定。
  2. If I want to calculate the precision & Recall for each label separately, can I use the argument class_id for each label to do one_vs_rest or binary classification.如果我想分别计算每个 label 的精度和召回率,我可以使用每个 label 的参数class_id来进行one_vs_restbinary分类。 Like what I have done in the code below?就像我在下面的代码中所做的那样?
  3. can I use the argument top_k with the value top_k=2 would be helpful here or it is not suitable for my classification of 4 classes only?我可以使用带有值top_k=2的参数top_k在这里会有所帮助还是不适合我的 4 类分类?
  4. While I am measuring the performance of each class, What could be the difference, when I set the top_k=1 and not setting top_k overall?当我测量每个 class 的性能时,当我设置top_k=1而不是整体设置top_k时,可能有什么区别?
model.compile(
      optimizer='sgd',
      loss=tf.keras.losses.CategoricalCrossentropy(),
      metrics=[tf.keras.metrics.CategoricalAccuracy(),
               ##class 0
               tf.keras.metrics.Precision(class_id=0,top_k=2), 
               tf.keras.metrics.Recall(class_id=0,top_k=2),
              ##class 1
               tf.keras.metrics.Precision(class_id=1,top_k=2), 
               tf.keras.metrics.Recall(class_id=1,top_k=2),
              ##class 2
               tf.keras.metrics.Precision(class_id=2,top_k=2), 
               tf.keras.metrics.Recall(class_id=2,top_k=2),
              ##class 3
               tf.keras.metrics.Precision(class_id=3,top_k=2), 
               tf.keras.metrics.Recall(class_id=3,top_k=2),
])

Any clarification of this function will be appreciated.对此 function 的任何澄清将不胜感激。 Thanks in advance提前致谢

3. can I use the argument top_k with the value top_k=2 would be helpful here or it is not suitable for my classification of 4 classes only? 3. 我是否可以使用值为 top_k=2 的参数 top_k 在这里会有所帮助,或者它不适合我的 4 类分类?

According to the description, it will only calculate top_k(with the function of _filter_top_k) predictions, and turn other predictions to False if you use this argument根据描述,它只会计算top_k(with the function of _filter_top_k)个预测,如果使用这个参数,其他预测为False

The example from official document link: https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision来自官方文档链接的示例: https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision

You may also want to read the original code: https://github.com/keras-team/keras/blob/07e13740fd181fc3ddec7d9a594d8a08666645f6/keras/utils/metrics_utils.py#L487 With top_k=2, it will calculate precision over y_true[:2] and y_pred[:2]您可能还想阅读原始代码: https://github.com/keras-team/keras/blob/07e13740fd181fc3ddec7d9a594d8a08666645f6/keras/utils/metrics_utils.py#L487使用 top_k=2,它将计算精度超过 _true[: 2] 和 y_pred[:2]

m = tf.keras.metrics.Precision(top_k=2)
m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
m.result().numpy()
0.0

As we can see the note posted in the example here, it will only calculate y_true[:2] and y_pred[:2], which means the precision will calculate only top 2 predictions (also turn the rest of y_pred to 0).正如我们在此处的示例中看到的注释,它只会计算 y_true[:2] 和 y_pred[:2],这意味着精度只会计算前 2 个预测(也将 y_pred 的 rest 设置为 0)。

If you want to use 4 classes classification, the argument of class_id maybe enough.如果要使用 4 类分类, class_id的参数可能就足够了。

4.While I am measuring the performance of each class, What could be the difference when I set the top_k=1 and not setting top_koverall? 4.当我在测量每个 class 的性能时,当我设置 top_k=1 而不设置 top_koverall 时会有什么区别? The function will calculate the precision across all the predictions your model make if you don't set top_k value.如果您不设置top_k值,function 将计算您的 model 所做的所有预测的精度。 If you want to measure the perfromance.如果你想衡量性能。

Top k may works for other model, not for classification model前 k 可能适用于其他 model,不适用于分类 model

1. Is it macro or micro? 1、是宏观还是微观?

To be precise, all the metrics are reset at the beginning of every epoch and at the beginning of every validation if there is.准确地说,所有指标在每个 epoch 开始时和每次验证开始时都会重置(如果有的话)。 So I guess, we can call it macro.所以我想,我们可以称之为宏。

2. Class specific precision and recall? 2. Class 具体准确率和召回率?

You can take a look at tf.compat.v1.metrics.precision_at_k and tf.compat.v1.metrics.recall_at_k .您可以查看tf.compat.v1.metrics.precision_at_ktf.compat.v1.metrics.recall_at_k It seems that it computes the respectivly the precision at the recall for a specific class k .它似乎分别计算了特定class k的召回精度。

https://www.tensorflow.org/api_docs/python/tf/compat/v1/metrics/precision_at_k https://www.tensorflow.org/api_docs/python/tf/compat/v1/metrics/precision_at_k

https://www.tensorflow.org/api_docs/python/tf/compat/v1/metrics/recall_at_k https://www.tensorflow.org/api_docs/python/tf/compat/v1/metrics/recall_at_k

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 Tensorflow:如何在多类分类中使用 tf.keras.metrics? - Tensorflow: How to use tf.keras.metrics in multiclass classification? Keras-精度和召回率大于1(多种分类) - Keras - Precision and Recall is greater than 1 (Multi classification) Keras分类器的Sklearn精度,召回率和FMeasure度量 - Sklearn Metrics of precision, recall and FMeasure on Keras classifier 我可以在多类分类问题中使用 tf.metrics.BinaryAccuracy 吗? - Can I use tf.metrics.BinaryAccuracy in a multiclass classification problem? 当我使用“tf.keras.metrics.Recall()”时,得到“ValueError: Shapes (None, 2) and (None, 1) are incompatible”进行二进制分类 - Getting "ValueError: Shapes (None, 2) and (None, 1) are incompatible" for binary classification when I am using "tf.keras.metrics.Recall()" 计算多 label 分类 keras 的召回精度和 F1 分数 - compute the recall precision and F1 score for a multi label classification keras TF.Keras 自定义 Scratch 训练中的多输出-多类分类 - Multioutput-Multiclass Classification in Custom Scratch Training in TF.Keras Keras 2.3.0 指标准确度、精度和召回率的值相同 - Same value for Keras 2.3.0 metrics accuracy, precision and recall Keras:多类分类 - Keras: multiclass classification Keras分类:预测和多类 - Classification with Keras: prediction and multiclass
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM