![](/img/trans.png)
[英]Suppress scientific notation in sklearn.metrics.plot_confusion_matrix
[英]Why do sklearn.metrics.confusion_matrix and sklearn.metrics.plot_confusion_matrix have inconsistent function defintions?
我正在使用 sklearn,我注意到sklearn.metrics.plot_confusion_matrix
和sklearn.metrics.confusion_matrix
的 arguments 不一致。 plot_confusion_matrix
使用estimator
和X
来构造y_pred
,而confusion_matrix
直接将y_pred
作为参数。
这种不一致的原因可能是什么?
部分 function 定义:
sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, ...)
[其中 X 应该是 X_test]sklearn.metrics.confusion_matrix(y_true, y_pred, ...)
资料来源:
是的,你是对的,没有一致的 API 设计,但这里有一个关于这个问题的持续讨论。
一种快速的解决方法是ConfusionMatrixDisplay
。
例子:
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
X, y = make_classification(random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)
clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0))
clf.fit(X_train, y_train)
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
cm_display = ConfusionMatrixDisplay(cm, [0,1]).plot()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.