简体   繁体   中英

set scatter plot legend labels with legend_elements

I just upgraded matplotlib to version 3.1.1 and I am experimenting with using legend_elements.

I am making a scatterplot of the top two components from PCA on a dataset of 30,000 flattened, grayscale images. Each image is labeled as one of four master categories (Accessory, Apparel, Footwear, Personal Care). I have color coded the plot by 'master category' by creating a colors column with values from 0 to 3.

I have read the documentation for PathCollection.legend_elements, but I haven't successfully incorporated the 'func' or 'fmt' parameters. https://matplotlib.org/3.1.1/api/collections_api.html#matplotlib.collections.PathCollection.legend_elements

Also, I have tried to follow examples provided: https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/scatter_with_legend.html

### create column for color codes
masterCat_codes = {'Accessories':0,'Apparel':1, 'Footwear':2, 'Personal Care':3}
df['colors'] = df['masterCategory'].apply(lambda x: masterCat_codes[x])

### create scatter plot
fig, ax = plt.subplots(figsize=(8,8))
scatter = ax.scatter( *full_pca.T, s=.1 , c=df['colors'], label= df['masterCategory'], cmap='viridis')

### using legend_elements
legend1 = ax.legend(*scatter.legend_elements(num=[0,1,2,3]), loc="upper left", title="Category Codes")
ax.add_artist(legend1)
plt.show()

The resulting legend labels are 0, 1, 2, 3. (This happens whether or not I specify label = df['masterCategory'] when defining 'scatter'). I would like labels to say Accessories, Apparel, Footwear, Personal Care.

Is there a way to accomplish this with legend_elements?

Note: As the dataset is large and the preprocessing is computationally heavy, I have written an example that is simpler to reproduce:

fake_data = np.array([[1,1],[1,2],[1,3],[2,1],[2,2],[2,3],[3,1],[3,2],[3,3]])
fake_df = pd.DataFrame(fake_data, columns=['X', 'Y'])
groups = np.array(['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'])
fake_df['Group'] = groups

group_codes = {k:idx for idx, k in enumerate(fake_df.Group.unique())}
fake_df['colors'] = fake_df['Group'].apply(lambda x: group_codes[x])
fig, ax = plt.subplots()
scatter = ax.scatter(fake_data[:,0], fake_data[:,1], c=fake_df['colors'])
legend = ax.legend(*scatter.legend_elements(num=[0,1,2]), loc="upper left", title="Group \nCodes")
ax.add_artist(legend)
plt.show()

在此处输入图片说明

Solution Thanks to ImportanceOfBeingErnest

  • .legend_elements returns legend handles and labels for a PathCollection .
    • handles = scatter.legend_elements(num=[0,1,2,3])[0] because the handles are the first object returned by the method.
  • Also see Scatter plots with a legend
group_codes = {k:idx for idx, k in enumerate(fake_df.Group.unique())}
fake_df['colors'] = fake_df['Group'].apply(lambda x: group_codes[x])
fig, ax = plt.subplots(figsize=(8,8))
scatter = ax.scatter(fake_data[:,0], fake_data[:,1], c=fake_df['colors'])

handles = scatter.legend_elements(num=[0,1,2,3])[0]  # extract the handles from the existing scatter plot

ax.legend(title='Group\nCodes', handles=handles, labels=group_codes.keys())
plt.show()

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM