简体   繁体   中英

Top K indices of a multi-dimensional tensor

I have a 2D tensor and I want to get the indices of the top k values. I know about pytorch's topk function. The problem with pytorch's topk function is, it computes the topk values over some dimension. I want to get topk values over both dimensions.

For example for the following tensor

a = torch.tensor([[4, 9, 7, 4, 0],
        [8, 1, 3, 1, 0],
        [9, 8, 4, 4, 8],
        [0, 9, 4, 7, 8],
        [8, 8, 0, 1, 4]])

pytorch's topk function will give me the following.

values, indices = torch.topk(a, 3)

print(indices)
# tensor([[1, 2, 0],
#        [0, 2, 1],
#        [0, 1, 4],
#        [1, 4, 3],
#        [1, 0, 4]])

But I want to get the following

tensor([[0, 1],
        [2, 0],
        [3, 1]])

This is the indices of 9 in the 2D tensor.

Is there any approach to achieve this using pytorch?

v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)

Output:

[[3 1]
 [2 0]
 [0 1]]
  1. Flatten and find top k
  2. Convert 1D indices to 2D using unravel_index

You can flatten the original tensor, apply topk and then convert resultant scalar indices back to multidimensional indices with something like the following:

def descalarization(idx, shape):
    res = []
    N = np.prod(shape)
    for n in shape:
        N //= n
        res.append(idx // N)
        idx %= N
    return tuple(res)

Example:

torch.tensor([descalarization(k, a.size()) for k in torch.topk(a.flatten(), 5).indices])
# Returns 
# tensor([[3, 1],
#         [2, 0],
#         [0, 1],
#         [3, 4],
#         [2, 4]])

You can make some vector operations to filter according to your needs. In this case not using topk.

print(a)
tensor([[4, 9, 7, 4, 0],
    [8, 1, 3, 1, 0],
    [9, 8, 4, 4, 8],
    [0, 9, 4, 7, 8],
    [8, 8, 0, 1, 4]])

values, indices = torch.max(a,1)   # get max values, indices
temp= torch.zeros_like(values)     # temporary
temp[values==9]=1                  # fill temp where values are 9 (wished value)
seq=torch.arange(values.shape[0])  # create a helper sequence
new_seq=seq[temp>0]                # filter sequence where values are 9
new_temp=indices[new_seq]          # filter indices with sequence where values are 9
final = torch.stack([new_seq, new_temp], dim=1)  # stack both to get result

print(final)
tensor([[0, 1],
        [2, 0],
        [3, 1]])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM