简体   繁体   中英

How to Return a MatPlotLib Figure with its corresponding Legend

I'm trying to define a function which returns a pre-styled figure with certain grid, style, width and other properties. However, when I return the fig and its axes, the legend is missing. Here's a simplified example:

def getfig():
    plt.style.use('default')
    fig, axs = plt.subplots(1, 1, figsize=(1,1), sharey=False)

    if issubclass(type(axs),mpl.axes.SubplotBase):
        axs=[axs]

    for ax in axs:
        ax.grid(color='grey', axis='both', linestyle='-.', linewidth=0.4)
        ax.legend(loc=9, bbox_to_anchor=(0.5, -0.3), ncol=2)

    return fig,axs

fig,axs=getfig()
axs[0].plot(range(10), label="label")

在此处输入图片说明

What am I missing?

Thanks!


UPDATE :

This is what I'm using so far but I think there really should be a way to force all future legends associated to a figure to have a certain style.

def fig_new(rows=1,columns=1,figsize=(1,1)):
    plt.style.use('default')
    fig, axs = plt.subplots(rows,columns, figsize=figsize, sharey=False)

    if issubclass(type(axs),mpl.axes.SubplotBase):
        axs=[axs]

    for ax in axs:
        ax.grid(color='grey', axis='both', linestyle='-.', linewidth=0.4)

    return fig,axs

def fig_leg(fig):
    for ax  in fig.get_axes():
        ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.3), ncol=5)

fig,axs=fig_new()
axs[0].plot(range(10), label="label")
fig_leg(fig) 

You need to call the legend after an artist with a label is plotted to the axes. An option is to let the function return the arguments to use for the legend afterwards.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

def getfig():
    plt.style.use('default')
    fig, axs = plt.subplots(1, 1, figsize=(1,1), sharey=False)

    if issubclass(type(axs),mpl.axes.SubplotBase):
        axs=np.array([axs])

    legendkw = []
    for ax in axs:
        ax.grid(color='grey', axis='both', linestyle='-.', linewidth=0.4)
        legendkw.append(dict(loc=9, bbox_to_anchor=(0.5, -0.3), ncol=2))

    return fig,axs,legendkw

fig,axs,kw=getfig()
axs[0].plot(range(10), label="label")
for i,ax in enumerate(axs.flat):
    ax.legend(**kw[i])
plt.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM