[英]Sorting a three-dimensional numpy array by a two-dimensional index
我在這方面遇到了真正的麻煩。 我有一個三維 numpy 數組,我想通過二維索引數組對其進行重新排序。 實際上,arrays 將以編程方式確定,並且三維數組可能是二維或四維數組,但為簡單起見,如果 arrays 都是二維的,則以下是所需的結果:
ph = np.array([[1,2,3], [3,2,1]])
ph_idx = np.array([[0,1,2], [2,1,0]])
for sub_dim_n, sub_dim_ph_idx in enumerate(ph_idx):
ph[sub_dim_n] = ph[sub_dim_n][sub_dim_ph_idx]
這使得 ph 數組變為:
array([[1, 2, 3],
[1, 2, 3]])
這就是我想要的。 如果情況相同,任何人都可以幫忙,但是我有一個三維數組(psh)而不是 ph,例如:
psh = np.array(
[[[1,2,3]],
[[3,2,1]]]
)
希望這很清楚,如果不是,請詢問。 提前致謝!
如果您想要最終得到一個ph.shape
形狀的數組,您可以簡單地np.squeeze
ph_ixs
使形狀匹配,並使用它來索引ph
:
print(ph)
[[[1 2 3]]
[[3 2 1]]]
print(ph_idx)
[[0 1 2]
[2 1 0]]
np.take_along_axis(np.squeeze(ph), ph_idx, axis=-1)
array([[1, 2, 3],
[1, 2, 3]])
因此,線索已經在此處的有用評論中,但為了完整起見,它就像使用 np.take_along_axis 和 2d 數組的廣播版本一樣簡單:
psh = np.array(
[[[1,2,3]],
[[3,2,1]]]
)
ph_idx = np.array(
[[0,1,2],
[2,1,0]]
)
np.take_along_axis(psh, ph_idx[:, None, :], axis=2)
如果 3d 陣列具有多個元素的 dim1 ,這也具有以下優點:
psh = np.array(
[[[1,2,3],[4,5,6],[7,8,9]],
[[3,2,1],[6,5,4],[9,8,7]]]
)
ph_idx = np.array([[0,1,2], [2,1,0]])
np.take_along_axis(psh, ph_idx[:, None, :], axis=2)
這使
array([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.