[英]how to get the row and column indices of the maximum element from a 2D pytorch tensor?
有什么方法可以检索二维pytorch
张量中包含的最大元素的行和列索引? 例如,请参见下面的pytorch
张量a
:
a
>> torch.tensor([1,2,3],
[9,5,4],
[6,7,8])
张量a
最大的元素是 9,它发生在第二行的第一列。 如果我将其更改为从零开始的 python 列和行索引,则元素的列索引将为 0,行索引将为 1。
有什么方法可以从二维 pytorch 张量a
检索索引 [1,0] 吗?
不幸的是,没有内置方法。 但是你可以使用 numpy:
np.unravel_index(torch.argmax(a), a.shape)
否则,您需要编写自己的逻辑,例如:
def unravel_index(flat_idx, shape):
multi_idx = []
r = flat_idx
for s in shape[:-1]:
multi_idx.append(r // s)
r = r % s
multi_idx.append(r % s)
return multi_idx
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.