繁体   English   中英

3-d numpy 数组的 2 轴上的 argmax

[英]argmax on 2 axis for 3-d numpy array

我想从 3D 矩阵中获取一维索引数组。

例如给定x = np.random.randint(10, size=(10,3,3)) ,我想做一些像np.argmax(x, axis=(1,2))就像你可以使用np.max ,即获得长度为 10 的一维数组,其中包含大小为(3,3)的每个子矩阵的最大值的索引(0 到 8 (3,3)

到目前为止,我还没有发现任何有用的东西,我想避免在第一维上循环(并使用np.argmax(x) ),因为它非常大。

干杯!

重塑以合并最后两个轴,然后使用np.argmax -

idx = x.reshape(x.shape[0],-1).argmax(-1)
out = np.unravel_index(idx, x.shape[-2:])

样品运行 -

In [263]: x = np.random.randint(10, size=(4,3,3))

In [264]: x
Out[264]: 
array([[[0, 9, 2],
        [7, 7, 8],
        [2, 5, 9]],

       [[1, 7, 2],
        [8, 9, 0],
        [2, 8, 3]],

       [[7, 5, 0],
        [7, 1, 6],
        [5, 1, 1]],

       [[0, 7, 3],
        [5, 4, 1],
        [9, 8, 9]]])

In [265]: idx = x.reshape(x.shape[0],-1).argmax(-1)

In [266]: np.unravel_index(idx, x.shape[-2:])
Out[266]: (array([0, 1, 0, 2]), array([1, 1, 0, 0]))

如果您的意思是获取合并索引,那么它更简单 -

x.reshape(x.shape[0],-1).argmax(1)

样品运行 -

In [283]: x
Out[283]: 
array([[[2, 3, 7],
        [8, 1, 0],
        [3, 6, 9]],

       [[8, 0, 5],
        [2, 2, 9],
        [9, 0, 9]],

       [[1, 9, 2],
        [5, 0, 3],
        [7, 2, 1]],

       [[1, 6, 5],
        [2, 3, 7],
        [7, 4, 6]]])

In [284]: x.reshape(x.shape[0],-1).argmax(1)
Out[284]: array([8, 5, 1, 5])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM