[英]Pytorch: How can I find indices of first nonzero element in each row of a 2D tensor?
I have a 2D tensor with some nonzero element in each row like this:我有一个二维张量,每行都有一些非零元素,如下所示:
import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
I want a tensor containing the index of first nonzero element in each row:我想要一个包含每行中第一个非零元素索引的张量:
indices = tensor([2],
[3])
How can I calculate it in Pytorch?我如何在 Pytorch 中计算它?
I have simplified Iman's approach to do the following:我简化了 Iman 的方法来执行以下操作:
idx = torch.arange(tmp.shape[1], 0, -1)
tmp2= tmp * idx
indices = torch.argmax(tmp2, 1, keepdim=True)
I could find a tricky answer for my question:我可以为我的问题找到一个棘手的答案:
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
idx = reversed(torch.Tensor(range(1,8)))
print(idx)
tmp2= torch.einsum("ab,b->ab", (tmp, idx))
print(tmp2)
indices = torch.argmax(tmp2, 1, keepdim=True)
print(indeces)
The result is:结果是:
tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
[0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
[3]])
All the nonzero values are equal, so argmax
returns the first index.所有非零值都相等,因此argmax
返回第一个索引。
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]])
indices = tmp.argmax(1)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.