繁体   English   中英

查找两个阵列之间最短距离的有效方法?

[英]Efficient way to find the shortest distance between two arrays?

我试图找到两组数组之间的最短距离。 x-数组相同,只包含整数。 这是我要执行的操作的一个示例:

import numpy as np
x1 = x2 = np.linspace(-1000, 1000, 2001)
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1)
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10)

def dis(x1, y1, x2, y2):
    return sqrt((y2-y1)**2+(x2-x1)**2)

min_distance = np.inf
for a, b in zip(x1, y1):
    for c, d in zip(x2, y2):
        if dis(a, b, c, d) < min_distance:
            min_distance = dis(a, b, c, d)

>>> min_distance
2.2360679774997898

此解决方案有效,但是问题是运行时。 如果x的长度约为10,000,则该解决方案不可行,因为程序运行时间为O(n ^ 2)。 现在,我尝试做出一些近似来加快程序速度:

for a, b in zip(x1, y1):
    cut = (x2 > a-20)*(x2 < a+20)
    for c, d in zip(x2, y2):
        if dis(a, b, c, d) < min_distance:
            min_distance = dis(a, b, c, d)

但是该程序花费的时间比我想要的还要长。 现在,据我所知,遍历一个numpy数组通常效率不高,因此我确定仍有改进的空间。 关于如何加快此程序的速度有什么想法吗?

您的问题也可以表示为2d碰撞检测,因此四叉树可能会有所帮助。 插入和查询都在O(log n)时间中进行,因此整个搜索将在O(n log n)中进行。

还有一个建议,因为sqrt是单调的,所以您可以比较距离的平方而不是距离本身,这将节省n ^ 2平方根的计算。

scipy具有cdist函数 ,该函数可计算所有成对的点之间的距离:

from scipy.spatial.distance import cdist
import numpy as np

x1 = x2 = np.linspace(-1000, 1000, 2001)
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1)
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10)

R1 = np.vstack((x1,y1)).T
R2 = np.vstack((x2,y2)).T

dists = cdist(R1,R2) # find all mutual distances

print (dists.min())
# output: 2.2360679774997898

它的运行速度比原始for循环快250倍以上。

这是一个困难的问题,如果您愿意接受近似值,则可能会有所帮助。 我会检查一下Spottify的烦恼

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM