简体   繁体   中英

How to compute pairwise distance between point set and lines in PyTorch?

The point set A is a Nx3 matrix, and from two point sets B and C with the same size of Mx3 we could get the lines BC betwen them. Now I want to compute the distance from each point in A to each line in BC . B is Mx3 and C is Mx3 , then the lines are from the points with correspoinding rows, so BC is a Mx3 matrix. The basic method is computed as follows:

D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
    p = A[i]  # 1x3
    for j in range(M):
        p1 = B[j] # 1x3
        p2 = C[j] # 1x3
        D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2) 

Are there any faster method to do this work? Thanks.

You can remove the for loops by doing this (it should speed-up at the cost of memory, unless M and N are small):

diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines

Of course, you don't need to do it step-by-step. I just tried to be clear with the variable names.

Note : if you don't provide dim to torch.cross , it will use the first dim=3 which would give the wrong results if N=3 (from the docs ):

If dim is not given, it defaults to the first dimension found with the size 3.

If you are wondering, you can check here why I chose expand instead of repeat .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM