简体   繁体   中英

How to annotate only the diagonal elements of a seaborn heatmap

I am using Seaborn heatmap to plot the output of a large confusion matrix. Since the diagonal element represents the correct prediction, they are more important to show the number/correct rate. As the question suggests, how to annotate only the diagonal entries in a heatmap?

I have consulted this website https://seaborn.pydata.org/examples/many_pairwise_correlations.html , but it does not help with how to annotate only the diagonal entries. Hope somebody could help with that. Thank you in advance!

Does this help you in getting what you have in mind? The URL example given by you does not have a diagonal, I had annotated the diagonal below the main diagonal instead. To annotate your confusion matrix diagonal, you can adapt to my code by changing the -1 value in np.diag(..., -1) to 0.

Note the additional parameter fmt='' that I had added in sns.heatmap(...) because my annot matrix elements are strings.

Code

from string import ascii_letters
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="white")

# Generate a large random dataset
rs = np.random.RandomState(33)
y = rs.normal(size=(100, 26))
d = pd.DataFrame(data=y,
                 columns=list(ascii_letters[26:]))

# Compute the correlation matrix
corr = d.corr()

# Generate a mask for the upper triangle
mask = np.zeros_like(corr, dtype=np.bool)
mask[np.triu_indices_from(mask)] = True

# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(11, 9))

# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)

# Generate the annotation
annot = np.diag(np.diag(corr.values,-1),-1)
annot = np.round(annot,2)
annot = annot.astype('str')
annot[annot=='0.0']=''

# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
            square=True, linewidths=.5, cbar_kws={"shrink": .5}, annot=annot, fmt='')

plt.show()

Output 在此处输入图像描述

In a related question, someone asked how to annotate the diagonal elements with strings. Here is an example:

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np

flights = sns.load_dataset('flights')
flights = flights.pivot('year', 'month', 'passengers')
corr_data = np.corrcoef(flights.to_numpy())

up_triang = np.triu(np.ones_like(corr_data)).astype(bool)
ax = sns.heatmap(corr_data, cmap='flare', xticklabels=False, yticklabels=False, square=True,
                 linecolor='white', linewidths=0.5,
                 cbar=True, mask=up_triang, cbar_kws={'shrink': 0.6, 'pad': 0.02, 'label': 'correlation'})
ax.invert_xaxis()
for i, label in enumerate(flights.index):
    ax.text(i + 0.2, i + 0.5, label, ha='right', va='center')
plt.show()

带注释对角线的热图

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM