簡體   English   中英

如何計算 PyTorch 中點集和線之間的成對距離?

[英]How to compute pairwise distance between point set and lines in PyTorch?

點集A是一個Nx3矩陣,從具有相同大小的Mx3的兩個點集BC我們可以得到它們之間的線BC 現在我想計算從A中的每個點到BC中的每條線的距離。 BMx3CMx3 ,那么這些線來自具有相應行的點,所以BC是一個Mx3矩陣。 基本方法計算如下:

D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
    p = A[i]  # 1x3
    for j in range(M):
        p1 = B[j] # 1x3
        p2 = C[j] # 1x3
        D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2) 

有沒有更快的方法來完成這項工作? 謝謝。

您可以通過這樣做刪除for循環(它應該以 memory 為代價加速,除非MN很小):

diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines

當然,你不需要一步一步來。 我只是想清楚變量名。

注意:如果您不向torch.cross提供dim ,它將使用第一個dim=3如果N=3會給出錯誤的結果(來自文檔):

如果沒有給出 dim ,則默認為找到的第一個尺寸為 3 的尺寸。

如果你想知道,你可以在這里查看我為什么選擇expand而不是repeat

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM