簡體   English   中英

查找兩個陣列之間最短距離的有效方法?

[英]Efficient way to find the shortest distance between two arrays?

我試圖找到兩組數組之間的最短距離。 x-數組相同,只包含整數。 這是我要執行的操作的一個示例:

import numpy as np
x1 = x2 = np.linspace(-1000, 1000, 2001)
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1)
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10)

def dis(x1, y1, x2, y2):
    return sqrt((y2-y1)**2+(x2-x1)**2)

min_distance = np.inf
for a, b in zip(x1, y1):
    for c, d in zip(x2, y2):
        if dis(a, b, c, d) < min_distance:
            min_distance = dis(a, b, c, d)

>>> min_distance
2.2360679774997898

此解決方案有效,但是問題是運行時。 如果x的長度約為10,000,則該解決方案不可行,因為程序運行時間為O(n ^ 2)。 現在,我嘗試做出一些近似來加快程序速度:

for a, b in zip(x1, y1):
    cut = (x2 > a-20)*(x2 < a+20)
    for c, d in zip(x2, y2):
        if dis(a, b, c, d) < min_distance:
            min_distance = dis(a, b, c, d)

但是該程序花費的時間比我想要的還要長。 現在,據我所知,遍歷一個numpy數組通常效率不高,因此我確定仍有改進的空間。 關於如何加快此程序的速度有什么想法嗎?

您的問題也可以表示為2d碰撞檢測,因此四叉樹可能會有所幫助。 插入和查詢都在O(log n)時間中進行,因此整個搜索將在O(n log n)中進行。

還有一個建議,因為sqrt是單調的,所以您可以比較距離的平方而不是距離本身,這將節省n ^ 2平方根的計算。

scipy具有cdist函數 ,該函數可計算所有成對的點之間的距離:

from scipy.spatial.distance import cdist
import numpy as np

x1 = x2 = np.linspace(-1000, 1000, 2001)
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1)
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10)

R1 = np.vstack((x1,y1)).T
R2 = np.vstack((x2,y2)).T

dists = cdist(R1,R2) # find all mutual distances

print (dists.min())
# output: 2.2360679774997898

它的運行速度比原始for循環快250倍以上。

這是一個困難的問題,如果您願意接受近似值,則可能會有所幫助。 我會檢查一下Spottify的煩惱

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM