[英]How to compute pairwise distance between point set and lines in PyTorch?
The point set A
is a Nx3
matrix, and from two point sets B
and C
with the same size of Mx3
we could get the lines BC
betwen them.点集A
是一个Nx3
矩阵,从具有相同大小的Mx3
的两个点集B
和C
我们可以得到它们之间的线BC
。 Now I want to compute the distance from each point in A
to each line in BC
.现在我想计算从A
中的每个点到BC
中的每条线的距离。 B
is Mx3
and C
is Mx3
, then the lines are from the points with correspoinding rows, so BC
is a Mx3
matrix. B
是Mx3
和C
是Mx3
,那么这些线来自具有相应行的点,所以BC
是一个Mx3
矩阵。 The basic method is computed as follows:基本方法计算如下:
D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
p = A[i] # 1x3
for j in range(M):
p1 = B[j] # 1x3
p2 = C[j] # 1x3
D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2)
Are there any faster method to do this work?有没有更快的方法来完成这项工作? Thanks.谢谢。
You can remove the for
loops by doing this (it should speed-up at the cost of memory, unless M
and N
are small):您可以通过这样做删除for
循环(它应该以 memory 为代价加速,除非M
和N
很小):
diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines
Of course, you don't need to do it step-by-step.当然,你不需要一步一步来。 I just tried to be clear with the variable names.我只是想清楚变量名。
Note : if you don't provide dim
to torch.cross
, it will use the first dim=3
which would give the wrong results if N=3
(from the docs ):注意:如果您不向torch.cross
提供dim
,它将使用第一个dim=3
如果N=3
会给出错误的结果(来自文档):
If dim is not given, it defaults to the first dimension found with the size 3.如果没有给出 dim ,则默认为找到的第一个尺寸为 3 的尺寸。
If you are wondering, you can check here why I chose expand
instead of repeat
.如果你想知道,你可以在这里查看我为什么选择expand
而不是repeat
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.