简体   繁体   中英

How to slice a 3d ndarray with a index vector

Considering a 3d numpy array:

ax1, ax2, ax3 = 3, 3, 2
arr = np.asarray(range(ax1*ax2*ax3)).reshape([ax1, ax2, ax3])
arr: [[[0, 1], [2, 3], [4, 5]],
      [[6, 7], [8, 9], [10, 11]],
      [[12, 13], [14, 15], [16, 17]]]

and a index vector idx = [0, 1, 2] .

I want to slice the array arr with idx by the following statement:

res = [arr[i, :idx[i]+1] for i in range(ax1)]
res: [[[0, 1]],
      [[6, 7], [8, 9],
      [[12, 13], [14, 15], [16, 17]]]

But this kind of slicing looks complicated.

Does numpy support such a operation without using loop? I am looking for something like arr[range(ax1), :idx+1] .

Your problem is that the resulting values are not rectangular: you cannot represent that properly as an array.

If you're fine with having the values only in a different format, you can get what you need through a boolean mask :

>>> mask = np.tri(3, 3, dtype=bool)
>>> arr[mask]
array([[ 0,  1],
       [ 6,  7],
       [ 8,  9],
       [12, 13],
       [14, 15],
       [16, 17]])

The principle being that you pass, for each pair of indices in [0;2]^2, whether you should take that pair or not:

>>> np.tri(3, 3, dtype=bool)
array([[ True, False, False],
       [ True,  True, False],
       [ True,  True,  True]], dtype=bool)

Which leads to the marvellously concise:

>>> arr[np.tri(3, 3, dtype=bool)]
array([[ 0,  1],
       [ 6,  7],
       [ 8,  9],
       [12, 13],
       [14, 15],
       [16, 17]])

Here's a vectorized approach assuming arr and idx as NumPy arrays -

np.split(arr[np.arange(arr.shape[0]) <= idx[:,None]],(idx+1).cumsum())[:-1]

Sample run to verify results -

In [5]: arr
Out[5]: 
array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],

       [[ 6,  7],
        [ 8,  9],
        [10, 11]],

       [[12, 13],
        [14, 15],
        [16, 17]]])

In [6]: idx
Out[6]: array([2, 0, 1])

In [7]: np.split(arr[np.arange(arr.shape[0]) <= idx[:,None]],(idx+1).cumsum())[:-1]
Out[7]: 
[array([[0, 1],
        [2, 3],
        [4, 5]]), array([[6, 7]]), array([[12, 13],
        [14, 15]])]

In [8]: [arr[i, :idx[i]+1] for i in range(ax1)] # Loopy approach
Out[8]: 
[array([[0, 1],
        [2, 3],
        [4, 5]]), array([[6, 7]]), array([[12, 13],
        [14, 15]])]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM