简体   繁体   中英

Create a single legend for multiple plot in matplotlib, seaborn

I am making boxplot using "iris.csv" data. I am trying to break the data into multiple dataframe by measurements (ie petal-length, petal-width, sepal-length, sepal-width) and then make box-plot on a forloop, thereby adding subplot.

Finally, I want to add a common legend for all the box plot at once. But, I am not able to do it. I have tried several tutorials and methods using several stackoverflow questions, but i am not able to fix it.

Here is my code:

import seaborn as sns 
from matplotlib import pyplot

iris_data = "iris.csv"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = read_csv(iris_data, names=names)


# Reindex the dataset by species so it can be pivoted for each species 
reindexed_dataset = dataset.set_index(dataset.groupby('class').cumcount())
cols_to_pivot = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width']

# empty dataframe 
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
    pivoted_dataset = reindexed_dataset.pivot(columns='class', values=var_name).rename_axis(None,axis=1)
    pivoted_dataset['measurement'] = var_name
    reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)


## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement') :
    grouped_dfs_02.append(group[1])


## make the box plot of several measured variables, compared between species 

pyplot.figure(figsize=(20, 5), dpi=80)
pyplot.suptitle('Distribution of floral traits in the species of iris')

sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')

my_pal = {"Iris-versicolor": "g", "Iris-setosa": "r", "Iris-virginica":"b"}
plt_index = 0


# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):

    axi = pyplot.subplot(1, len(grouped_dfs_02), plt_index + 1)
    sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
    df_melt = df.melt('measurement', var_name='species', value_name='values')

    sns.boxplot(data=df_melt, x='species', y='values', ax = axi, orient="v", palette=my_pal)
    pyplot.title(group_name)
    plt_index += 1


# Move the legend to an empty part of the plot
pyplot.legend(title='species', labels = sp_name, 
         handles=[setosa, versi, virgi], bbox_to_anchor=(19, 4),
           fancybox=True, shadow=True, ncol=5)


pyplot.show()

Here is the plot: 在此处输入图像描述

How, do I add a common legend to the main figure, outside the main frame, by the side of the "main suptitle"?

To position the legend, it is important to set the loc parameter, being the anchor point. (The default loc is 'best' which means you don't know beforehand where it would end up). The positions are measured from 0,0 being the lower left of the current ax, to 1,1 : the upper left of the current ax. This doesn't include the padding for titles etc., so the values can go a bit outside the 0, 1 range. The "current ax" is the last one that was activated.

Note that instead of plt.legend (which uses an axes), you could also use plt.gcf().legend which uses the "figure". Then, the coordinates are 0,0 in lower left corner of the complete plot (the "figure") and 1,1 in the upper right. A drawback would be that no extra space would be created for the legend, so you'd need to manually set a top padding (eg plt.gcf().subplots_adjust(top=0.8) ). A drawback would be that you can't use plt.tight_layout() anymore, and that it would be harder to align the legend with the axes.

import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import patches as mpatches
import pandas as pd

dataset = sns.load_dataset("iris")

# Reindex the dataset by species so it can be pivoted for each species
reindexed_dataset = dataset.set_index(dataset.groupby('species').cumcount())
cols_to_pivot = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']

# empty dataframe
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
    pivoted_dataset = reindexed_dataset.pivot(columns='species', values=var_name).rename_axis(None, axis=1)
    pivoted_dataset['measurement'] = var_name
    reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)

## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement'):
    grouped_dfs_02.append(group[1])

## make the box plot of several measured variables, compared between species
plt.figure(figsize=(20, 5), dpi=80)
plt.suptitle('Distribution of floral traits in the species of iris')

sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')

my_pal = {"versicolor": "g", "setosa": "r", "virginica": "b"}
plt_index = 0

# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):
    axi = plt.subplot(1, len(grouped_dfs_02), plt_index + 1)
    sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
    df_melt = df.melt('measurement', var_name='species', value_name='values')

    sns.boxplot(data=df_melt, x='species', y='values', ax=axi, orient="v", palette=my_pal)
    plt.title(group_name)
    plt_index += 1

# Move the legend to an empty part of the plot
plt.legend(title='species', labels=sp_name,
           handles=[setosa, versi, virgi], bbox_to_anchor=(1, 1.23),
           fancybox=True, shadow=True, ncol=5, loc='upper right')
plt.tight_layout()
plt.show()

结果图

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# load iris data
iris = sns.load_dataset("iris")

   sepal_length  sepal_width  petal_length  petal_width species
0           5.1          3.5           1.4          0.2  setosa
1           4.9          3.0           1.4          0.2  setosa
2           4.7          3.2           1.3          0.2  setosa
3           4.6          3.1           1.5          0.2  setosa
4           5.0          3.6           1.4          0.2  setosa

# create figure
fig = plt.figure(figsize=(20, 5), dpi=80)

# add subplots
for i, col in enumerate(iris.columns[:-1], 1):
    plt.subplot(1, 4, i)
    ax = sns.boxplot(x='species', y=col, data=iris, hue='species')
    ax.get_legend().remove()
    plt.title(col)

# add legend
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', ncol=3, bbox_to_anchor=(.75, 0.98))

# add subtitle
fig.suptitle('Distribution of floral traits in the species of iris')

plt.show()

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM