简体   繁体   English

如何获取两个不同大小的 PyTorch 张量中相等元素的索引?

[英]How to get the indexes of equal elements in two different size PyTorch tensors?

Let's say I have two PyTorch tensors:假设我有两个 PyTorch 张量:

t_1d = torch.Tensor([6, 5, 1, 7, 8, 4, 7, 1, 0, 4, 11, 7, 4, 7, 4, 1])
t = torch.Tensor([4, 7])

I want to get the indices of exact match intersection between the sets for the tensor t_1d with tensor t.我想获得张量 t_1d 与张量 t 的集合之间精确匹配交集的索引。

Desired output of t_1d and t : [5, 12] (first index of exact intersection) t_1dt的所需 output : [5, 12] (精确交集的第一个索引)

Preferably on GPU for large Tensors, so no loops or Numpy casts.对于大张量,最好在 GPU 上,所以没有循环或 Numpy 演员表。

In general, we can check where each element in t is equal to elements in t_1d .一般来说,我们可以检查t中的每个元素在哪里等于t_1d中的元素。

After that, shift back the last element by as many places as it misses from the first element (in general case, here shift by -1 ) and check whether arrays are equal:之后,将最后一个元素从第一个元素移回尽可能多的位置(在一般情况下,这里移-1 )并检查 arrays 是否相等:

intersection = (t_1d == t[0]) & torch.roll(t_1d == t[1], shifts=-1)
torch.where(intersection)[0] # torch.tensor([5, 12])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM