[英]Understanding tf.keras.metrics.Precision and Recall for multiclass classification
I am building a model for a multiclass classification problem.我正在为多类分类问题构建 model。 So I want to evaluate the model performance using the Recall and Precision.
所以我想使用召回率和精度来评估 model 的性能。 I have 4 classes in the dataset and it is provided in
one hot
representation.我在数据集中有 4 个类,它以
one hot
表示形式提供。
I was reading the Precision and Recall tf.keras
documentation, and have some questions:我正在阅读Precision and Recall
tf.keras
文档,并且有一些问题:
macro
or micro
since it is not specified in the documentation as in the Sikit learn .macro
还是micro
计算的,因为它没有像Sikit learn那样在文档中指定。class_id
for each label to do one_vs_rest
or binary
classification.class_id
来进行one_vs_rest
或binary
分类。 Like what I have done in the code below?top_k
with the value top_k=2
would be helpful here or it is not suitable for my classification of 4 classes only?top_k=2
的参数top_k
在这里会有所帮助还是不适合我的 4 类分类?top_k=1
and not setting top_k
overall?top_k=1
而不是整体设置top_k
时,可能有什么区别?model.compile(
optimizer='sgd',
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy(),
##class 0
tf.keras.metrics.Precision(class_id=0,top_k=2),
tf.keras.metrics.Recall(class_id=0,top_k=2),
##class 1
tf.keras.metrics.Precision(class_id=1,top_k=2),
tf.keras.metrics.Recall(class_id=1,top_k=2),
##class 2
tf.keras.metrics.Precision(class_id=2,top_k=2),
tf.keras.metrics.Recall(class_id=2,top_k=2),
##class 3
tf.keras.metrics.Precision(class_id=3,top_k=2),
tf.keras.metrics.Recall(class_id=3,top_k=2),
])
Any clarification of this function will be appreciated.对此 function 的任何澄清将不胜感激。 Thanks in advance
提前致谢
3. can I use the argument top_k with the value top_k=2 would be helpful here or it is not suitable for my classification of 4 classes only? 3. 我是否可以使用值为 top_k=2 的参数 top_k 在这里会有所帮助,或者它不适合我的 4 类分类?
According to the description, it will only calculate top_k(with the function of _filter_top_k) predictions, and turn other predictions to False
if you use this argument根据描述,它只会计算top_k(with the function of _filter_top_k)个预测,如果使用这个参数,其他预测为
False
The example from official document link: https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision来自官方文档链接的示例: https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision
You may also want to read the original code: https://github.com/keras-team/keras/blob/07e13740fd181fc3ddec7d9a594d8a08666645f6/keras/utils/metrics_utils.py#L487 With top_k=2, it will calculate precision over y_true[:2] and y_pred[:2]您可能还想阅读原始代码: https://github.com/keras-team/keras/blob/07e13740fd181fc3ddec7d9a594d8a08666645f6/keras/utils/metrics_utils.py#L487使用 top_k=2,它将计算精度超过 _true[: 2] 和 y_pred[:2]
m = tf.keras.metrics.Precision(top_k=2)
m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
m.result().numpy()
0.0
As we can see the note posted in the example here, it will only calculate y_true[:2] and y_pred[:2], which means the precision will calculate only top 2 predictions (also turn the rest of y_pred to 0).正如我们在此处的示例中看到的注释,它只会计算 y_true[:2] 和 y_pred[:2],这意味着精度只会计算前 2 个预测(也将 y_pred 的 rest 设置为 0)。
If you want to use 4 classes classification, the argument of class_id
maybe enough.如果要使用 4 类分类,
class_id
的参数可能就足够了。
4.While I am measuring the performance of each class, What could be the difference when I set the top_k=1 and not setting top_koverall? 4.当我在测量每个 class 的性能时,当我设置 top_k=1 而不设置 top_koverall 时会有什么区别? The function will calculate the precision across all the predictions your model make if you don't set
top_k
value.如果您不设置
top_k
值,function 将计算您的 model 所做的所有预测的精度。 If you want to measure the perfromance.如果你想衡量性能。
Top k may works for other model, not for classification model前 k 可能适用于其他 model,不适用于分类 model
1. Is it macro or micro? 1、是宏观还是微观?
To be precise, all the metrics are reset at the beginning of every epoch and at the beginning of every validation if there is.准确地说,所有指标在每个 epoch 开始时和每次验证开始时都会重置(如果有的话)。 So I guess, we can call it macro.
所以我想,我们可以称之为宏。
2. Class specific precision and recall? 2. Class 具体准确率和召回率?
You can take a look at tf.compat.v1.metrics.precision_at_k
and tf.compat.v1.metrics.recall_at_k
.您可以查看
tf.compat.v1.metrics.precision_at_k
和tf.compat.v1.metrics.recall_at_k
。 It seems that it computes the respectivly the precision at the recall for a specific class k .它似乎分别计算了特定class k的召回精度。
https://www.tensorflow.org/api_docs/python/tf/compat/v1/metrics/precision_at_k https://www.tensorflow.org/api_docs/python/tf/compat/v1/metrics/precision_at_k
https://www.tensorflow.org/api_docs/python/tf/compat/v1/metrics/recall_at_k https://www.tensorflow.org/api_docs/python/tf/compat/v1/metrics/recall_at_k
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.