[英]argmax on 2 axis for 3-d numpy array
我想从 3D 矩阵中获取一维索引数组。
例如给定x = np.random.randint(10, size=(10,3,3))
,我想做一些像np.argmax(x, axis=(1,2))
就像你可以使用np.max
,即获得长度为 10 的一维数组,其中包含大小为(3,3)
的每个子矩阵的最大值的索引(0 到 8 (3,3)
。
到目前为止,我还没有发现任何有用的东西,我想避免在第一维上循环(并使用np.argmax(x)
),因为它非常大。
干杯!
重塑以合并最后两个轴,然后使用np.argmax
-
idx = x.reshape(x.shape[0],-1).argmax(-1)
out = np.unravel_index(idx, x.shape[-2:])
样品运行 -
In [263]: x = np.random.randint(10, size=(4,3,3))
In [264]: x
Out[264]:
array([[[0, 9, 2],
[7, 7, 8],
[2, 5, 9]],
[[1, 7, 2],
[8, 9, 0],
[2, 8, 3]],
[[7, 5, 0],
[7, 1, 6],
[5, 1, 1]],
[[0, 7, 3],
[5, 4, 1],
[9, 8, 9]]])
In [265]: idx = x.reshape(x.shape[0],-1).argmax(-1)
In [266]: np.unravel_index(idx, x.shape[-2:])
Out[266]: (array([0, 1, 0, 2]), array([1, 1, 0, 0]))
如果您的意思是获取合并索引,那么它更简单 -
x.reshape(x.shape[0],-1).argmax(1)
样品运行 -
In [283]: x
Out[283]:
array([[[2, 3, 7],
[8, 1, 0],
[3, 6, 9]],
[[8, 0, 5],
[2, 2, 9],
[9, 0, 9]],
[[1, 9, 2],
[5, 0, 3],
[7, 2, 1]],
[[1, 6, 5],
[2, 3, 7],
[7, 4, 6]]])
In [284]: x.reshape(x.shape[0],-1).argmax(1)
Out[284]: array([8, 5, 1, 5])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.