简体   繁体   中英

Removing elements from list corresponding to numpy array

I have three lists xs, ys, zs of intgers, as well as a 3d numpy array V , which contains the value for each point. For example, the value of point (x[0], y[0], z[0]) is V[x[0], y[0], z[0]] . I'm using these to create a 3d scatter plot plt.scatter(xs, ys, zs, c=V) .

I would like to plot only points that have values that are at least 0.2 in V . How can I go about removing the correct elements from xs, ys, zs and getting V into the correct shape?

Edit: here is a brute force way of doing it:

xg = []
yg = []
zg = []
Vg = []
for x in xs:
    for y in ys:
        for z in zs:
            if V[x,y,z] > 0.2:
                xg.append(x)
                yg.append(y)
                zg.append(z)
                Vg.append(V[x,y,z])
ax.scatter(xg, yg, zg, c=Vg)

In the best case the array V is ordered such that when it's flattened, the value at index i corresponds to the i th value in x,y,z . If this is the case you can filter the respective arrays by the condition:

X = np.array(xs); Y = np.array(ys); Z=np.array(zs)
X = X[V>0.2]
Y = Y[V>0.2]
Z = Z[V>0.2]
V = V[V>0.2]

plt.scatter(X,Y,Z, c=V)

If x,y,z do not actually define a grid, we need to define that grid first.

Y,X,Z = np.meshgrid(xs,ys,zs)
X = X[V>0.2]
Y = Y[V>0.2]
Z = Z[V>0.2]
V = V[V>0.2]
ax2.scatter(X, Y, Z, c=V)

A complete example, comparing the method from the question with this one:

import numpy as np
V = np.arange(27).reshape((3,3,3))/35.
xs = np.arange(3)
ys = np.arange(3)
zs = np.arange(3)


from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(121, projection='3d')
ax2 = fig.add_subplot(122, projection='3d')

# solution from the question
xg = []
yg = []
zg = []
Vg = []
for x in xs:
    for y in ys:
        for z in zs:
            if V[x,y,z] > 0.2:
                xg.append(x)
                yg.append(y)
                zg.append(z)
                Vg.append(V[x,y,z])
ax.scatter(xg, yg, zg, c=Vg)

#  numpy solution
Y,X,Z = np.meshgrid(xs,ys,zs)
X = X[V>0.2]
Y = Y[V>0.2]
Z = Z[V>0.2]
V = V[V>0.2]
ax2.scatter(X, Y, Z, c=V)

plt.show()

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM