简体   繁体   中英

Custom legend for Seaborn regplot (Python 3)

I've been trying to follow this How to make custom legend in matplotlib SO question but I think a few things are getting lost in translation. I used a custom color mapping for the different classes of points in my plot and I want to be able to put a table with those color-label pairs. I stored the info in a dictionary D_color_label and then made 2 parallel lists colors and labels . I tried using it in the ax.legend but it didn't seem to work.

np.random.seed(0)

# Create dataframe
DF_0 = pd.DataFrame(np.random.random((100,2)), columns=["x","y"])

# Label to colors
D_idx_color = {**dict(zip(range(0,25), ["#91FF61"]*25)),
               **dict(zip(range(25,50), ["#BA61FF"]*25)),
               **dict(zip(range(50,75), ["#916F61"]*25)),
               **dict(zip(range(75,100), ["#BAF1FF"]*25))}

D_color_label = {"#91FF61":"label_0",
                 "#BA61FF":"label_1",
                 "#916F61":"label_2",
                 "#BAF1FF":"label_3"}

# Add color column
DF_0["color"] = pd.Series(list(D_idx_color.values()), index=list(D_idx_color.keys()))

# Plot
fig, ax = plt.subplots(figsize=(8,8))
sns.regplot(data=DF_0, x="x", y="y", scatter_kws={"c":DF_0["color"]}, ax=ax)

# Add custom legend
colors = list(set(DF_0["color"]))
labels = [D_color_label[x] for x in set(DF_0["color"])]

# If I do this, I get the following error:
# ax.legend(colors, labels)
# UserWarning: Legend does not support '#BA61FF' instances.
# A proxy artist may be used instead.

在此处输入图片说明

According to http://matplotlib.org/users/legend_guide.html you have to put to legend function artists which will be labeled. To use scatter_plot individually you have to group by your data by color and plot every data of one color individually to set its own label for every artist:

import pandas as pd
import numpy as np
import matplotlib.pylab as plt
import seaborn as sns

np.random.seed(0)

# Create dataframe
DF_0 = pd.DataFrame(np.random.random((100, 2)), columns=["x", "y"])
DF_0['color'] =  ["#91FF61"]*25 + ["#BA61FF"]*25 + ["#91FF61"]*25 + ["#BA61FF"]*25
#print DF_0

D_color_label = {"#91FF61": "label_0", "#BA61FF": "label_1",
                 "#916F61": "label_2", "#BAF1FF": "label_3"}
colors = list(DF_0["color"].uniqe())
labels = [D_color_label[x] for x in DF_0["color"].unique()]

ax = sns.regplot(data=DF_0, x="x", y="y", scatter_kws={'c': DF_0['color'], 'zorder':1})

# Make a legend
# groupby and plot points of one color
for i, grp in DF_0.groupby(['color']):
    grp.plot(kind='scatter', x='x', y='y', c=i, ax=ax, label=labels[i+1], zorder=0)       
ax.legend(loc=2)

plt.show()

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM