简体   繁体   中英

How to select values in a n-dimensional array

I have been trying to perform a simple operation, but I can't seem to find a simple way to do it using Numpy functions without creating unnecessary copies of the array.

Suppose we have the following 3-dimensional array :

In [171]: x = np.arange(24).reshape((4, 3, 2))
In [172]: x
Out[172]: 
array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],

       [[ 6,  7],
        [ 8,  9],
        [10, 11]],

       [[12, 13],
        [14, 15],
        [16, 17]],

       [[18, 19],
        [20, 21],
        [22, 23]]])

And the following array :

In [173]: y = np.array([0, 1, 1, 0])

I want to select in x , for each row, the value of the last dimension whose index is the corresponding element in y . In other words, I want :

array([[ 0,  2, 4],
       [ 7,  9, 11],
       [13, 15, 17],
       [18, 20, 22]])

The only solution that I have for now is using a for loop over the first dimension of x and y , as follows :

z = np.zeros((4, 3), dtype=int)
for i, row in enumerate(x):
    z[i, :] = row[:, y[i]]

Is there a way of avoiding a for loop here, using numpy functions or fancy indexing?

Thanks!

The tricky aspect is that you don't want all of the 0th-dimension for each slice, you want the slices to correspond to each element in the 0th-dimension. So you could do something like:

>>> x[np.arange(x.shape[0]), :, y]
array([[ 0,  2,  4],
       [ 7,  9, 11],
       [13, 15, 17],
       [18, 20, 22]])

Fancy indexing:

x[np.arange(y.size),:,y]

gives:

array([[ 0,  2,  4],
       [ 7,  9, 11],
       [13, 15, 17],
       [18, 20, 22]])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM