[英]Get indices of element of one array using indices in another array
假設我有一個形狀為 (2, 2, 2) 的數組a
:
a = np.array([[[7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
和一個數組b
是a
的最大值: b=a.max(-1)
(按行):
b = np.array([[9, 19],
[24, 18]])
我想使用扁平化a
中的索引獲取b
中元素的索引,即a.reshape(-1)
:
array([ 7, 9, 19, 18, 24, 5, 18, 11])
結果應該是一個與b
具有相同形狀的數組,其索引為b
在展平a
中:
array([[1, 2],
[4, 6]])
基本上這是 pytorch 中 return_indices=True 時 maxpool2d 的結果,但我正在尋找 numpy 中的實現。我用過where
但它似乎不起作用,也可以將查找最大值和索引結合在一個go,效率更高? 謝謝你的幫助!
我現在能想到的唯一解決方案是生成一個 2d(或 3d,見下文)索引平面數組的range
,並使用定義b
的最大索引(即a.argmax(-1)
)對其進行索引:
import numpy as np
a = np.array([[[ 7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
multi_inds = a.argmax(-1)
b_shape = a.shape[:-1]
b_size = np.prod(b_shape)
flat_inds = np.arange(a.size).reshape(b_size, -1)
flat_max_inds = flat_inds[range(b_size), multi_inds.ravel()]
max_inds = flat_max_inds.reshape(b_shape)
我用一些有意義的變量名稱將步驟分開,這應該可以解釋發生了什么。
multi_inds
告訴您在a
的每個“行”中選擇哪個“列”以獲得最大值:
>>> multi_inds
array([[1, 0],
[0, 0]])
flat_inds
是一個索引列表,每行要從中選擇一個值:
>>> flat_inds
array([[0, 1],
[2, 3],
[4, 5],
[6, 7]])
這是根據每行中的最大索引精確索引的。 flat_max_inds
是您要查找的值,但在平面數組中:
>>> flat_max_inds
array([1, 2, 4, 6])
所以我們需要重塑它以匹配b.shape
:
>>> max_inds
array([[1, 2],
[4, 6]])
一個稍微晦澀但也更優雅的解決方案是使用 3d 索引數組並在其中使用廣播索引:
import numpy as np
a = np.array([[[ 7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
multi_inds = a.argmax(-1)
i, j = np.indices(a.shape[:-1])
max_inds = np.arange(a.size).reshape(a.shape)[i, j, multi_inds]
這在沒有中間展平為 2d 的情況下做同樣的事情。
最后一部分也是如何從multi_inds
獲取b
,即無需再次調用*max
function:
b = a[i, j, multi_inds]
我有一個類似於基於 np.argmax 和 np.arange 的 Andras 的解決方案。 我建議在 np.argmax 的結果中添加一個分段偏移量,而不是“索引索引”:
import numpy as np
a = np.array([[[7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
off = np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])
>>> off
array([[0, 2],
[4, 6]])
這導致:
>>> a.argmax(-1) + off
array([[1, 2],
[4, 6]])
或者作為單線:
>>> a.argmax(-1) + np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])
array([[1, 2],
[4, 6]])
這是一個長單線
new = np.array([np.where(a.reshape(-1)==x)[0][0] for x in a.max(-1).reshape(-1)]).reshape(2,2)
print(new)
array([[1, 2],
[4, 3]])
但是 number = 18 重復了兩次; 那么目標是哪個索引。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.