简体   繁体   中英

How to find intersection of two sets of 2D tensors (points on a 2D plane) in Pytorch

I have two lists of Pytorch 2D tensors, which are points on a plane:

ListA = tensor([ [1.0,2.0], [1.0,3.0], [4.0,8.0] ], device='cuda:0')
ListB = tensor([ [5.0,7.0], [1.0,2.0], [4.0,8.0] ], device='cuda:0')

How to compute?

Desired output = tensor([ [1.0,2.0] ,  [4.0,8.0] ], device='cuda:0')

I would like to find the Intersection between two lists ListA and ListB .

Note: Computation should be carried out only on CUDA.

There is no direct way in PyTorch to accomplish this (ie, through a function). However, a workaround can be.

Flatten ing both tensors:

combined = torch.cat((ListA.view(-1), ListB.view(-1)))
combined
Out[52]: tensor([1., 2., 1., 3., 4., 8., 5., 7., 1., 2., 4., 8.], device='cuda:0')

Finding unique elements:

unique, counts = combined.unique(return_counts=True)
intersection = unique[counts > 1].reshape(-1, ListA.shape[1])
intersection
Out[55]: 
tensor([[1., 2.],
        [4., 8.]], device='cuda:0')

Benchmarks :

def find_intersection_two_tensors(A: tensor, B:tensor):
       combined = torch.cat((A.view(-1), B.view(-1)))
       unique, counts = combined.unique(return_counts=True)
       return unique[counts > 1].reshape(-1, A.shape[1])

Timing it

%timeit find_intersection_two_tensors(ListA, ListB)
207 µs ± 2.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

If you are ok with moving to CPU, numpy could be a better solution in regards to performance:

def find_intersection_two_ndarray(AGPU: tensor, BGPU: tensor):
      A = AGPU.view(-1).cpu().numpy()
      B = BGPU.view(-1).cpu().numpy()
      C = np.intersect1d(A, B)
      return torch.from_numpy(C).cuda('cuda:0')

Timing it

%timeit find_intersection_two_ndarray(ListA, ListB)
85.4 µs ± 1.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM