简体   繁体   中英

Python3 / Numpy: ndarray conditional indexing

I'm coming from a MATLAB background and I'm trying to write this in python/numpy:

[l, m, n] = ndgrid(1:size(dct, 1), 1:size(dct, 2), 1:size(dct, 3));
mycell{i, j} = dct(...
    min.^2 <= l.^2 + m.^2 + n.^2 & ...
    l.^2 + m.^2 + n.^2 <= max.^2)';

So what the code is supposed to to is take all the values of the array that have an index (eg x,y,z) that have a 2-norm between min and max , ie min^2 < x^2 + y^2 + z^2 < max^2

The only thing I could find was about indexing some values of an array with a condition of the value of the array at this index, however I want to index with a condition on the index itself.

I read about broadcasting and the ix_ function and advanced indexing, however i cannot fit the pieces together.

NumPy offers us to create open meshes, which could replace the 3D meshes with the help of np.ogrid . This would essentially replace the ndgrid part from MATLAB code, but with additional benefits (read on for more details). These meshes could then be squared and added to perform the equivalent of l.^2 + m.^2 + n.^2 without actually creating 3D versions of l , m and n as we had done with ndgrid . This is a huge performance criteria and has been explored in this previous post and it has shown performance benefits.

Thus, porting over to NumPy, we would have -

m,n,r = dct.shape    
x,y,z = np.ogrid[0:m,0:n,0:r]
vals = x**2+y**2+z**2
mycell[i][j] = dct[(min**2 <= vals) & (vals <= max**2)]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM