簡體   English   中英

使用另一個數組中的索引獲取一個數組元素的索引

[英]Get indices of element of one array using indices in another array

假設我有一個形狀為 (2, 2, 2) 的數組a

a = np.array([[[7, 9],
               [19, 18]],
              [[24, 5],
               [18, 11]]])

和一個數組ba的最大值: b=a.max(-1) (按行):

b = np.array([[9, 19],
              [24, 18]])

我想使用扁平化a中的索引獲取b中元素的索引,即a.reshape(-1)

array([ 7,  9, 19, 18, 24,  5, 18, 11])

結果應該是一個與b具有相同形狀的數組,其索引為b在展平a中:

array([[1, 2],
       [4, 6]])

基本上這是 pytorch 中 return_indices=True 時 maxpool2d 的結果,但我正在尋找 numpy 中的實現。我用過where但它似乎不起作用,也可以將查找最大值和索引結合在一個go,效率更高? 謝謝你的幫助!

我現在能想到的唯一解決方案是生成一個 2d(或 3d,見下文)索引平面數組的range ,並使用定義b的最大索引(即a.argmax(-1) )對其進行索引:

import numpy as np

a = np.array([[[ 7,  9],
               [19, 18]],
              [[24,  5],
               [18, 11]]])
multi_inds = a.argmax(-1)
b_shape = a.shape[:-1]
b_size = np.prod(b_shape)
flat_inds = np.arange(a.size).reshape(b_size, -1)
flat_max_inds = flat_inds[range(b_size), multi_inds.ravel()]
max_inds = flat_max_inds.reshape(b_shape)

我用一些有意義的變量名稱將步驟分開,這應該可以解釋發生了什么。

multi_inds告訴您在a的每個“行”中選擇哪個“列”以獲得最大值:

>>> multi_inds
array([[1, 0],
       [0, 0]])

flat_inds是一個索引列表,每行要從中選擇一個值:

>>> flat_inds
array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]])

這是根據每行中的最大索引精確索引的。 flat_max_inds是您要查找的值,但在平面數組中:

>>> flat_max_inds
array([1, 2, 4, 6])

所以我們需要重塑它以匹配b.shape

>>> max_inds
array([[1, 2],
       [4, 6]])

一個稍微晦澀但也更優雅的解決方案是使用 3d 索引數組並在其中使用廣播索引:

import numpy as np

a = np.array([[[ 7,  9],
               [19, 18]],
              [[24,  5],
               [18, 11]]])
multi_inds = a.argmax(-1)
i, j = np.indices(a.shape[:-1])
max_inds = np.arange(a.size).reshape(a.shape)[i, j, multi_inds]

這在沒有中間展平為 2d 的情況下做同樣的事情。

最后一部分也是如何從multi_inds獲取b ,即無需再次調用*max function:

b = a[i, j, multi_inds]

我有一個類似於基於 np.argmax 和 np.arange 的 Andras 的解決方案。 我建議在 np.argmax 的結果中添加一個分段偏移量,而不是“索引索引”:

import numpy as np
a = np.array([[[7, 9],
               [19, 18]],
              [[24, 5],
               [18, 11]]])
off = np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])
>>> off
array([[0, 2],
       [4, 6]])

這導致:

>>> a.argmax(-1) + off
array([[1, 2],
       [4, 6]])

或者作為單線:

>>> a.argmax(-1) + np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])
array([[1, 2],
       [4, 6]])

這是一個長單線

new = np.array([np.where(a.reshape(-1)==x)[0][0] for x in a.max(-1).reshape(-1)]).reshape(2,2)
print(new)
array([[1, 2],
       [4, 3]])

但是 number = 18 重復了兩次; 那么目標是哪個索引。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM