[英]How to plot precision and recall of multiclass classifier?
I'm using scikit learn, and I want to plot the precision and recall curves.我正在使用 scikit learn,我想绘制精度和召回曲线。 the classifier I'm using is
RandomForestClassifier
.我使用的分类器是
RandomForestClassifier
。 All the resources in the documentations of scikit learn uses binary classification. scikit learn 文档中的所有资源都使用二进制分类。 Also, can I plot a ROC curve for multiclass?
另外,我可以为多类绘制 ROC 曲线吗?
Also, I only found for SVM for multilabel and it has a decision_function
which RandomForest
doesn't have另外,我只找到了多
RandomForest
SVM,它有一个RandomForest
没有的decision_function
RandomForest
From scikit-learn documentation:来自 scikit-learn 文档:
Precision-recall curves are typically used in binary classification to study the output of a classifier.
Precision-Recall 曲线通常用于二元分类以研究分类器的输出。 In order to extend the precision-recall curve and average precision to multi-class or multi-label classification, it is necessary to binarize the output.
为了将精度-召回曲线和平均精度扩展到多类或多标签分类,需要对输出进行二值化。 One curve can be drawn per label, but one can also draw a precision-recall curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging).
每个标签可以绘制一条曲线,但也可以通过将标签指标矩阵的每个元素视为二元预测(微平均)来绘制精确召回曲线。
ROC curves are typically used in binary classification to study the output of a classifier.
ROC 曲线通常用于二元分类以研究分类器的输出。 In order to extend ROC curve and ROC area to multi-class or multi-label classification, it is necessary to binarize the output.
为了将ROC曲线和ROC面积扩展到多类或多标签分类,需要对输出进行二值化。 One ROC curve can be drawn per label, but one can also draw a ROC curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging).
每个标签可以绘制一条 ROC 曲线,但也可以通过将标签指标矩阵的每个元素视为二元预测(微平均)来绘制 ROC 曲线。
Therefore, you should binarize the output and consider precision-recall and roc curves for each class.因此,您应该对输出进行二值化并考虑每个类的 precision-recall 和 roc 曲线。 Moreover, you are going to use
predict_proba
to get class probabilities.此外,您将使用
predict_proba
来获取类别概率。
I divide the code into three parts:我把代码分成三部分:
1. general settings, learning and prediction 1. 一般设置、学习和预测
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
#%matplotlib inline
mnist = fetch_mldata("MNIST original")
n_classes = len(set(mnist.target))
Y = label_binarize(mnist.target, classes=[*range(n_classes)])
X_train, X_test, y_train, y_test = train_test_split(mnist.data,
Y,
random_state = 42)
clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
max_depth=3,
random_state=0))
clf.fit(X_train, y_train)
y_score = clf.predict_proba(X_test)
2. precision-recall curve 2. 精确召回曲线
# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
y_score[:, i])
plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()
3. ROC curve 3.ROC曲线
# roc curve
fpr = dict()
tpr = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
y_score[:, i]))
plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))
plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()
This looks good, but I am using a SGDClassifier for multiclass using a pipeline setting. 这看起来不错,但是我正在使用SGDClassifier通过管道设置来实现多类。 And I am finding it difficult to fit the output of the predict method into plotting the precision recall curve.
而且我发现很难将预测方法的输出拟合到绘制精确召回曲线中。
Here you are using OneVsRestRestClassifier and you are able to binarize the y_test, but it needed for SGDClassifier. 在这里,您使用的是OneVsRestRestClassifier,并且可以对y_test进行二值化,但是SGDClassifier需要它。
Please let me know how I can create one for SGDClassifier output. 请让我知道如何为SGDClassifier输出创建一个。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.