简体   繁体   English

如何使用 sklearn 更改混淆矩阵中框的颜色?

[英]How change the color of boxes in confusion matrix using sklearn?

Here is my code snippet to produce confusion matrix: I am wondering how can I change the color of boxes in confusion matrix for those boxes which are not located in diagonal same as heatmap using sklearn.这是我生成混淆矩阵的代码片段:我想知道如何使用 sklearn 更改混淆矩阵中那些不在与热图相同的对角线上的框的颜色。

    nb_classes = 15    
confusion_matrix = torch.zeros(nb_classes, nb_classes)

with torch.no_grad():
    for i, (inputs, target, classes, im_path) in enumerate(dataLoaders['test']):

        inputs = inputs.to(device)
        target = target.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        for t, p in zip(target.view(-1), preds.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

num_classes = 15
class_names = ['A2CH', 'A3CH', 'A4CH_LV', 'A4CH_RV', 'A5CH', 'Apical_MV_LA_IAS',
                 'OTHER', 'PLAX_TV', 'PLAX_full', 'PLAX_valves', 'PSAX_AV', 'PSAX_LV',
                 'Subcostal_IVC', 'Subcostal_heart', 'Suprasternal']                

plt.figure()
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Blues)

tick_marks = numpy.arange(num_classes)
classNames = class_names

thresh = confusion_matrix.max() / 2.
for i in range(confusion_matrix.shape[0]):
    for j in range(confusion_matrix.shape[1]):
        plt.text(j, i, format(confusion_matrix[i, j]),
                ha="center", va="center",
                color="white" if  confusion_matrix[i, j] == 0 or confusion_matrix[i, j] > thresh else "black") 
plt.tight_layout()
plt.colorbar()
return plt
plt.show()   

在此处输入图片说明

Use heatmap to plot confusion matrix使用热图绘制混淆矩阵

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3], 
    [3,31,0,0,0,0,0,0,0,0,0], 
    [0,4,41,0,0,0,0,0,0,0,1], 
    [0,1,0,30,0,6,0,0,0,0,1], 
    [0,0,0,0,38,10,0,0,0,0,0], 
    [0,0,0,3,1,39,0,0,0,0,4], 
    [0,2,2,0,4,1,31,0,0,0,2],
    [0,1,0,0,0,0,0,36,0,2,0], 
    [0,0,0,0,0,0,1,5,37,5,1], 
    [3,0,0,0,0,0,0,0,0,39,0], 
    [0,0,0,0,0,0,0,0,0,0]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
              columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True,cmap="OrRd")

heatmap accept an extra argument cmap to change the color of matrix. heatmap 接受一个额外的参数 cmap 来改变矩阵的颜色。 These are some possible values for camp.这些是 Camp 的一些可能值。

cmap = [Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, 
BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, 
Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, 
Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, 
PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, 
RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, 
Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Wistia, Wistia_r, YlGn, YlGnBu, 
YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r, afmhot, afmhot_r, autumn,
autumn_r, binary, binary_r, bone, bone_r, brg, brg_r, bwr, bwr_r, cividis, 
cividis_r, cool, cool_r, coolwarm, coolwarm_r, copper, copper_r, cubehelix, 
cubehelix_r, flag, flag_r, gist_earth, gist_earth_r, gist_gray, gist_gray_r,
gist_heat, gist_heat_r, gist_ncar, gist_ncar_r, gist_rainbow, gist_rainbow_r, 
gist_stern, gist_stern_r, gist_yarg, gist_yarg_r, gnuplot, gnuplot2, gnuplot2_r, 
gnuplot_r, gray, gray_r, hot, hot_r, hsv, hsv_r, icefire, icefire_r, inferno, 
inferno_r, jet, jet_r, magma, magma_r, mako, mako_r, nipy_spectral, nipy_spectral_r,
ocean, ocean_r, pink, pink_r, plasma, plasma_r, prism, prism_r, rainbow, rainbow_r, 
rocket, rocket_r, seismic, seismic_r, spring, spring_r, summer, summer_r, tab10, 
tab10_r, tab20, tab20_r, tab20b, tab20b_r, tab20c, tab20c_r, terrain, terrain_r, 
viridis, viridis_r, vlag, vlag_r, winter, winter_r]

cmap = "OrRd" cmap = "OrRd" cmap = "OrRd"

cmap = "Greens_r" cmap = "Greens_r" cmap = "Greens_r" cmap = "OrRd_r" cmap = "OrRd_r" cmap = "OrRd_r"

def plot_confusion_matrix(y_true, y_pred, classes,
                      normalize=False,
                      title=None,
                      cmap=plt.cm.Blues):

you can change a name in cmap=plt.cm.Blues as the color you want such as green, red, orange, etc. Don't forget to add s in every word of colors.您可以将cmap=plt.cm.Blues的名称更改为您想要的颜色,例如绿色、红色、橙色等。不要忘记在颜色的每个单词中添加 s。 In addition, there are two default forms of each confusion matrix color.此外,每种混淆矩阵颜色有两种默认形式。 For example, it is green.例如,它是绿色的。 1. Greens. 1. 绿党。 it is for green color in diagonal line.它用于对角线上的绿色。 2. Greens_r. 2. Greens_r。 It is for green color outside of diagonal line.它用于对角线外的绿色。

hopefully, it is helpful for you.希望对你有帮助。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM