简体   繁体   English

随机获取 PyTorch 张量中最大值之一的索引

[英]Randomly get index of one of the maximum values in a PyTorch tensor

I need to perform something similar to the built-in torch.argmax() function on a one-dimensional tensor, but instead of picking the index of the first of the maximum values, I want to be able to pick a random index of one of the maximum values.我需要在一维张量上执行类似于内置 torch.argmax() 函数的操作,但不是选择第一个最大值的索引,我希望能够选择一个随机索引的最大值。 For example:例如:

my_tensor = torch.tensor([0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.1])
index_1 = random_max_val_index_fn(my_tensor)
index_2 = random_max_val_index_fn(my_tensor)

print(f"{index_1}, {index_2}")
> 5, 1

You can get the indexes of all the maximums first and then choose randomly from them:您可以先获取所有最大值的索引,然后从中随机选择:

def rand_argmax(tens):
    max_inds, = torch.where(tens == tens.max())
    return np.random.choice(max_inds)

sample runs:示例运行:

>>> my_tensor = torch.tensor([0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.1])
>>> rand_argmax(my_tensor)
2
>>> rand_argmax(my_tensor)
5
>>> rand_argmax(my_tensor)
2
>>> rand_argmax(my_tensor)
1

I think this should work:我认为这应该有效:

import numpy as np
import torch

your_tensor = torch.tensor([0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.1])
argmaxes = np.argwhere(your_tensor==torch.max(your_tensor)).flatten()
rand_argmax = np.random.choice(argmaxes)
print(rand_argmax)

make sure you adjust for np.random.choice to account for replacement确保您调整np.random.choice以考虑替换

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM