简体   繁体   English

是否可以在混淆矩阵中添加额外的列?

[英]Is it possible to add extra columns to confusion matrix?

I created a multi-class classifier and now I want to show the confusion matrix and accuracies per class in a clean way.我创建了一个多类分类器,现在我想以简洁的方式显示每个 class 的混淆矩阵和准确度。

I already found a function in sklearn that gives me the possibility to show the confusion matrix: sklearn.metrics.plot_confusion_matrix , but I do not see a way to add an extra column where I can put the accuracy per class/row.我已经在 sklearn 中找到了 function ,它使我有可能显示混淆矩阵: sklearn.metrics.plot_confusion_matrix ,但我看不到添加额外列的方法,我可以在其中放置每个类/行的准确性。

This is an example on how its possible to plot the confusion matrix:这是一个关于如何将 plot 混淆矩阵的示例:

import matplotlib.pyplot as plt  
from sklearn.datasets import make_classification
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = SVC(random_state=0)
clf.fit(X_train, y_train)
plot_confusion_matrix(clf, X_test, y_test)  
plt.show() 

In the following picture, I drawed something in paint to show what I mean by "Add extra column":在下图中,我用颜料画了一些东西来说明“添加额外列”的含义:

Is there a way to do change this example and add the extra column?有没有办法改变这个例子并添加额外的列? Or are there other libraries which support what I want to do?还是有其他支持我想做的库?

It doesn't look like anything does this out-of-the-box, so I wrote one:看起来没有任何东西可以开箱即用,所以我写了一个:

def plot_class_accuracies(plotted_cm, axis, display_labels=None, cmap="viridis"):
    """
    plotted_cm : instance of `ConfusionMatrixDisplay`
        Result of `sklearn.metrics.plot_confusion_matrix`
    axis : matplotlib `AxesSubplot`
        Result of `fig, (ax1, ax2) = plt.subplots(1, 2)`
    display_labels : list of labels or None
        Human-readable class names
    cmap : colormap, optional
        Optional colormap
    """
    cmatrix = plotted_cm.confusion_matrix
    normalized_cmatrix = np.diag(cmatrix) / np.sum(cmatrix, axis=1)
    n_classes = len(normalized_cmatrix)

    cmap_min, cmap_max = plotted_cm.im_.cmap(0), plotted_cm.im_.cmap(256)
    thresh = (normalized_cmatrix.max() + normalized_cmatrix.min()) / 2.0

    if display_labels is None:
        labels = np.arange(n_classes)
    else:
        labels = display_labels

    axis.imshow(
        normalized_cmatrix.reshape(n_classes, 1),
        interpolation="nearest",
        cmap=cmap,
    )

    for i, value in enumerate(normalized_cmatrix):
        color = cmap_min if value > thresh else cmap_max
        axis.text(0, i, format(value, ".2g"), ha="center", va="center", color=color)

    axis.set(
        yticks=np.arange(len(normalized_cmatrix)),
        ylabel="True label",
        xlabel="Class accuracy",
        yticklabels=labels,
    )
    axis.tick_params(
        axis="x", bottom=False, labelbottom=False,
    )
    axis.set_ylim((len(normalized_cmatrix) - 0.5, -0.5))

Assuming this is in a file cmatrix.py :假设这是在文件cmatrix.py

from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import plot_confusion_matrix

# Import `plot_class_accuracies` from `cmatrix.py`
from cmatrix import plot_class_accuracies

if __name__ == "__main__":

    class ExampleClassifier(LogisticRegression):
        def __init__(self):
            self.classes_ = None
        def predict(self, X_test):
            self.classes_ = np.unique(X_test)
            return X_test

    X_test = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 2])
    y_test = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3])

    fig, (ax1, ax2) = plt.subplots(1, 2)
    clf = ExampleClassifier()

    disp = plot_confusion_matrix(
        clf, X_test, y_test, ax=ax1, cmap=plt.cm.Blues, normalize="true"
    )

    plot_class_accuracies(disp, ax2, cmap=plt.cm.Blues)
    plt.show()

Result:结果:

图像左侧是混淆矩阵,右侧是显示类别准确度的图。左边的对角线与右边相同。

And here's an example based on the example from the Confusion Matrix example from the sklearn documentation :这是一个基于sklearn 文档中混淆矩阵示例的示例

import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix

from cmatrix import plot_class_accuracies

iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
classifier = svm.SVC(kernel='linear', C=0.01).fit(X_train, y_train)

fig, (ax1, ax2) = plt.subplots(1, 2)

disp = plot_confusion_matrix(classifier, X_test, y_test,
                             display_labels=class_names,
                             ax=ax1,
                             cmap=plt.cm.Blues)

plot_class_accuracies(disp, ax2, display_labels=class_names, cmap=plt.cm.Blues)

plt.show()

Result:结果:

与上一张图片的想法相同,但显示了 iris 数据集中在 setosa、versicolor 和 virginica 上的性能示例。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM