简体   繁体   English

如何绘制多类分类器的精度和召回率?

[英]How to plot precision and recall of multiclass classifier?

I'm using scikit learn, and I want to plot the precision and recall curves.我正在使用 scikit learn,我想绘制精度和召回曲线。 the classifier I'm using is RandomForestClassifier .我使用的分类器是RandomForestClassifier All the resources in the documentations of scikit learn uses binary classification. scikit learn 文档中的所有资源都使用二进制分类。 Also, can I plot a ROC curve for multiclass?另外,我可以为多类绘制 ROC 曲线吗?

Also, I only found for SVM for multilabel and it has a decision_function which RandomForest doesn't have另外,我只找到了多RandomForest SVM,它有一个RandomForest没有的decision_function RandomForest

From scikit-learn documentation:来自 scikit-learn 文档:

Precision-recall curves are typically used in binary classification to study the output of a classifier. Precision-Recall 曲线通常用于二元分类以研究分类器的输出。 In order to extend the precision-recall curve and average precision to multi-class or multi-label classification, it is necessary to binarize the output.为了将精度-召回曲线和平均精度扩展到多类或多标签分类,需要对输出进行二值化。 One curve can be drawn per label, but one can also draw a precision-recall curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging).每个标签可以绘制一条曲线,但也可以通过将标签指标矩阵的每个元素视为二元预测(微平均)来绘制精确召回曲线。

ROC curves are typically used in binary classification to study the output of a classifier. ROC 曲线通常用于二元分类以研究分类器的输出。 In order to extend ROC curve and ROC area to multi-class or multi-label classification, it is necessary to binarize the output.为了将ROC曲线和ROC面积扩展到多类或多标签分类,需要对输出进行二值化。 One ROC curve can be drawn per label, but one can also draw a ROC curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging).每个标签可以绘制一条 ROC 曲线,但也可以通过将标签指标矩阵的每个元素视为二元预测(微平均)来绘制 ROC 曲线。

Therefore, you should binarize the output and consider precision-recall and roc curves for each class.因此,您应该对输出进行二值化并考虑每个类的 precision-recall 和 roc 曲线。 Moreover, you are going to use predict_proba to get class probabilities.此外,您将使用predict_proba来获取类别概率。

I divide the code into three parts:我把代码分成三部分:

  1. general settings, learning and prediction一般设置、学习和预测
  2. precision-recall curve精确召回曲线
  3. ROC curve ROC曲线

1. general settings, learning and prediction 1. 一般设置、学习和预测

from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
#%matplotlib inline

mnist = fetch_mldata("MNIST original")
n_classes = len(set(mnist.target))

Y = label_binarize(mnist.target, classes=[*range(n_classes)])

X_train, X_test, y_train, y_test = train_test_split(mnist.data,
                                                    Y,
                                                    random_state = 42)

clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
                             max_depth=3,
                             random_state=0))
clf.fit(X_train, y_train)

y_score = clf.predict_proba(X_test)

2. precision-recall curve 2. 精确召回曲线

# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        y_score[:, i])
    plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))
    
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()

在此处输入图片说明

3. ROC curve 3.ROC曲线

# roc curve
fpr = dict()
tpr = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
                                  y_score[:, i]))
    plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))

plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()

在此处输入图片说明

This looks good, but I am using a SGDClassifier for multiclass using a pipeline setting. 这看起来不错,但是我正在使用SGDClassifier通过管道设置来实现多类。 And I am finding it difficult to fit the output of the predict method into plotting the precision recall curve. 而且我发现很难将预测方法的输出拟合到绘制精确召回曲线中。

Here you are using OneVsRestRestClassifier and you are able to binarize the y_test, but it needed for SGDClassifier. 在这里,您使用的是OneVsRestRestClassifier,并且可以对y_test进行二值化,但是SGDClassifier需要它。

Please let me know how I can create one for SGDClassifier output. 请让我知道如何为SGDClassifier输出创建一个。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何从 nltk 分类器中获得准确率和召回率? - How to get the precision and recall from a nltk classifier? scikit-learn中多类分类器的平均值=“ micro”的精确召回曲线 - Precision-recall curve with average='micro' for multiclass classifier in scikit-learn sklearn 的 Plot 精度和召回率 - Plot precision and recall with sklearn 使用Tensorflow CNN分类器获得精度和召回价值 - Get precision and recall value with Tensorflow CNN classifier SGD 分类器 Precision-Recall 曲线 - SGD classifier Precision-Recall curve Keras分类器的Sklearn精度,召回率和FMeasure度量 - Sklearn Metrics of precision, recall and FMeasure on Keras classifier 了解二元分类器的精度和召回结果 - Understanding Precision and Recall Results on a Binary Classifier 如何使用 scikit learn 计算多类案例的准确率、召回率、准确率和 f1 分数? - How to compute precision, recall, accuracy and f1-score for the multiclass case with scikit learn? 如何使用 MobileNet 计算 Multiclass-problem 的准确率、召回率、F1 和混淆矩阵 - How to compute precision, recall, F1 and confusion matrix of Multiclass-problem with MobileNet 了解 tf.keras.metrics.Precision and Recall 进行多类分类 - Understanding tf.keras.metrics.Precision and Recall for multiclass classification
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM