简体   繁体   English

使用调色板进行散点图绘制时显示正确的图例

[英]Showing a correct legend when doing scatter plot with palette

Stupid way to plot a scatter plot 绘制散点图的愚蠢方法

Suppose I have a data with 3 classes, the following code can give me a perfect graph with a correct legend, in which I plot out data class by class. 假设我有一个包含3个类的数据,下面的代码可以为我提供一个带有正确图例的完美图形,其中我逐个类地绘制数据。

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs
import numpy as np

X, y = make_blobs()

X0 = X[y==0]
X1 = X[y==1]
X2 = X[y==2]

ax = plt.subplot(1,1,1)
ax.scatter(X0[:,0],X0[:,1], lw=0, s=40)
ax.scatter(X1[:,0],X1[:,1], lw=0, s=40)
ax.scatter(X2[:,0],X2[:,1], lw=0, s=40)
ax.legend(['0','1','2'])

在此处输入图片说明

Better way to plot a scatter plot 绘制散点图的更好方法

However, if I have a dataset with 3000 classes, the above method doesn't work anymore. 但是,如果我有一个包含3000个类的数据集,则上述方法将不再起作用。 (You won't expect me to write 3000 line corresponding to each class, right?) So I come up with the following plotting code. (您不会期望我写对应于每个类的3000行,对吗?)因此,我想到了以下绘图代码。

num_classes = len(set(y))
palette = np.array(sns.color_palette("hls", num_classes))

ax = plt.subplot(1,1,1)
ax.scatter(X[:,0], X[:,1], lw=0, s=40, c=palette[y.astype(np.int)])
ax.legend(['0','1','2'])

在此处输入图片说明

This code is perfect, we can plot out all the classes with only 1 line. 这段代码很完美,我们只用一行就可以绘制所有类。 However, the legend is not showing correctly this time. 但是,图例这次没有正确显示。

Question

How to maintain a correct legend when we plot graphs by using the following? 使用以下内容绘制图形时如何维护正确的图例?

ax.scatter(X[:,0], X[:,1], lw=0, s=40, c=palette[y.astype(np.int)])

plt.legend() works best when you have multiple "artists" on the plot. 当剧情中有多个“艺术家”时, plt.legend()效果最佳。 That is the case in your first example which is why calling plt.legend(labels) works effortlessly. 第一个示例就是这种情况,这就是为什么调用plt.legend(labels)可以轻松进行的原因。

If you are worried about writing lots of lines of code then you can take advantage of for loops. 如果您担心编写大量代码行,则可以利用for循环。

As we can see with this example using 5 classes: 正如我们在使用5个类的示例中看到的:

import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import numpy as np

X, y = make_blobs(centers=5)
ax = plt.subplot(1,1,1)

for c in np.unique(y):
    ax.scatter(X[y==c,0],X[y==c,1],label=c)

ax.legend()

在此处输入图片说明

np.unique() returns a sorted array of the unique elements of y, by looping through these and plotting each class with its own artist plt.legend() can easily provide a legend. np.unique()返回y唯一元素的排序数组,方法是循环遍历这些元素,并用其自己的艺术家plt.legend()绘制每个类,可以轻松提供图例。

Edit: 编辑:

You can also assign labels to the plots as you make them which is probably safer. 您还可以在制作图时为其分配标签,这可能更安全。

plt.scatter(..., label=c) followed by plt.legend() plt.scatter(..., label=c)后跟plt.legend()

Why not simply do the following? 为什么不简单地执行以下操作?

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs
import numpy as np

X, y = make_blobs()
ngroups = 3

ax = plt.subplot(1, 1, 1)
for i in range(ngroups):
    ax.scatter(X[y==i][:,0], X[y==i][:,1], lw=0, s=40, label=i)
ax.legend()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM