简体   繁体   中英

How do I display the line colour in the legend when using seaborn's 2D kdeplot?

I am wanting to overlay different 2D density plots over each other using the kdeplot() function from seaborn , however the colour of the contours aren't appearing in the legend. How would I be able to update the legend with the colour?

Code example:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
sns.kdeplot(x = np.random.random(30), y = np.random.random(30), label = "dist1", ax=ax)
sns.kdeplot(x = np.random.random(30) + 1, y = np.random.random(30) + 1, label = "dist2", ax=ax)
ax.legend()
plt.show()

Output plot

I'm using seaborn v0.12.0

Found a way to work around the issue. By extracting the colour in the colourcycle, you can manually set the colour of kdeplot() as well as construct the handles for the legend.

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

fig, ax = plt.subplots()
handles = []
# Extracting next colour in cycle
color = next(ax._get_lines.prop_cycler)["color"]
sns.kdeplot(x = np.random.random(30), y = np.random.random(30), color = color, label = "dist1", ax=ax)
handles.append(mlines.Line2D([], [], color=color, label="dist1"))

color = next(ax._get_lines.prop_cycler)["color"]
sns.kdeplot(x = np.random.random(30) + 1, y = np.random.random(30) + 1, color = color, label = "dist2", ax=ax)
handles.append(mlines.Line2D([], [], color=color, label="dist1"))


ax.legend(handles = handles)

Output plot

  • It's easier create a pandas.DataFrame with a label column, and let the plot API handle the colors.
  • Tested in python 3.10 , pandas 1.4.3 , matplotlib 3.5.2 , seaborn 0.12.0

Create a DataFrame

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

np.random.seed(2022)  # for the same sample data each time

# create a dataframe with a label for each set, and then combine then concat the sets
df = pd.concat([pd.DataFrame({'x': np.random.random(30), 'y': np.random.random(30), 'label': ['d1']*len(x1)}),
                pd.DataFrame({'x': np.random.random(30) + 1, 'y': np.random.random(30) + 1, 'label': ['d2']*len(x2)})], ignore_index=True)

# display(df.head())
          x         y label
0  0.009359  0.564672    d1
1  0.499058  0.349429    d1
2  0.113384  0.975909    d1
3  0.049974  0.037820    d1
4  0.685408  0.794270    d1

sns.displot

# plot the dataframe in a figure level plot
g = sns.displot(kind='kde', data=df, x='x', y='y', hue='label')

在此处输入图像描述

sns.kdeplot

# plot the dataframe in an axes level plot
fig, ax = plt.subplots(figsize=(7, 5))
sns.kdeplot(data=df, x='x', y='y', hue='label', ax=ax)

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM