简体   繁体   English

Python 中的 3D 绘图 - 向散点图添加图例

[英]3D plotting in Python - Adding a Legend to Scatterplot

from mpl_toolkits.mplot3d import Axes3D

ax.scatter(X_lda[:,0], X_lda[:,1], X_lda[:,2], alpha=0.4, c=y_train, cmap='rainbow', s=20)

plt.legend()
plt.show()

Essentially I'd like to add a legend for the scatterplot that shows the unique values in y_train and what color point it corresponds to on the plot.本质上,我想为散点图添加一个图例,显示 y_train 中的唯一值以及它在图中对应的颜色点。

The output plot:输出图: 阴谋

Producing either a legend or a colorbar for a scatter is usually quite simple:为散点生成图例或颜色条通常非常简单:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

x,y,z = (np.random.normal(size=(300,4))+np.array([0,2,4,6])).reshape(3,400)
c = np.tile([1,2,3,4], 100)

fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))
sc = ax.scatter(x,y,z, alpha=0.4, c=c, cmap='rainbow', s=20)

plt.legend(*sc.legend_elements())
plt.colorbar(sc)
plt.show()

在此处输入图片说明

Edit: After seeing @bigreddot's solution, I agree that this approach is somewhat more complicated than strictly necessary. 编辑:在看到@bigreddot的解决方案后,我同意这种方法比严格必要的方法更为复杂。 I leave it here in case somebody needs more fine-tuning for their colorbar or legend. 我将其留在此处,以防有人需要对其色条或图例进行更多的微调。

Here is a way to create both a custom legend and a custom colorbar for the 3D graph. 这是一种为3D图形创建自定义图例和自定义颜色条的方法。 So you can chose one or the other, depending on specific needs. 因此,您可以根据特定需求选择其中一个。 I'm not sure how the y_train is distributed; 我不确定y_train的分布方式; in the code some float values of a limited set are simulated. 在代码中,模拟了一组有限的浮点值。 Also, it is not clear what the labels should mention, so now they just put the value of y_train. 另外,还不清楚标签应提及什么,因此现在它们只是放置y_train的值。

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib as mpl

N = 1000
X_lda = np.random.gamma(9.0, 0.5, (N,3))
y_train = np.random.randint(0, 6, N)
X0 = np.random.gamma(5.0, 1.5, (N,3))
X1 = np.random.gamma(1.0, 1.5, (N,3))
for i in range(3):
    X_lda[:,i] = np.where (y_train == 0, X0[:,i], X_lda[:,i])
    X_lda[:,i] = np.where (y_train == 1, X1[:,i], X_lda[:,i])
y_train = np.sin(y_train*.2 + 10) * 10.0 + 20.0

fig = plt.figure(figsize = (15,15))
ax = fig.add_subplot(111, projection = '3d')
ax.scatter(X_lda[:,0], X_lda[:,1], X_lda[:,2], alpha=0.4, c=y_train, cmap='rainbow', s=20)

norm = mpl.colors.Normalize(np.min(y_train), np.max(y_train))
cmap = plt.get_cmap('rainbow')

y_unique = np.unique(y_train)
legend_lines = [mpl.lines.Line2D([0],[0], linestyle="none", marker='o', c=cmap(norm(y))) for y in y_unique]
legend_labels = [f'{y:.2f}' for y in y_unique]
ax.legend(legend_lines, legend_labels, numpoints = 1, title='Y-train')

sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
plt.colorbar(sm, ticks=y_unique, label='Y-train')
plt.show()

样例

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM