简体   繁体   中英

How can I add titles to my subplots from a list or dictionary?

I'm working on a code to plot various heatmaps for different meteorological seasons. I want to look at the correlation between different renewable energy resources and demand. I created a list that I called "core_matrix" in which each integer in the list corresponds to a correlation matrix for a different season.

Currently, with the way I've set up my list and programs, I am having trouble with the titles of my subplots. When I use a list of titles in the ax.set_title... line I get an error: raise ValueError(f"Must pass 2-d input. shape={values.shape}") ValueError: Must pass 2-d input. shape=() raise ValueError(f"Must pass 2-d input. shape={values.shape}") ValueError: Must pass 2-d input. shape=()

Could someone please help me fix this?

Below is a copy of my code:


#load df and prepare data
df=pd.read_excel("df3_demand_cfs.xlsx", header=0,index_col=0)
df=df.drop_duplicates(subset="Dates")
df['del_dates']=df["Dates"]
df=df.set_index("Dates")

#####
df.drop(columns=["demand_MWh","hour","monthday",'month'],axis=1,inplace=True)
df=df.rename(columns={"wind capacity":'OnSW',"solar capacity":"Solar","osw capacity":"OSW","norm_demand":'Demand'})
df.sort_index(inplace = True)


# In[3]:
mon=[*range(0,17)]
dates=[*range(0,17)]
orr_matrix=[*range(0,17)]
corr_matrix_list=["JF 15", 'MAM','JJA','SON','DJF 15/16','MAM','JJA','SON','DJF 16/17','MAM','JJA','SON','DJF 17/18','MAM','JJA','SON','D 18']



#mon[0] = pd.DataFrame(df.loc['2015-01':'2015-02'])
mon[0] = df.loc['2015-01':'2015-02']
mon[1] = df.loc['2015-03':'2015-05']
mon[2] = df.loc['2015-06':'2015-08']
mon[3] = df.loc['2015-09':'2015-11']
mon[4] = df.loc['2015-12':'2016-02']
mon[5] = df.loc['2016-03':'2016-05']
mon[6] = df.loc['2016-06':'2016-08']
mon[7] = df.loc['2016-09':'2016-11']
mon[8] = df.loc['2016-12':'2017-02']
mon[9] = df.loc['2016-03':'2017-05']
mon[10] = df.loc['2016-06':'2017-08']
mon[11] = df.loc['2016-09':'2017-11']
mon[12] = df.loc['2017-12':'2018-02']
mon[13] = df.loc['2018-03':'2018-05']
mon[14] = df.loc['2018-06':'2018-08']
mon[15] = df.loc['2018-09':'2018-11']
mon[16] = df.loc['2018-12':'2019-02']

# In[4]:
for i in range(len(mon)):
    mon[i]=mon[i].drop(columns=['del_dates'],axis=1)
    corr_matrix[i] = mon[i].corr(method="spearman")
    corr_matrix[i] = corr_matrix[i].round(2)

# In[5]:
fs1=10
# plt.rcParams["font.family"] = "Helvetica"
font_title = {'family': 'serif',
        'color':  'black',
        'weight': 'bold',
        'size': fs1,
        }

font_label = {'family': 'serif',
        'color':  'black',
        'weight': 'bold',
        'size': fs1,
        }

fs2=14
plt.rc('xtick', labelsize=fs2)
plt.rc('ytick', labelsize=fs2)
plt.rcParams["font.family"] = "serif"

fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(25,15))
for i,ax in enumerate(axes.flatten()):
    ax.set_title(i, fontsize = fs1, loc='center',fontdict=font_title)
    sns.heatmap(corr_matrix_list[i],
            ax=ax,
            annot=True,
            linewidth=.01,
            square=True,
            annot_kws={"size": 16},
            vmin=-1,
            vmax=1,
            cmap="RdBu")
    ax.tick_params(axis='y', labelrotation=0)
    ax.tick_params(axis='x', labelrotation=45)

#plt.tight_layout()
plt.tight_layout(1,1,1)
plt.savefig('heatmaps_monthly.png')
plt.show()

I answered my own question. Though I'm not sure why this fixed the problem or what was wrong initally.

I added

axes=axes.flatten() and I changed the line with the for loop: for i,ax in enumerate(axes):

So now the last part of my code looks like this:

fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(25,15))
axes=axes.flatten()
for i,ax in enumerate(axes):
    ax.set_title(str(corr_matrix_list[i]), fontsize = fs1, loc='center',fontdict=font_title)
    sns.heatmap(corr_matrix[i],
            ax=ax,
            annot=True,
            linewidth=.01,
            square=True,
            annot_kws={"size": 16},
            vmin=-1,
            vmax=1,
            cmap="RdBu")
    ax.tick_params(axis='y', labelrotation=0)
    ax.tick_params(axis='x', labelrotation=45)

#plt.tight_layout()
plt.tight_layout(1,1,1)
plt.savefig('heatmaps_monthly.png')
plt.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM