简体   繁体   English

使用来自 Sklearn 的 ConfusionMatrixDisplay 对用于绘制混淆矩阵的颜色图进行归一化

[英]Normalizing a color map for plotting a Confusion Matrix with ConfusionMatrixDisplay from Sklearn

I am trying to create a color map for my 10x10 confusion matrix that is provided by sklearn .我正在尝试为sklearn提供的 10x10 混淆矩阵创建颜色图。 I would like to be able to customize the color map to be normalized between [0,1] but I have had no success.我希望能够自定义颜色映射以在 [0,1] 之间进行标准化,但我没有成功。 I am trying to use ax_ and matplotlib.colors.Normalize but am struggling to get something to work since ConfusionMatrixDisplay is a sklearn object that creates a different than usual matplotlib plot.我正在尝试使用ax_matplotlib.colors.Normalize但我正在努力使某些东西起作用,因为 ConfusionMatrixDisplay 是一个 sklearn 对象,它创建了与通常的matplotlib图不同的图。

My code is the following:我的代码如下:

from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

train_confuse_matrix = confusion_matrix(y_true = ytrain, y_pred = y_train_pred_labels)
print(train_confuse_matrix)
cm_display = ConfusionMatrixDisplay(train_confuse_matrix, display_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])
print(cm_display)
cm_display.plot(cmap = 'Greens')
plt.show()
plt.clf()


[[3289   56   84   18   55    7   83   61   48  252]
 [   2 3733    0    1    2    1   16    1    3  220]
 [  81   15 3365   64   81   64  273   18    6   17]
 [  17   37   71 3015  127  223  414   44    6   64]
 [   3    1   43   27 3659   24  225   35    0    3]
 [   5   23   38  334  138 3109  224   80    4   25]
 [   3    1   19   10   12    7 3946    1    1    5]
 [   4    7   38   69  154   53   89 3615    2   27]
 [  62   67   12    7   25    3   62    4 3595  153]
 [   2   30    1    2    4    0   15    2    0 3957]]


在此处输入图片说明

Let's try imshow and annotate manually:让我们尝试imshow并手动annotate

accuracies = conf_mat/conf_mat.sum(1)
fig, ax = plt.subplots(figsize=(10,8))
cb = ax.imshow(accuracies, cmap='Greens')
plt.xticks(range(len(classes)), classes,rotation=90)
plt.yticks(range(len(classes)), classes)

for i in range(len(classes)):
    for j in range(len(classes)):
        color='green' if accuracies[i,j] < 0.5 else 'white'
        ax.annotate(f'{conf_mat[i,j]}', (i,j), 
                    color=color, va='center', ha='center')

plt.colorbar(cb, ax=ax)
plt.show()

Output:输出:

在此处输入图片说明

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM