簡體   English   中英

如何將numpy slices混合到索引列表?

[英]How to mix numpy slices to list of indices?

我有一個numpy.array ,稱為grid ,形狀為:

grid.shape = [N, M_1, M_2, ..., M_N]

N,M_1,M_2,...,M_N的值僅在初始化后才知道。

對於此示例,假設N = 3且M_1 = 20,M_2 = 17,M_3 = 9:

grid = np.arange(3*20*17*9).reshape(3, 20, 17, 9)

我試圖遍歷此數組,如下所示:

for indices, val in np.ndenumerate(grid[0]):
    print indices
    _some_func_with_N_arguments(*grid[:, indices])

在第一次迭代中,索引=(0,0,0)並且:

grid[:, indices] # array with shape 3,3,17,9

而我希望它僅是三個元素的數組,就像:

grid[:, indices[0], indices[1], indices[2]] # array([   0, 3060, 6120])

但是我不能像上一行那樣實現,因為我不知道a-priori indices的長度是多少。

我正在使用python 2.7,但歡迎使用版本無關的實現:-)

我想你想要這樣的東西:

In [134]: x=np.arange(24).reshape(4,3,2)

In [135]: x
Out[135]: 
array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],

       [[ 6,  7],
        [ 8,  9],
        [10, 11]],

       [[12, 13],
        [14, 15],
        [16, 17]],

       [[18, 19],
        [20, 21],
        [22, 23]]])

In [136]: for i,j in np.ndindex(x[0].shape):
     ...:     print(i,j,x[:,i,j])
     ...:     
(0, 0, array([ 0,  6, 12, 18]))
(0, 1, array([ 1,  7, 13, 19]))
(1, 0, array([ 2,  8, 14, 20]))
(1, 1, array([ 3,  9, 15, 21]))
(2, 0, array([ 4, 10, 16, 22]))
(2, 1, array([ 5, 11, 17, 23]))

第一行是:

In [142]: x[:,0,0]
Out[142]: array([ 0,  6, 12, 18])

將索引元組解壓縮為i,j並在x[:,i,j]是執行此索引的最簡單方法。 但是要將其推廣到其他尺寸,我將不得不使用元組。 x[i,j]x[(i,j)]

In [147]: for ind in np.ndindex(x.shape[1:]):
     ...:     print(ind,x[(slice(None),)+ind])
     ...:     
((0, 0), array([ 0,  6, 12, 18]))
((0, 1), array([ 1,  7, 13, 19]))
...

enumerate

for ind,val in np.ndenumerate(x[0]):
    print(ind,x[(slice(None),)+ind])

您可以手動將slice(None)添加到索引元組:

>>> grid.shape
(3, 20, 17, 9)
>>> indices
(19, 16, 8)
>>> grid[:,19,16,8]
array([3059, 6119, 9179])
>>> grid[(slice(None),) + indices]
array([3059, 6119, 9179])

有關更多信息 ,請參見此處的文檔。

我相信您正在尋找的是grid[1:][grid[0]]

grid = np.array([
        [0, 2, 1],  # N
        [1, 9, 3, 6],  # M_1
        [7, 8, 2, 5, 0, 8, 3],  # M_2
        [4, 8]  # M_3
    ])

np.array([grid[a[0] + 1][n] for a, n in np.ndenumerate(grid[0])])
# array([1, 2, 8])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM