[英]get top mapped values in pytorch tensor
我正在尝试获取张量中的前 N 个元素。 我有一个映射告诉我如何对张量值进行排序
values_mapping = {1: 12, 3: 1, 4: 2, 2: 34, 12: 3}
tensor = torch.tensor([1, 4, 12, 2])
tensor.topk(3)
这里的结果应该是torch.tensor([1, 12, 2])
即使用values_mapping
映射后的最高值
有什么办法可以使用手电筒吗? 我们可以告诉火炬如何对它获得的值进行排序吗?
我不知道是否存在更优雅的解决方案,但您可以使用所需的键 select 值,select 值中的前 k 项,使用 topk 索引到 select 键:
values_mapping = {1: 12, 3: 1, 4: 2, 2: 34, 12: 3}
tensor0 = torch.tensor([1, 4, 12, 2])
mapping0 = torch.Tensor([(k, v) for k, v in values_mapping.items() if k in tensor0])
topk = mapping0[:,1].topk(3)
top_keys = mapping0[:,0][topk.indices]
print(top_keys)
>>> tensor([ 2., 1., 12.])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.