简体   繁体   中英

How to annotate and correctly place numbers in a heatmap

I'm having problems with heatmap.

I create the following function to show the analysis with heatmap

data = [ 0.00662896, -0.00213044, -0.00156812,  0.01450994, -0.00875174, -0.01561342, -0.00694762,  0.00476027,  0.00470659]

def plot_heatmap(pathOut, data, title, fileName, precis=2, show=False):
    from matplotlib import cm
    fig  = plt.figure()
    n       = int(np.sqrt(len(data)))
    data    = data.reshape(n,n)
    heatmap = plt.pcolor(data,cmap=cm.YlOrBr)
    xLabels = (np.linspace(1,n,n,dtype=int))
    yLabels = (np.linspace(1,n,n,dtype=int))
    xpos    = np.linspace(1,n,n)-0.5
    ypos    = np.linspace(1,n,n)-0.5

    for y in range(n):
        for x in range(n):
            plt.text(x + 0.5, y + 0.5, f'{data[y, x]:.{precis}f}',
                horizontalalignment='center',
                verticalalignment='center',
                )

    plt.colorbar(heatmap, format='%.2f')
    plt.xticks(xpos,xLabels)
    plt.yticks(ypos,yLabels)
    plt.title(f'{title}')
    if (show == False ):
        plt.close(fig)        
    elif (show == True):        
        plt.show()    
    fig.savefig(f'{pathOut}/{fileName}.pdf', format='pdf')   

When I call the function the heatmap is created but not correctly, because I would like to show values at a specific precision. I know how to define text precision and scale precision, but how to adjust data precision to generate the correct heatmap?

In the attached figure, I have 7 cells equal to 0, for my desired precision, but the data used has a larger precision what produce different colors.

数字

  • It is much easier to useseaborn.heatmap , which includes annotations and a colorbar. seaborn is a high-level API for matplotlib .
    • This significantly reduces the number of lines of code.
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import seaborn as sns

def plot_heatmap(pathOut, fileName, data, title, precis=2, show=False):
    n = int(np.sqrt(len(data)))
    data = data.reshape(n, n)
    
    xy_labels = range(1, n+1)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    p = sns.heatmap(data=data, annot=True, fmt=f'.{precis}g', ax=ax,
                    cmap=cm.YlOrBr, xticklabels=xy_labels, yticklabels=xy_labels)

    ax.invert_yaxis()  # invert the axis if desired
    ax.set_title(f'{title}')
    fig.savefig(f'{pathOut}/{fileName}.pdf', format='pdf') 
    if (show == False ):
        plt.close(fig)        
    elif (show == True):        
        plt.show()


data = np.array([ 0.00662896, -0.00213044, -0.00156812,  0.01450994, -0.00875174, -0.01561342, -0.00694762,  0.00476027,  0.00470659])

plot_heatmap('.', 'test', data, 'test', 4, True)

在此处输入图片说明

  • The f-string for plt.txt is not correct. It will be easier to round the value and convert it to a str type.
    • str(round(data[x, y], precis)) instead of f'{data[y, x]:.{precis}f}'
  • data[x, y] should be data[y, x]
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np

def plot_heatmap(pathOut, fileName, data, title, precis=2, show=False):
    fig  = plt.figure(figsize=(8, 6))
    n       = int(np.sqrt(len(data)))
    data    = data.reshape(n, n)
    heatmap = plt.pcolor(data, cmap=cm.YlOrBr)
    xLabels = (np.linspace(1,n,n,dtype=int))
    yLabels = (np.linspace(1,n,n,dtype=int))
    xpos    = np.linspace(1,n,n)-0.5
    ypos    = np.linspace(1,n,n)-0.5

    for y in range(n):
        for x in range(n):
            s = str(round(data[y, x], precis))  # added s for plt.txt and reverse x and y for data addressing
            plt.text(x + 0.5, y + 0.5, s,
                horizontalalignment='center',
                verticalalignment='center',
                )

    plt.colorbar(heatmap, format=f'%.{precis}f')  # add precis to the colorbar
    plt.xticks(xpos,xLabels)
    plt.yticks(ypos,yLabels)
    plt.title(f'{title}')
    fig.savefig(f'{pathOut}/{fileName}.pdf', format='pdf')  # this should be before plt.show()
    if (show == False ):
        plt.close(fig)        
    elif (show == True):        
        plt.show()


# the function expects an array, not a list
data = np.array([ 0.00662896, -0.00213044, -0.00156812,  0.01450994, -0.00875174, -0.01561342, -0.00694762,  0.00476027,  0.00470659])

# function call
plot_heatmap('.', 'test', data, 'test', 4, True)

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM