繁体   English   中英

获取 pytorch 张量中的最高映射值

[英]get top mapped values in pytorch tensor

我正在尝试获取张量中的前 N 个元素。 我有一个映射告诉我如何对张量值进行排序

values_mapping = {1: 12, 3: 1, 4: 2, 2: 34, 12: 3}
tensor = torch.tensor([1, 4, 12, 2])
tensor.topk(3)

这里的结果应该是torch.tensor([1, 12, 2])即使用values_mapping映射后的最高值

有什么办法可以使用手电筒吗? 我们可以告诉火炬如何对它获得的值进行排序吗?

我不知道是否存在更优雅的解决方案,但您可以使用所需的键 select 值,select 值中的前 k 项,使用 topk 索引到 select 键:

values_mapping = {1: 12, 3: 1, 4: 2, 2: 34, 12: 3}
tensor0 = torch.tensor([1, 4, 12, 2])
mapping0 = torch.Tensor([(k, v) for k, v in values_mapping.items() if k in tensor0])
topk = mapping0[:,1].topk(3)
top_keys = mapping0[:,0][topk.indices]
print(top_keys)
>>> tensor([ 2.,  1., 12.])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM