简体   繁体   中英

How to overlay seaborn heatmap on matplotlib figure

I am trying to overlay a heatmap on top of a matplotlib figure of a football pitch.

This is the image of the matplotlib pitch created by the code block below:

在此处输入图片说明


import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as plt_p
import numpy as np

def draw_pitch(ax):
    # size of the pitch is 120, 80
    #Create figure

    #Pitch Outline & Centre Line
    plt.plot([0,0],[0,80], color="black")
    plt.plot([0,120],[80,80], color="black")
    plt.plot([120,120],[80,0], color="black")
    plt.plot([120,0],[0,0], color="black")
    plt.plot([60,60],[0,80], color="black")

    #Left Penalty Area
    plt.plot([14.6,14.6],[57.8,22.2],color="black")
    plt.plot([0,14.6],[57.8,57.8],color="black")
    plt.plot([0,14.6],[22.2,22.2],color="black")

    #Right Penalty Area
    plt.plot([120,105.4],[57.8,57.8],color="black")
    plt.plot([105.4,105.4],[57.8,22.5],color="black")
    plt.plot([120, 105.4],[22.5,22.5],color="black")

    #Left 6-yard Box
    plt.plot([0,4.9],[48,48],color="black")
    plt.plot([4.9,4.9],[48,32],color="black")
    plt.plot([0,4.9],[32,32],color="black")

    #Right 6-yard Box
    plt.plot([120,115.1],[48,48],color="black")
    plt.plot([115.1,115.1],[48,32],color="black")
    plt.plot([120,115.1],[32,32],color="black")

    #Prepare Circles
    centreCircle = plt.Circle((60,40),8.1,color="black",fill=False)
    centreSpot = plt.Circle((60,40),0.71,color="black")
    leftPenSpot = plt.Circle((9.7,40),0.71,color="black")
    rightPenSpot = plt.Circle((110.3,40),0.71,color="black")

    #Draw Circles
    ax.add_patch(centreCircle)
    ax.add_patch(centreSpot)
    ax.add_patch(leftPenSpot)
    ax.add_patch(rightPenSpot)

    #Prepare Arcs
    # arguments for arc
    # x, y coordinate of centerpoint of arc
    # width, height as arc might not be circle, but oval
    # angle: degree of rotation of the shape, anti-clockwise
    # theta1, theta2, start and end location of arc in degree
    leftArc = plt_p.Arc((9.7,40),height=16.2,width=16.2,angle=0,theta1=310,theta2=50,color="black")
    rightArc = plt_p.Arc((110.3,40),height=16.2,width=16.2,angle=0,theta1=130,theta2=230,color="black")

    #Draw Arcs
    ax.add_patch(leftArc)
    ax.add_patch(rightArc)

fig=plt.figure()
fig.set_size_inches(7, 5)
ax=fig.add_subplot(1,1,1)
draw_pitch(ax)
plt.axis('off')
plt.show()

As recommended on previous posts, I have tried to pass in the ax argument into sns.heatmap() and change the alpha to increase the transparency of the heatmap. However, the heatmap still covers the entire figure and football pitch is not visible.

When running the below code, I get the following result:

在此处输入图片说明

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as plt_p
import numpy as np

#DUMMY DATA 
df_test = pd.DataFrame(np.array([[43.2, 22.4, 0], [-5.1,-53.2,1], [33.5,-19.2,0],
                                 [23.2, 32.4, 1], [-5.3,-53.2,1], [33.5,-69.2,0],
                                 [53.2, -42.4, 0], [-5.4,-53.2,0], [-3.5,-39.2,0],
                                 [63.2, 62.4, 1], [-52,-53.2,0], [37.5,-11.2,1],
                                 [113.2, 72.4, 0], [-34.2,-53.2,0], [42.5,-119.2,1]]),
                  columns=['x', 'y', 'outcome'])

#CREATES THE HEATMAP OVERLAY ON THE FOOTBALL PITCH
def pass_comp_map(df):
    df['x_bands'] = pd.qcut(df['x'],4,labels=False)
    df['y_bands'] = pd.qcut(df['y'],3,labels=False)
    df_pass = df[['x_bands','y_bands','outcome']]
    df_sum = df_pass.groupby(['x_bands','y_bands'], as_index=False).sum() # get total number of completed passes 
    df_count = df_pass.groupby(['x_bands','y_bands'], as_index=False).count() #get total number passes
    df_agg = pd.merge(df_sum, df_count['outcome'].to_frame(), how ='left',left_index=True,right_index=True)
    df_agg['pass_comp'] = df_agg['outcome_x'] / df_agg['outcome_y']
    data = df_agg[['x_bands','y_bands','pass_comp']]
    data_pivot = data.pivot_table(index='y_bands', columns='x_bands', values='pass_comp')
    data_pivot = data_pivot.fillna(0)

    #OVERLAY FIGIURE CREATED HERE
    fig=plt.figure()
    fig.set_size_inches(7, 5)
    ax=fig.add_subplot(1,1,1)
    draw_pitch(ax)
    plt.axis('off')
    sns.heatmap(data_pivot,cbar=False, xticklabels=False, yticklabels=False,annot=True,alpha = 0.5,ax=ax)
    plt.show()

pass_comp_map(df_test)

How do I make it such that the heatmap is overlayed on the football pitch, but the football pitch is still visible?

As commented already, I would recommend not to use a seaborn.heatmap, because it's pretty much impossible to scale it to the size of the field.

If you make no further changes to your code, just replace the line sns.heatmap(...) by

ax.imshow(data_pivot.values, zorder=0, aspect="auto", extent=(0,120,0,80), 
          cmap=sns.cubehelix_palette(light=1, as_cmap=True))

The plot already looks like

在此处输入图片说明

You can still play with alpha or use a different colormap etc. To also have the heatmap annotated, you can do so via

scale = np.array([120,80])
ax.imshow(data_pivot.values, zorder=0, aspect="auto", extent=(0,scale[0],0,scale[1]), 
          cmap=sns.cubehelix_palette(light=1, as_cmap=True), origin="lower")
offs = np.array([scale[0]/data_pivot.values.shape[1], scale[1]/data_pivot.values.shape[0]])
for pos, val in np.ndenumerate(data_pivot.values):
    ax.annotate(f"{val:.2f}", xy=np.array(pos)[::-1]*offs+offs/2, ha="center", va="center")
ax.invert_yaxis()

在此处输入图片说明

Have a look at your scales, the pitch is on a completely different scale to the heatmap. If you zoom out you will see something that looks a bit like this:

Zoom out of plot

As demonstrated in @simon-rogers' answer , you have an issue of scale between your pitch drawing and your heatmap.

Seaborn's heatmap will be drawn on an axe with limits: [0-number of columns] in x and [0-number of lines] in y. With your example dataframe, the resulting plot is therefore 4x3, while your pitch is 120x80.

The solution is to draw the pitch at a 4x3 scale. Better yet, draw the pitch with parametrized dimensions so it can fit various sized heatmaps.

I've started working on the function, but I didn't have time to figure out the ratios for the circles, I leave that as an exercise to you .

def draw_pitch(ax, width=120, height=80):
    # size of the pitch is width, height
    #Create figure

    #Pitch Outline & Centre Line
    plt.plot([0,0],[0,height], color="black")
    plt.plot([0,width],[height,height], color="black")
    plt.plot([width,width],[height,0], color="black")
    plt.plot([width,0],[0,0], color="black")
    plt.plot([width/2,width/2],[0,height], color="black")

    #Left Penalty Area
    plt.plot([width*0.12,width*0.12],[height*0.72,height*0.28],color="black")
    plt.plot([0,width*0.12],[height*0.72,height*0.72],color="black")
    plt.plot([0,width*0.12],[height*0.28,height*0.28],color="black")

    #Right Penalty Area
    plt.plot([width,width*0.88],[height*0.72,height*0.72],color="black")
    plt.plot([width*0.88,width*0.88],[height*0.72,height*0.28],color="black")
    plt.plot([width, width*0.88],[height*0.28,height*0.28],color="black")

    #Left 6-yard Box
    plt.plot([0,width*0.04],[height*0.6,height*0.6],color="black")
    plt.plot([width*0.04,width*0.04],[height*0.6,height*0.4],color="black")
    plt.plot([0,width*0.04],[height*0.4,height*0.4],color="black")

    #Right 6-yard Box
    plt.plot([width,width*0.96],[height*0.6,height*0.6],color="black")
    plt.plot([width*0.96,width*0.96],[height*0.6,height*0.4],color="black")
    plt.plot([width,width*0.96],[height*0.4,height*0.4],color="black")

    #Prepare Circles
    centreCircle = plt.Circle((width/2,40),8.1,color="black",fill=False)
    centreSpot = plt.Circle((width/2,40),0.71,color="black")
    leftPenSpot = plt.Circle((9.7,40),0.71,color="black")
    rightPenSpot = plt.Circle((110.3,40),0.71,color="black")

    #Draw Circles
    ax.add_patch(centreCircle)
    ax.add_patch(centreSpot)
    ax.add_patch(leftPenSpot)
    ax.add_patch(rightPenSpot)

    #Prepare Arcs
    # arguments for arc
    # x, y coordinate of centerpoint of arc
    # width, height as arc might not be circle, but oval
    # angle: degree of rotation of the shape, anti-clockwise
    # theta1, theta2, start and end location of arc in degree
    leftArc = plt_p.Arc((9.7,40),height=16.2,width=16.2,angle=0,theta1=310,theta2=50,color="black")
    rightArc = plt_p.Arc((110.3,40),height=16.2,width=16.2,angle=0,theta1=130,theta2=230,color="black")

    #Draw Arcs
    ax.add_patch(leftArc)
    ax.add_patch(rightArc)

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM