I'm attempting to loop over multiple company_sizes
to create a 4x2 heatmap for the two metrics: percent_engaged
and avg_sales
and have the two metric heat maps side by side for each company_size. Ideally, if I could also sort company size from small at the top to extra large at the bottom that would be great. Appreciate any help.
The heatmap is not displaying anything right now, but trying to show the 4 company sizes for each of the 2 metrics side by side
company_sizes: 'small', 'medium', 'large', 'extra large'
fig, ax= plt.subplots(4,2,figsize(30,50))
metrics = ['avg_sales', 'percent_engaged']
for company_size, ax in zip(df.company_size.unique(),axes.flat):
df2 = df[df.company_size==company_size]
for metric in metrics:
sns.heatmap(df2.pivot(index='months_since_purchase', columns='year', values='metrics'),annot=True,cmap='YlBnBu',linewidth=2.5, linecolor='white', cbar=False, ax=ax)
ax.set_title("Purchases per company for: " + company_size)
ax.set_xlabel("Year")
ax.set_ylabel("Months Since First Purchase")
ax.tick_params(axis='both', which='major', labelsize=12)
plt.show()
Sample Data:
Columns = ['year', company_size, 'months_since_purchase', 'percent_engaged', 'avg_sales']
2019-01-01 small 0 1.00 2.00
2019-01-01 small 1 0.90 3.00
2019-01-01 small 2 0.86 2.94
2019-01-01 small 3 0.81 2.88
2019-01-01 small 4 0.77 2.82
2019-01-01 small 5 0.73 2.77
2019-01-01 small 6 0.70 2.71
2019-01-01 small 7 0.66 2.66
2019-01-01 small 8 0.63 2.60
2019-01-01 small 9 0.60 2.55
2019-01-01 small 10 0.57 2.50
2019-01-01 small 11 0.54 2.45
2019-01-01 small 12 0.51 2.40
2020-01-01 small 0 1.00 2.00
2020-01-01 small 1 0.90 3.00
2020-01-01 small 2 0.86 2.76
2020-01-01 small 3 0.81 2.54
2020-01-01 small 4 0.77 2.34
2020-01-01 small 5 0.73 2.15
2020-01-01 small 6 0.70 1.98
2020-01-01 small 7 0.66 1.82
2020-01-01 small 8 0.63 1.67
2020-01-01 small 9 0.60 1.54
2020-01-01 small 10 0.57 1.42
2020-01-01 small 11 0.54 1.30
2020-01-01 small 12 0.51 1.20
Very similar to your approach of two for-loops. I kept axs
as a 2d array and indexed into it using enumerate
s from the two for-loops
import pandas as pd
#Read in the example table
import io
df = pd.read_csv(io.StringIO(
"""2019-01-01 small 0 1.00 2.00
2019-01-01 small 1 0.90 3.00
2019-01-01 small 2 0.86 2.94
2019-01-01 small 3 0.81 2.88
2019-01-01 small 4 0.77 2.82
2019-01-01 small 5 0.73 2.77
2019-01-01 small 6 0.70 2.71
2019-01-01 small 7 0.66 2.66
2019-01-01 small 8 0.63 2.60
2019-01-01 small 9 0.60 2.55
2019-01-01 small 10 0.57 2.50
2019-01-01 small 11 0.54 2.45
2019-01-01 small 12 0.51 2.40
2020-01-01 small 0 1.00 2.00
2020-01-01 small 1 0.90 3.00
2020-01-01 small 2 0.86 2.76
2020-01-01 small 3 0.81 2.54
2020-01-01 small 4 0.77 2.34
2020-01-01 small 5 0.73 2.15
2020-01-01 small 6 0.70 1.98
2020-01-01 small 7 0.66 1.82
2020-01-01 small 8 0.63 1.67
2020-01-01 small 9 0.60 1.54
2020-01-01 small 10 0.57 1.42
2020-01-01 small 11 0.54 1.30
2020-01-01 small 12 0.51 1.20"""
),delim_whitespace=True, header=None)
df.columns = ['year', 'company_size', 'months_since_purchase', 'percent_engaged', 'avg_sales']
#Melt percent_engaged/avg_sales to be a metric column so we can groupby on it later
df = df.melt(
id_vars = ['year','company_size','months_since_purchase'],
value_vars = ['percent_engaged','avg_sales'],
var_name = 'metric',
value_name = 'value',
)
#Determine the number of rows/columns of the figure
num_company_sizes = df['company_size'].nunique()
num_metrics = df['metric'].nunique()
#Create the figure
fig, axs = plt.subplots(
num_company_sizes,
num_metrics,
squeeze=False, #makes sure a 2D array of axs is always returned even if only 1 row/column
figsize = (4*num_metrics, 4*num_company_sizes), #have fig-size depend on number of companys/metrics
)
#Loop through sizes and metrics
for i,(company_size,size_df) in enumerate(df.groupby('company_size')):
for j,(metric,size_metric_df) in enumerate(size_df.groupby('metric')):
piv_df = size_metric_df.pivot(index='months_since_purchase', columns='year', values='value')
sns.heatmap(
piv_df,
ax = axs[i][j],
)
axs[i][j].set_title(f'{company_size} {metric}')
plt.show()
plt.close()
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.