简体   繁体   中英

Plot Seaborn heatmaps side by side with for loop

I'm attempting to loop over multiple company_sizes to create a 4x2 heatmap for the two metrics: percent_engaged and avg_sales and have the two metric heat maps side by side for each company_size. Ideally, if I could also sort company size from small at the top to extra large at the bottom that would be great. Appreciate any help.

The heatmap is not displaying anything right now, but trying to show the 4 company sizes for each of the 2 metrics side by side

company_sizes: 'small', 'medium', 'large', 'extra large'

fig, ax= plt.subplots(4,2,figsize(30,50))
metrics = ['avg_sales', 'percent_engaged']
for company_size,  ax in zip(df.company_size.unique(),axes.flat):
   df2 = df[df.company_size==company_size]
   for metric in metrics:
   sns.heatmap(df2.pivot(index='months_since_purchase', columns='year', values='metrics'),annot=True,cmap='YlBnBu',linewidth=2.5, linecolor='white', cbar=False, ax=ax)
ax.set_title("Purchases per company for: " + company_size)
ax.set_xlabel("Year")
ax.set_ylabel("Months Since First Purchase")
ax.tick_params(axis='both', which='major', labelsize=12)
plt.show()

Sample Data:

Columns = ['year', company_size, 'months_since_purchase', 'percent_engaged', 'avg_sales']

2019-01-01  small   0   1.00    2.00
2019-01-01  small   1   0.90    3.00
2019-01-01  small   2   0.86    2.94
2019-01-01  small   3   0.81    2.88
2019-01-01  small   4   0.77    2.82
2019-01-01  small   5   0.73    2.77
2019-01-01  small   6   0.70    2.71
2019-01-01  small   7   0.66    2.66
2019-01-01  small   8   0.63    2.60
2019-01-01  small   9   0.60    2.55
2019-01-01  small   10  0.57    2.50
2019-01-01  small   11  0.54    2.45
2019-01-01  small   12  0.51    2.40
2020-01-01  small   0   1.00    2.00
2020-01-01  small   1   0.90    3.00
2020-01-01  small   2   0.86    2.76
2020-01-01  small   3   0.81    2.54
2020-01-01  small   4   0.77    2.34
2020-01-01  small   5   0.73    2.15
2020-01-01  small   6   0.70    1.98
2020-01-01  small   7   0.66    1.82
2020-01-01  small   8   0.63    1.67
2020-01-01  small   9   0.60    1.54
2020-01-01  small   10  0.57    1.42
2020-01-01  small   11  0.54    1.30
2020-01-01  small   12  0.51    1.20

Very similar to your approach of two for-loops. I kept axs as a 2d array and indexed into it using enumerate s from the two for-loops

在此处输入图像描述

import pandas as pd

#Read in the example table
import io
df = pd.read_csv(io.StringIO(
"""2019-01-01  small   0   1.00    2.00
2019-01-01  small   1   0.90    3.00
2019-01-01  small   2   0.86    2.94
2019-01-01  small   3   0.81    2.88
2019-01-01  small   4   0.77    2.82
2019-01-01  small   5   0.73    2.77
2019-01-01  small   6   0.70    2.71
2019-01-01  small   7   0.66    2.66
2019-01-01  small   8   0.63    2.60
2019-01-01  small   9   0.60    2.55
2019-01-01  small   10  0.57    2.50
2019-01-01  small   11  0.54    2.45
2019-01-01  small   12  0.51    2.40
2020-01-01  small   0   1.00    2.00
2020-01-01  small   1   0.90    3.00
2020-01-01  small   2   0.86    2.76
2020-01-01  small   3   0.81    2.54
2020-01-01  small   4   0.77    2.34
2020-01-01  small   5   0.73    2.15
2020-01-01  small   6   0.70    1.98
2020-01-01  small   7   0.66    1.82
2020-01-01  small   8   0.63    1.67
2020-01-01  small   9   0.60    1.54
2020-01-01  small   10  0.57    1.42
2020-01-01  small   11  0.54    1.30
2020-01-01  small   12  0.51    1.20"""
),delim_whitespace=True, header=None)

df.columns = ['year', 'company_size', 'months_since_purchase', 'percent_engaged', 'avg_sales']

#Melt percent_engaged/avg_sales to be a metric column so we can groupby on it later
df = df.melt(
    id_vars = ['year','company_size','months_since_purchase'],
    value_vars = ['percent_engaged','avg_sales'],
    var_name = 'metric',
    value_name = 'value',
)

#Determine the number of rows/columns of the figure
num_company_sizes = df['company_size'].nunique()
num_metrics = df['metric'].nunique()

#Create the figure
fig, axs = plt.subplots(
    num_company_sizes,
    num_metrics,
    squeeze=False, #makes sure a 2D array of axs is always returned even if only 1 row/column
    figsize = (4*num_metrics, 4*num_company_sizes), #have fig-size depend on number of companys/metrics
)

#Loop through sizes and metrics
for i,(company_size,size_df) in enumerate(df.groupby('company_size')):
    for j,(metric,size_metric_df) in enumerate(size_df.groupby('metric')):
        piv_df = size_metric_df.pivot(index='months_since_purchase', columns='year', values='value')
        sns.heatmap(
            piv_df,
            ax = axs[i][j],
        )
        axs[i][j].set_title(f'{company_size} {metric}')
        
plt.show()
plt.close()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM