簡體   English   中英

為什么 sklearn.metrics.confusion_matrix 和 sklearn.metrics.plot_confusion_matrix 的 function 定義不一致?

[英]Why do sklearn.metrics.confusion_matrix and sklearn.metrics.plot_confusion_matrix have inconsistent function defintions?

我正在使用 sklearn,我注意到sklearn.metrics.plot_confusion_matrixsklearn.metrics.confusion_matrix的 arguments 不一致。 plot_confusion_matrix使用estimatorX來構造y_pred ,而confusion_matrix直接將y_pred作為參數。

這種不一致的原因可能是什么?

部分 function 定義:

  • sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, ...) [其中 X 應該是 X_test]
  • sklearn.metrics.confusion_matrix(y_true, y_pred, ...)

資料來源:

是的,你是對的,沒有一致的 API 設計,但這里有一個關於這個問題的持續討論。

一種快速的解決方法是ConfusionMatrixDisplay

例子:

from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

X, y = make_classification(random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)

clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0))
clf.fit(X_train, y_train)

from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred)

cm_display = ConfusionMatrixDisplay(cm, [0,1]).plot()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM