簡體   English   中英

防止 Matplotlib Legend 在垂直子圖上塗抹 h-padding?

[英]Prevent Matplotlib Legend from smushing h-padding on vertical subplots?

我有一個腳本,我在一個 5x1 的網格中繪制了幾個變量。 我注意到,當我有數據使我的圖例更短時,子圖本身具有可接受的高度和水平填充。 當我有數據使我的圖例更大(垂直)時,子圖被壓扁,在圖之間留下額外的水平填充。

有沒有辦法防止這種情況? 要將圖例分配與軸對象分開並獨立於圖例間距繪制每個圖?

下面是一個最小的可重現示例來說明我的意思:

#!/usr/bin/env python3
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def plotter(df, var_cols):
    dfa = df.query("accepted == 'accepted'")
    dfr = df.query("accepted != 'accepted'")
    colors = {v: c for v, c in zip(['accepted', 'rejected', 'rerun'],
                                   ['darkgreen', 'firebrick', 'steelblue'])}
    fig, axes = plt.subplots(nrows=len(var_cols), sharex=True)

    for var, ax in zip(var_cols, axes):
        for k, d in dfr.groupby('accepted'):
            ax.scatter(d.iteration, d[var], label=k, alpha=0.8, c=d.accepted.map(colors))
        ax.plot(dfa.iteration, dfa[var], '-o', label='accepted', color=colors['accepted'])

    # Grab 3rd axes because I want the legend to be towards the center
    handles, labels = axes[2].get_legend_handles_labels()
    # Sort legend labels to put 'accepted' on top
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
    axes[2].legend(handles, labels, markerscale=1.2, bbox_to_anchor=(1, 0.5))
    fig.tight_layout()


def main():
    states = {1: 'accepted', 2: 'rejected', 3: 'rerun'}
    np.random.seed(666)
    dat1 = pd.DataFrame({
        'iteration': [0, 1, 2],
        'accepted': ['accepted']*3,
        'h_cap': [10.1, 6.5, 12.2],
        'h_stor': [500, 410, 0],
        'h_mark': [10, 6, 1],
        'bid': [500, 100, 50],
        'npv': [2.278, 2.6, 2.85]
    })

    dat2 = pd.DataFrame({
        'iteration': range(10),
        'accepted': [states[num] for num in np.random.randint(1, 4, size=10)],
        'h_cap': np.random.rand(10),
        'h_stor': np.random.rand(10),
        'h_mark': np.random.rand(10),
        'bid': np.random.rand(10),
        'npv': np.random.rand(10)
    })

    var_cols = ['h_cap', 'h_stor', 'h_mark', 'bid', 'npv']
    plotter(dat1, var_cols)
    plt.savefig(Path('~/Desktop/nonsmushed.png').expanduser())

    plotter(dat2, var_cols)
    plt.savefig(Path('~/Desktop/smushed.png').expanduser())


if __name__ == '__main__':
    main()

未塗抹.png

在此處輸入圖片說明

被弄臟的.png

在此處輸入圖片說明

因為您的圖例“屬於” axes[2]tight_layout()調整間距,以便相鄰軸不覆蓋圖例。

我認為最簡單的解決方案是創建一個“圖形級”圖例( fig.legend() ),但問題在於tight_layout()不考慮該圖例,您必須調整正確的手動保證金(如果需要,可能有一種方法可以自動計算,但這可能會變得混亂)

(...)
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
fig.legend(handles, labels, markerscale=1.2, bbox_to_anchor=(1, 0.5))
fig.tight_layout()
fig.subplots_adjust(right=0.75)  # adjust value as needed

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM