简体   繁体   English

使用 matshow 绘图时不完整的混淆矩阵

[英]Incomplete confusion matrix when plotting with matshow

I'm trying to plot this confusión matrix:我正在尝试 plot 这个混淆矩阵:

[[25940  2141    84    19     3     0     0     1   184     4]
 [ 3525  6357   322    41     5     1     3     0   242     2]
 [  410  1484  1021    80     5     6     0     0   282     0]
 [   98   285   189   334     9     9     5     1   140     0]
 [   26    64    55    50   112    15     4     1    75     0]
 [   11    45    20    24     5   118     8     0    79     0]
 [    1     8     8     5     0    10    62     1    55     0]
 [    2     0     0     0     0     0     2     0     6     0]
 [  510   524   103    55     5     7     7     1 65350     0]
 [   62    13     2     1     0     0     1     0    11    13]]

Therefore, 10x10.因此,10x10。 Those 10 labels are:这10个标签是:

[ 5  6  7  8  9 10 11 12 14 15]

I use the following code:我使用以下代码:

Get the confusion matrix获取混淆矩阵

cm = confusion_matrix(y_test, y_pred, labels=labels)
print('Confusion Matrix of {} is:\n{}'.format(clf_name, cm))
print(labels)
plt.matshow(cm, interpolation='nearest')
ax = plt.gca()
ax.set_xticklabels([''] + labels.astype(str).tolist())
ax.set_yticklabels([''] + labels.astype(str).tolist())
plt.title('Confusion matrix of the {} classifier'.format(clf_name))
plt.colorbar(mat, extend='both')
plt.clim(0, 100)

And I only get a plot with labels from 5 to 9:我只得到一个标签从 5 到 9 的 plot:

在此处输入图像描述

What's the problem here?这里有什么问题?

Relevant imports and configuration (I'm working with Jupyter, btw):相关导入和配置(我正在使用 Jupyter,顺便说一句):

import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline
plt.style.use('seaborn')
mpl.rcParams['figure.figsize'] = 8, 6

I tried downgrading to matplotlib 3.1.0, as I read that something went wrong on 3.1.1 about seaborn, but anyway the result is the same (and also if I change style to ggplot).我尝试降级到 matplotlib 3.1.0,因为我读到关于 seaborn 的 3.1.1 出现问题,但无论如何结果是相同的(如果我将样式更改为 ggplot)。

Matplotlib doesn't put a label at every tick (to prevent overlapping ticks in case they would be longer). Matplotlib 不会在每个刻度上放置 label(以防止重叠刻度以防它们更长)。 You can force ticks at every column with ax.set_xticks(range(10)) .您可以使用ax.set_xticks(range(10))在每一列强制刻度。

Here is some example code, with calls adapted to matplotlib's "object oriented" interface.这是一些示例代码,调用适应于 matplotlib 的“面向对象”接口。 Also, some extra padding prevents the title not to bounce with the top tick labels.此外,一些额外的填充可防止标题与顶部刻度标签一起反弹。 Note that the labels can be numerically, matplotlib automatically interprets them as the corresponding strings.请注意,标签可以是数字,matplotlib 会自动将它们解释为相应的字符串。 ax.tick_params() can help to remove the tick marks at bottom and top (or, alternatively, also get them left and/or right). ax.tick_params()可以帮助删除底部和顶部的刻度线(或者,也可以让它们向左和/或向右)。 The sample code also uses a grid on the minor xticks to make separations.示例代码还使用次要 xticks 上的网格来进行分隔。

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import numpy as np

cm = np.random.randint(0, 25000, (10, 10)) * np.random.randint(0, 2, (10, 10))
labels = np.array([5, 6, 7, 8, 9, 10, 11, 12, 14, 15])

fig, ax = plt.subplots()
mat = ax.matshow(cm, interpolation='nearest')
mat.set_clim(0, 100)
ax.set_xticks(range(10))
ax.set_yticks(range(10))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
ax.tick_params(axis='x', which='both', bottom=False, top=False)

ax.grid(b=False, which='major', axis='both')
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.yaxis.set_minor_locator(MultipleLocator(0.5))
ax.grid(b=True, which='minor', axis='both', lw=2, color='white')

ax.set_title('Confusion matrix of the {} classifier'.format('clf_name'), pad=20)
plt.colorbar(mat, extend='both')
plt.show()

示例图

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM