简体   繁体   中英

How change the color of boxes in confusion matrix using sklearn?

Here is my code snippet to produce confusion matrix: I am wondering how can I change the color of boxes in confusion matrix for those boxes which are not located in diagonal same as heatmap using sklearn.

    nb_classes = 15    
confusion_matrix = torch.zeros(nb_classes, nb_classes)

with torch.no_grad():
    for i, (inputs, target, classes, im_path) in enumerate(dataLoaders['test']):

        inputs = inputs.to(device)
        target = target.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        for t, p in zip(target.view(-1), preds.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

num_classes = 15
class_names = ['A2CH', 'A3CH', 'A4CH_LV', 'A4CH_RV', 'A5CH', 'Apical_MV_LA_IAS',
                 'OTHER', 'PLAX_TV', 'PLAX_full', 'PLAX_valves', 'PSAX_AV', 'PSAX_LV',
                 'Subcostal_IVC', 'Subcostal_heart', 'Suprasternal']                

plt.figure()
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Blues)

tick_marks = numpy.arange(num_classes)
classNames = class_names

thresh = confusion_matrix.max() / 2.
for i in range(confusion_matrix.shape[0]):
    for j in range(confusion_matrix.shape[1]):
        plt.text(j, i, format(confusion_matrix[i, j]),
                ha="center", va="center",
                color="white" if  confusion_matrix[i, j] == 0 or confusion_matrix[i, j] > thresh else "black") 
plt.tight_layout()
plt.colorbar()
return plt
plt.show()   

在此处输入图片说明

Use heatmap to plot confusion matrix

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3], 
    [3,31,0,0,0,0,0,0,0,0,0], 
    [0,4,41,0,0,0,0,0,0,0,1], 
    [0,1,0,30,0,6,0,0,0,0,1], 
    [0,0,0,0,38,10,0,0,0,0,0], 
    [0,0,0,3,1,39,0,0,0,0,4], 
    [0,2,2,0,4,1,31,0,0,0,2],
    [0,1,0,0,0,0,0,36,0,2,0], 
    [0,0,0,0,0,0,1,5,37,5,1], 
    [3,0,0,0,0,0,0,0,0,39,0], 
    [0,0,0,0,0,0,0,0,0,0]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
              columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True,cmap="OrRd")

heatmap accept an extra argument cmap to change the color of matrix. These are some possible values for camp.

cmap = [Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, 
BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, 
Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, 
Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, 
PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, 
RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, 
Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Wistia, Wistia_r, YlGn, YlGnBu, 
YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r, afmhot, afmhot_r, autumn,
autumn_r, binary, binary_r, bone, bone_r, brg, brg_r, bwr, bwr_r, cividis, 
cividis_r, cool, cool_r, coolwarm, coolwarm_r, copper, copper_r, cubehelix, 
cubehelix_r, flag, flag_r, gist_earth, gist_earth_r, gist_gray, gist_gray_r,
gist_heat, gist_heat_r, gist_ncar, gist_ncar_r, gist_rainbow, gist_rainbow_r, 
gist_stern, gist_stern_r, gist_yarg, gist_yarg_r, gnuplot, gnuplot2, gnuplot2_r, 
gnuplot_r, gray, gray_r, hot, hot_r, hsv, hsv_r, icefire, icefire_r, inferno, 
inferno_r, jet, jet_r, magma, magma_r, mako, mako_r, nipy_spectral, nipy_spectral_r,
ocean, ocean_r, pink, pink_r, plasma, plasma_r, prism, prism_r, rainbow, rainbow_r, 
rocket, rocket_r, seismic, seismic_r, spring, spring_r, summer, summer_r, tab10, 
tab10_r, tab20, tab20_r, tab20b, tab20b_r, tab20c, tab20c_r, terrain, terrain_r, 
viridis, viridis_r, vlag, vlag_r, winter, winter_r]

cmap = "OrRd" cmap = "OrRd"

cmap = "Greens_r" cmap = "Greens_r" cmap = "OrRd_r" cmap = "OrRd_r"

def plot_confusion_matrix(y_true, y_pred, classes,
                      normalize=False,
                      title=None,
                      cmap=plt.cm.Blues):

you can change a name in cmap=plt.cm.Blues as the color you want such as green, red, orange, etc. Don't forget to add s in every word of colors. In addition, there are two default forms of each confusion matrix color. For example, it is green. 1. Greens. it is for green color in diagonal line. 2. Greens_r. It is for green color outside of diagonal line.

hopefully, it is helpful for you.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM