简体   繁体   English

为 matplotlib、seaborn 中的多个 plot 创建单个图例

[英]Create a single legend for multiple plot in matplotlib, seaborn

I am making boxplot using "iris.csv" data.我正在使用“iris.csv”数据制作箱线图。 I am trying to break the data into multiple dataframe by measurements (ie petal-length, petal-width, sepal-length, sepal-width) and then make box-plot on a forloop, thereby adding subplot.我试图通过测量(即花瓣长度、花瓣宽度、萼片长度、萼片宽度)将数据分解为多个 dataframe,然后在 forloop 上制作箱线图,从而添加子图。

Finally, I want to add a common legend for all the box plot at once.最后,我想一次为所有框 plot 添加一个通用图例。 But, I am not able to do it.但是,我做不到。 I have tried several tutorials and methods using several stackoverflow questions, but i am not able to fix it.我已经使用几个 stackoverflow 问题尝试了几个教程和方法,但我无法修复它。

Here is my code:这是我的代码:

import seaborn as sns 
from matplotlib import pyplot

iris_data = "iris.csv"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = read_csv(iris_data, names=names)


# Reindex the dataset by species so it can be pivoted for each species 
reindexed_dataset = dataset.set_index(dataset.groupby('class').cumcount())
cols_to_pivot = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width']

# empty dataframe 
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
    pivoted_dataset = reindexed_dataset.pivot(columns='class', values=var_name).rename_axis(None,axis=1)
    pivoted_dataset['measurement'] = var_name
    reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)


## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement') :
    grouped_dfs_02.append(group[1])


## make the box plot of several measured variables, compared between species 

pyplot.figure(figsize=(20, 5), dpi=80)
pyplot.suptitle('Distribution of floral traits in the species of iris')

sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')

my_pal = {"Iris-versicolor": "g", "Iris-setosa": "r", "Iris-virginica":"b"}
plt_index = 0


# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):

    axi = pyplot.subplot(1, len(grouped_dfs_02), plt_index + 1)
    sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
    df_melt = df.melt('measurement', var_name='species', value_name='values')

    sns.boxplot(data=df_melt, x='species', y='values', ax = axi, orient="v", palette=my_pal)
    pyplot.title(group_name)
    plt_index += 1


# Move the legend to an empty part of the plot
pyplot.legend(title='species', labels = sp_name, 
         handles=[setosa, versi, virgi], bbox_to_anchor=(19, 4),
           fancybox=True, shadow=True, ncol=5)


pyplot.show()

Here is the plot:这是 plot: 在此处输入图像描述

How, do I add a common legend to the main figure, outside the main frame, by the side of the "main suptitle"?如何,我如何在“主字幕”旁边的主图外添加一个常见的图例?

To position the legend, it is important to set the loc parameter, being the anchor point.对于 position 的图例,设置loc参数很重要,即定位点。 (The default loc is 'best' which means you don't know beforehand where it would end up). (默认loc'best' ,这意味着您事先不知道它会在哪里结束)。 The positions are measured from 0,0 being the lower left of the current ax, to 1,1 : the upper left of the current ax.位置的测量范围是从当前轴的左下角0,0到当前轴的左上角1,1 This doesn't include the padding for titles etc., so the values can go a bit outside the 0, 1 range.这不包括标题等的填充,因此值可以 go 有点超出0, 1范围。 The "current ax" is the last one that was activated. “当前斧头”是最后一个被激活的斧头。

Note that instead of plt.legend (which uses an axes), you could also use plt.gcf().legend which uses the "figure".请注意,您还可以使用plt.gcf().legend代替plt.legend (使用轴),它使用“图形”。 Then, the coordinates are 0,0 in lower left corner of the complete plot (the "figure") and 1,1 in the upper right.然后,坐标在完整的 plot(“图”)的左下角为0,0 ,在右上角为1,1 A drawback would be that no extra space would be created for the legend, so you'd need to manually set a top padding (eg plt.gcf().subplots_adjust(top=0.8) ).一个缺点是不会为图例创建额外的空间,因此您需要手动设置顶部填充(例如plt.gcf().subplots_adjust(top=0.8) )。 A drawback would be that you can't use plt.tight_layout() anymore, and that it would be harder to align the legend with the axes.一个缺点是您不能再使用plt.tight_layout() ,并且将图例与轴对齐会更加困难。

import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import patches as mpatches
import pandas as pd

dataset = sns.load_dataset("iris")

# Reindex the dataset by species so it can be pivoted for each species
reindexed_dataset = dataset.set_index(dataset.groupby('species').cumcount())
cols_to_pivot = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']

# empty dataframe
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
    pivoted_dataset = reindexed_dataset.pivot(columns='species', values=var_name).rename_axis(None, axis=1)
    pivoted_dataset['measurement'] = var_name
    reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)

## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement'):
    grouped_dfs_02.append(group[1])

## make the box plot of several measured variables, compared between species
plt.figure(figsize=(20, 5), dpi=80)
plt.suptitle('Distribution of floral traits in the species of iris')

sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')

my_pal = {"versicolor": "g", "setosa": "r", "virginica": "b"}
plt_index = 0

# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):
    axi = plt.subplot(1, len(grouped_dfs_02), plt_index + 1)
    sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
    df_melt = df.melt('measurement', var_name='species', value_name='values')

    sns.boxplot(data=df_melt, x='species', y='values', ax=axi, orient="v", palette=my_pal)
    plt.title(group_name)
    plt_index += 1

# Move the legend to an empty part of the plot
plt.legend(title='species', labels=sp_name,
           handles=[setosa, versi, virgi], bbox_to_anchor=(1, 1.23),
           fancybox=True, shadow=True, ncol=5, loc='upper right')
plt.tight_layout()
plt.show()

结果图

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# load iris data
iris = sns.load_dataset("iris")

   sepal_length  sepal_width  petal_length  petal_width species
0           5.1          3.5           1.4          0.2  setosa
1           4.9          3.0           1.4          0.2  setosa
2           4.7          3.2           1.3          0.2  setosa
3           4.6          3.1           1.5          0.2  setosa
4           5.0          3.6           1.4          0.2  setosa

# create figure
fig = plt.figure(figsize=(20, 5), dpi=80)

# add subplots
for i, col in enumerate(iris.columns[:-1], 1):
    plt.subplot(1, 4, i)
    ax = sns.boxplot(x='species', y=col, data=iris, hue='species')
    ax.get_legend().remove()
    plt.title(col)

# add legend
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', ncol=3, bbox_to_anchor=(.75, 0.98))

# add subtitle
fig.suptitle('Distribution of floral traits in the species of iris')

plt.show()

在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM