简体   繁体   中英

Why do sklearn.metrics.confusion_matrix and sklearn.metrics.plot_confusion_matrix have inconsistent function defintions?

I am using sklearn and I noticed that the arguments of sklearn.metrics.plot_confusion_matrix and sklearn.metrics.confusion_matrix are inconsistent. plot_confusion_matrix uses estimator and X to construct y_pred , while confusion_matrix has y_pred as argument directly.

What may be the reason for this inconsistency?

Partial function definitions:

  • sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, ...) [where X should be X_test]
  • sklearn.metrics.confusion_matrix(y_true, y_pred, ...)

Sources:

Yes, you are right that there isn't a consistent API design for this but there is an on going discussion for this issue here .

One quick work around is ConfusionMatrixDisplay .

example:

from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

X, y = make_classification(random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)

clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0))
clf.fit(X_train, y_train)

from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred)

cm_display = ConfusionMatrixDisplay(cm, [0,1]).plot()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM