[英]Color confusion matrix heatmap using percentage of correctness
这是 plot 的代码示例,一个用于多分类问题的混淆矩阵。
cf_matrix = np.array([[50, 2, 38],
[7, 43, 32],
[1, 0, 4]])
labels = ['col1','col2','col3']
df_confusion = pd.DataFrame(cf_matrix, index = labels, columns=labels)
df_confusion['TOTAL'] = df_confusion.sum(axis=1)
df_confusion.loc['TOTAL']= df_confusion.sum()
plt.figure(figsize=(24, 10))
sns.set(font_scale = 1.5)
ax = sns.heatmap(df_confusion, annot=True, cmap='Blues', fmt="d")
ax.set_title('Confusion Matrix\n\n',size=22)
ax.set_xlabel('\nPredicted Values',size=20)
ax.set_ylabel('Actual Values ', size=20)
plt.show()
如何更改颜色条以使颜色与元素数量无关,而是基于每个单元格的元素百分比除以该 class(行)的总实际元素。 例如,在这种情况下,第三个 class col3 将具有最高颜色,因为它相对于 col1 和 col2 具有 4/5 = 80% 的正确预测,分别具有:50/90 = 55% 和 43/82 = 52%正确的预测。
由于cmap
参数使用data
来应用渐变,因此您需要将data
更改为百分比,然后使用annot
参数将值覆盖为实际数字。
所以,我想你想要像下面这样的东西。 注意我已将df_percentages.TOTAL
的百分比设置为0
下面; 否则TOTAL
列显然会完全变成深蓝色。
无论如何,既然您知道了逻辑,我相信您会知道如何根据自己的喜好调整df_percentages
的值。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
cf_matrix = np.array([[50, 2, 38],
[7, 43, 32],
[1, 0, 4]])
labels = ['col1','col2','col3']
df_confusion = pd.DataFrame(cf_matrix, index = labels, columns=labels)
df_confusion['TOTAL'] = df_confusion.sum(axis=1)
df_confusion.loc['TOTAL']= df_confusion.sum()
# get percentages
df_percentages = df_confusion.div(df_confusion.TOTAL, axis=0)
df_percentages.TOTAL = 0
# =============================================================================
# col1 col2 col3 TOTAL
# col1 0.555556 0.022222 0.422222 0
# col2 0.085366 0.524390 0.390244 0
# col3 0.200000 0.000000 0.800000 0
# TOTAL 0.327684 0.254237 0.418079 0
# =============================================================================
plt.figure(figsize=(24, 10))
sns.set(font_scale = 1.5)
# cmap using data for color, taking values from annot
ax = sns.heatmap(data=df_percentages, annot=df_confusion, cmap='Blues', fmt="d",
cbar_kws={'label': 'percentages'})
ax.set_title('Confusion Matrix\n\n',size=22)
ax.set_xlabel('\nPredicted Values',size=20)
ax.set_ylabel('Actual Values ', size=20)
plt.show()
结果:
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.