[英]How to plot confusion matrix with gridding without color using matplotlib?
我在如何使用字符串轴而不是python中的整数绘制混淆矩阵时发现了类似的问题。 但答案并不完全是我想要的。 因为它不包含网格(例如,数字不在小方块中)并且有背景颜色来显示不是我想要的数字。
import numpy as np
import matplotlib.pyplot as plt
conf_arr = [[33,2,0,0,0,0,0,0,0,1,3],
[3,31,0,0,0,0,0,0,0,0,0],
[0,4,41,0,0,0,0,0,0,0,1],
[0,1,0,30,0,6,0,0,0,0,1],
[0,0,0,0,38,10,0,0,0,0,0],
[0,0,0,3,1,39,0,0,0,0,4],
[0,2,2,0,4,1,31,0,0,0,2],
[0,1,0,0,0,0,0,36,0,2,0],
[0,0,0,0,0,0,1,5,37,5,1],
[3,0,0,0,0,0,0,0,0,39,0],
[0,0,0,0,0,0,0,0,0,0,38]]
norm_conf = []
for i in conf_arr:
a = 0
tmp_arr = []
a = sum(i, 0)
for j in i:
tmp_arr.append(float(j)/float(a))
norm_conf.append(tmp_arr)
fig = plt.figure()
plt.clf()
ax = fig.add_subplot(111)
ax.set_aspect(1)
res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet,
interpolation='nearest')
width, height = conf_arr.shape
for x in xrange(width):
for y in xrange(height):
ax.annotate(str(conf_arr[x][y]), xy=(y, x),
horizontalalignment='center',
verticalalignment='center')
cb = fig.colorbar(res)
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
plt.xticks(range(width), alphabet[:width])
plt.yticks(range(height), alphabet[:height])
plt.savefig('confusion_matrix.png', format='png')
通过对相当优秀的代码提案进行一些更改(我赞成它,考虑也这样做),您可以获得您描述的数字。
您将通过调用ax
对象的hlines
和vlines
方法来获取网格,这将分别添加水平和垂直线。 当你然后删除对imshow
的调用时,颜色消失了。 像这样:
import numpy as np
import matplotlib.pyplot as plt
conf_arr = np.array([[33,2,0,0,0,0,0,0,0,1,3],
[3,31,0,0,0,0,0,0,0,0,0],
[0,4,41,0,0,0,0,0,0,0,1],
[0,1,0,30,0,6,0,0,0,0,1],
[0,0,0,0,38,10,0,0,0,0,0],
[0,0,0,3,1,39,0,0,0,0,4],
[0,2,2,0,4,1,31,0,0,0,2],
[0,1,0,0,0,0,0,36,0,2,0],
[0,0,0,0,0,0,1,5,37,5,1],
[3,0,0,0,0,0,0,0,0,39,0],
[0,0,0,0,0,0,0,0,0,0,38]])
height, width = conf_arr.shape
fig = plt.figure('confusion matrix')
ax = fig.add_subplot(111, aspect='equal')
for x in range(width):
for y in range(height):
ax.annotate(str(conf_arr[x][y]), xy=(y, x), ha='center', va='center')
offset = .5
ax.set_xlim(-offset, width - offset)
ax.set_ylim(-offset, height - offset)
ax.hlines(y=np.arange(height+1)- offset, xmin=-offset, xmax=width-offset)
ax.vlines(x=np.arange(width+1) - offset, ymin=-offset, ymax=height-offset)
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
plt.xticks(range(width), alphabet[:width])
plt.yticks(range(height), alphabet[:height])
plt.savefig('confusion_matrix.png', format='png')
注意当您删除对imshow
的调用时,您需要明确设置x和y限制,如上所示,否则您将只看到左下区域( imshow
根据您通过的内容自动更新限制它)。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.