简体   繁体   English

Seaborn Matplotlib:在情节之外获取自定义图例

[英]Seaborn Matplotlib: Get custom legend outside of plot

I have a function that generates up to 4 different plots at once.我有一个函数可以一次生成多达 4 个不同的图。 The legend needs to be saved separately from the plot.图例需要与情节分开保存。

In my code I collect all of the labels and even create some for the peaks in the bar graphs.在我的代码中,我收集了所有标签,甚至为条形图中的峰值创建了一些标签。

I then display them separately but for some reason the plot is coming out blank.然后我分别显示它们,但由于某种原因,情节出现空白。

Code:代码:

degree = ' \u2109 '



def generate_graph_image():
    filename = 'test_bar.png'
    legend_file_name = 'legend.png'
    bar = True

    unit = 'Kw'
    time = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'June', 'July', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']

    current_temperatures = [35, 45, 55, 65, 75, 85, 95, 100, 85, 65, 45, 35]
    # list is not always provided
    historic_temperatures = [35, 85,  35, 45, 55, 65, 75, 95, 100, 85, 65, 45,]

    current_readings = [.99, .75, .55, .10, .35, .05, .05, .08, .20, .55, .60, .85]
    # list is not always provided
    historic_readings = [.50, .05, .05, .08, .20,  .75, .55, .10, .35,.45, .65, .49, ]

    swap = True if sum(historic_readings) > sum(current_readings) else False

    time_label = 'Month'

    temp_label = f'Temperatures {degree}'
   
    current_data = {time_label: time, unit: current_readings, temp_label: current_temperatures}
    historic_data = {time_label: time, unit: historic_readings, temp_label: historic_temperatures}

    current_data_frame = pd.DataFrame(current_data)
    historic_data_frame = pd.DataFrame(historic_data)

    fig, current_ax = plt.subplots(figsize=(10, 6))
    current_color = 'blue'
    current_palette = "Reds"
    historic_color = 'black'
    historic_palette = "Blues_r"

    historic_ax = current_ax.twinx()
    historic_ax.axes.xaxis.set_visible(False)
    historic_ax.axes.yaxis.set_visible(False)
    historic_ax.axes.set_ylim(current_ax.axes.get_ylim())

    current_ax.set_xlabel('Time', fontsize=16)
    current_ax.set_ylabel(unit, fontsize=16, )

    current_peak = max(current_readings)
    current_peak_index = current_readings.index(current_peak)

    historic_peak = max(historic_readings)
    historic_peak_index = historic_readings.index(historic_peak)

    current_ax = sns.barplot(ax=current_ax, x=time_label, y=unit, data=current_data_frame, palette=current_palette, color=current_color, )
    current_ax.patches[current_peak_index].set_color('red')
    current_ax.patches[historic_peak_index].set_alpha(0.3)

    historic_ax = sns.barplot(ax=historic_ax, x=time_label, y=unit, data=historic_data_frame, palette=historic_palette,   color=historic_color, alpha=.7)
    historic_ax.patches[historic_peak_index].set_color('black')

    temperature_ax = current_ax.twinx()
    current_color = 'green'
    historic_color = 'orange'

    temperature_ax.set_ylabel(f'Temperature {degree}', fontsize=16,)
    temperature_ax = sns.lineplot(x=time_label, y=temp_label, data=current_data_frame, sort=False, color=current_color)
    temperature_ax.tick_params(axis='y', color=current_color
    temperature_ax = sns.lineplot(x=time_label, y=temp_label, data=historic_data_frame, sort=False, color=historic_color)
    temperature_ax.tick_params(axis='y', color=historic_color)

    plt.style.use('seaborn-poster')
    plt.style.use('ggplot')

    plt.savefig(fname=filename, dpi=200)

    figsize = (2.3, 2.3)
    fig_leg = plt.figure(figsize=figsize)
    fig_leg.set_size_inches(2.3, 2.3, forward=True)
    ax_leg = fig_leg.add_subplot(111)

    current_peak_reading_label = mpatches.Patch(color='red', label=f'Current Peak ({unit})')
    current_reading_label = mpatches.Patch(color='purple', label=f'Current {unit}')
    historic_peak_reading_label = mpatches.Patch(color='pink', label=f'Historic Peak ({unit})')
    historic_reading_label = mpatches.Patch(color='yellow', label=f'Historic {unit}')

    handles, labels = current_ax.get_legend_handles_labels()
    handles += [current_reading_label, current_peak_reading_label, historic_reading_label, historic_peak_reading_label]

    historic_handles, historic_labels = historic_ax.get_legend_handles_labels()
    handles += historic_handles
    labels += historic_labels
    temp_handles, temp_labels = temperature_ax.get_legend_handles_labels()
    handles += temp_handles
    labels += temp_labels
    ax_leg.legend(handles, labels, loc='center', frameon=False)
    # hide the axes frame and the x/y labels
    ax_leg.axis('off')
    fig_leg.savefig(legend_file_name, dpi=200, bbox_inches='tight')

    plt.show()

Output:输出: 在此处输入图片说明

Fundamentally, unless you use hue , seaborn plots will not render a legend.从根本上说,除非你使用hueseaborn地块将不会呈现一个传奇。 Therefore, your handles and labels originate empty.因此,您的handleslabels最初是空的。 Additionally, while the mpatches section populates handles , you keep labels empty and not being equal in length, no legend is rendered in final.此外,当mpatches部分填充handles ,您将labels保持为空并且长度不相等,最终不会呈现任何图例。

Consider adjustments to your current code by sections (full code in link at bottom).考虑按部分调整当前代码(底部链接中的完整代码)。 However, below is for demonstration.但是,下面是演示。 Adjust process with understanding of above hue and mpatches labels issues.在了解上述huempatches标签问题的情况下调整过程。

Data (add period column for hue argument)数据(为hue参数添加周期列)

current_data_frame = pd.DataFrame(current_data).assign(period='current')
#    Month    Kw  Temperatures  ℉   period
# 0    Jan  0.99                35  current
# 1    Feb  0.75                45  current
# 2    Mar  0.55                55  current
# 3    Apr  0.10                65  current
# 4    May  0.35                75  current
# 5   June  0.05                85  current
# 6   July  0.05                95  current
# 7    Aug  0.08               100  current
# 8    Sep  0.20                85  current
# 9    Oct  0.55                65  current
# 10   Nov  0.60                45  current
# 11   Dec  0.85                35  current

historic_data_frame = pd.DataFrame(historic_data).assign(period='historic')
#    Month    Kw  Temperatures  ℉    period
# 0    Jan  0.50                35  historic
# 1    Feb  0.05                85  historic
# 2    Mar  0.05                35  historic
# 3    Apr  0.08                45  historic
# 4    May  0.20                55  historic
# 5   June  0.75                65  historic
# 6   July  0.55                75  historic
# 7    Aug  0.10                95  historic
# 8    Sep  0.35               100  historic
# 9    Oct  0.45                85  historic
# 10   Nov  0.65                65  historic
# 11   Dec  0.49                45  historic

Barplot (call .get_legend().remove() to remove from original plot)条形图(调用.get_legend().remove()从原始图中删除)

current_ax = sns.barplot(ax=current_ax, x=time_label, y=unit, hue='period', data=current_data_frame, palette=current_palette, color=current_color, )
current_ax.patches[current_peak_index].set_color('red')
current_ax.patches[historic_peak_index].set_alpha(0.3)
current_ax.get_legend().remove()

historic_ax = sns.barplot(ax=historic_ax, x=time_label, y=unit, hue='period', data=historic_data_frame, palette=historic_palette, color=historic_color, alpha=.7)
historic_ax.patches[historic_peak_index].set_color('black')
historic_ax.get_legend().remove()

Line Plots (call .get_legend().remove() to remove from original plot)线图(调用.get_legend().remove()从原始图中删除)

temperature_ax.set_ylabel(f'Temperature {degree}', fontsize=16,)

temperature_ax = sns.lineplot(x=time_label, y=temp_label, hue='period', data=current_data_frame, sort=False, palette=['g'])
temperature_ax.tick_params(axis='y', color=current_color)
temperature_ax.get_legend().remove()

temperature_ax = sns.lineplot(x=time_label, y=temp_label, hue='period', data=historic_data_frame, sort=False, palette=['orange'])
temperature_ax.get_legend().remove()
temperature_ax.tick_params(axis='y', color=historic_color)

Mpatches (add label for each handle) Mpatches (为每个手柄添加标签)

#... same mpatches code ...

handles, labels = current_ax.get_legend_handles_labels()

# ADD LABEL TO CORRESPOND TO HANDLE
labels += ['current_reading_label', 'current_peak_reading_label', 'historic_reading_label', 'historic_peak_reading_label']
handles += [current_reading_label, current_peak_reading_label, historic_reading_label, historic_peak_reading_label]

historic_handles, historic_labels = historic_ax.get_legend_handles_labels()
handles += historic_handles
labels += historic_labels

temp_handles, temp_labels = temperature_ax.get_legend_handles_labels()
handles += [temp_handles[1]] + [temp_handles[3]]   # SKIP 1ST AND 3RD ITEMS (LEGEND TITLE, 'period')
labels += [temp_labels[1]] + [temp_labels[3]]      # SKIP 1ST AND 3RD ITEMS (LEGEND TITLE, 'period')

Full Code完整代码


Legend Plot (colors/palette need adjustment)图例图(颜色/调色板需要调整)

传说情节

Bar and Line Plot (colors/palette need adjustment)条形图和折线图(颜色/调色板需要调整)

条形图和线图

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM