简体   繁体   中英

How do I add a legend to a scatter plot on matplotlib (the points are colour coded according to an array of 0s and 1s)?

I am working in Python 3 Jupyter Notebook. I have a 3-column table (price, size and view). I have created a scatter plot of "price" against "size" but I colour coded the dots according to the "view" column which contains only 0s and 1s. Now I want to add a legend showing the red dots represent "No view" and the blue dots represent "View". This is what I have tried:

plt.scatter(data1["size"], data1["price"], c = data1["view"], cmap = "bwr_r")
plt.xlabel("Size", fontsize = 25, c = "green")
plt.ylabel("Price", fontsize = 25, c = "green")
plt.legend(["No view", "View"], bbox_to_anchor= (1.05, 0.5), loc= "lower left")
plt.show()

After running the code above, everything works fine except the legend only shows "No view" for red dots, "View" for blue dots does not appear.

To do what you need, you will need to assign the view - 0 or 1 to a color, so that the right color is mapped. This can be done using map. The handle for the legend will need to have the custom text added, so that the blue and red colors are assigned and show with the correct labels. I have used random numbers as data to plot the graph required, keeping as much of your code as is.

Code

import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

size = []
price = []
view = []

for i in range(0,100):
    size.append(round(random.random(),3))
    price.append(round(random.random(),3))
    view.append(int(random.random()*10 % 2))
df = pd.DataFrame({'size':size, 'price':price, 'view':view})
colors = {0:'red', 1:'blue'}
plt.scatter(x=df['size'], y=df['price'], c=df['view'].map(colors))
plt.xlabel("Size", fontsize = 25, c = "green")
plt.ylabel("Price", fontsize = 25, c = "green")
markersize=8) for k, v in colors.items()]
custom = [Line2D([], [], marker='.', color='red', linestyle='None'),
          Line2D([], [], marker='.', color='blue', linestyle='None')]

plt.legend(handles = custom, labels=['No View', 'View'], bbox_to_anchor= (1.05, 0.5), loc= "lower left")
plt.show()

Output graph

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM