繁体   English   中英

如何从 2D pytorch 张量获取最大元素的行和列索引?

[英]how to get the row and column indices of the maximum element from a 2D pytorch tensor?

有什么方法可以检索二维pytorch张量中包含的最大元素的行和列索引? 例如,请参见下面的pytorch张量a

a
>> torch.tensor([1,2,3],
                [9,5,4],
                [6,7,8])

张量a最大的元素是 9,它发生在第二行的第一列。 如果我将其更改为从零开始的 python 列和行索引,则元素的列索引将为 0,行索引将为 1。

有什么方法可以从二维 pytorch 张量a检索索引 [1,0] 吗?

不幸的是,没有内置方法。 但是你可以使用 numpy:

np.unravel_index(torch.argmax(a), a.shape)

否则,您需要编写自己的逻辑,例如:

def unravel_index(flat_idx, shape): 
     multi_idx = [] 
     r = flat_idx 
     for s in shape[:-1]: 
         multi_idx.append(r // s) 
         r = r % s 
     multi_idx.append(r % s) 
     return multi_idx

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM