![](/img/trans.png)
[英]How to modify 2d Scatterplot to display color based off third array in csv file?
[英]Connecting dots in a 2D scatterplot with a color as a third dimension
假设我有以下数据集:
x = np.arange(150000,550000,100000)
y = np.random.rand(7*4)
z = [0.6,0.6,0.6,0.6,0.7,0.7,0.7,0.7,0.8,0.8,0.8,0.8,0.9,0.9,0.9,0.9,1.0,1.0,1.0,1.0,1.1,1.1,1.1,1.1,1.2,1.2,1.2,1.2]
x_ = np.hstack([x,x,x,x,x,x,x])
我正在做一个散点图:
plt.figure()
plt.scatter(x_,y,c=z)
plt.colorbar()
plt.set_cmap('jet')
plt.xlim(100000,500000)
plt.show()
但是,我想连接相同颜色的点。 我尝试只使用plt.plot
和相同的变量,但是它连接所有的点,而不仅仅是黄点和黄点。
任何帮助表示赞赏。
编辑:
z 轴是离散的,我事先知道这些值。 我也知道我可以多次调用plt.plot()
方法来绘制线条,但我希望可能有另一种方法来做到这一点。
如果我理解正确的话,你有一个 x 值列表,对于每个 x,一些 y 是关联的,每个 x,y 都有一个特定的 z 值。 z 值属于有限集(或可以四舍五入以确保只有有限集)。
因此,我创建了 x、y 和 z 的副本,并通过 z 同时对它们进行了排序。 然后,循环遍历 z 数组,收集 x、y 的值,每次 z 更改时,都可以绘制属于该颜色的所有线。 为了不需要特殊的步骤来绘制最后一组线,我附加了一个哨兵。
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(150000,550000,100000)
y = np.random.rand(7*4)
z = [0.6,0.6,0.6,0.6,0.7,0.7,0.7,0.7,0.8,0.8,0.8,0.8,0.9,0.9,0.9,0.9,1.0,1.0,1.0,1.0,1.1,1.1,1.1,1.1,1.2,1.2,1.2,1.2]
z_min = min(z)
z_max = max(z)
x_ = np.hstack([x,x,x,x,x,x,x])
zs = [round(c, 1) for c in z] # make sure almost equal z-values are exactly equal
zs, xs, ys = zip( *sorted( zip(z, x_, y) ) ) # sort the x,y via z
zs += (1000,) # add a sentinel at the end that can be used to stop the line drawing
xs += (None, )
ys += (None, )
plt.set_cmap('plasma')
cmap = plt.get_cmap() # get the color map of the current plot call with `plt.get_cmap('jet')`
norm = plt.Normalize(z_min, z_max) # needed to map the z-values between 0 and 1
plt.scatter(x_, y, c=z, zorder=10) # z-order: plot the scatter dots on top of the lines
prev_x, prev_y, prev_z = None, None, None
x1s, y1s, x2s, y2s = [], [], [], []
for x0, y0, z0 in zip(xs, ys, zs):
if z0 == prev_z:
x1s.append(prev_x)
y1s.append(prev_y)
x2s.append(x0)
y2s.append(y0)
elif prev_z is not None: # the z changed, draw the lines belonging to the previous z
print(x1s, y1s, x2s, y2s)
plt.plot(x1s, y1s, x2s, y2s, color=cmap(norm(prev_z)))
x1s, y1s, x2s, y2s = [], [], [], []
prev_x, prev_y, prev_z = x0, y0, z0
plt.colorbar()
plt.show()
使用LineCollection
更容易:
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import numpy as np; np.random.seed(42)
x = np.arange(150000,550000,100000)
y = np.random.rand(7*4)
z = [0.6,0.6,0.6,0.6,0.7,0.7,0.7,0.7,0.8,0.8,0.8,0.8,0.9,0.9,0.9,0.9,1.0,1.0,1.0,1.0,1.1,1.1,1.1,1.1,1.2,1.2,1.2,1.2]
x_ = np.tile(x, 7)
segs = np.stack((x_, y), axis=1).reshape(7, 4, 2)
plt.figure()
sc = plt.scatter(x_,y,c=z, cmap="plasma")
plt.colorbar()
lc = LineCollection(segs, cmap="plasma", array=np.unique(z))
plt.gca().add_collection(lc)
plt.show()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.