繁体   English   中英

加快python中的代码块

[英]Speeding up block of code in python

我发现了两个潜在的原因,原因是给定的点是10000个size-2列表,因此以下代码段的性能非常差。

  1. “添加”以给定键添加新值
  2. 邻居地图字典。

     def calculate_distance(point1, point2): a = (point1[0], point1[1]) b = (point2[0], point2[1]) return distance.euclidean(a, b) def get_eps_neighbours(points, eps): neighbours = {} index = 0 for p in points: for q in points: if(calculate_distance(p, q) <= eps): if index in neighbours: neighbours[index].append(q) else: neighbours[index] = q index = index + 1 return {'neighbours': neighbours} 

关于如何提高代码效率的任何建议?

这是一个平凡的并行问题的例子。

我的建议:

  • 使用numpy
  • 创建2个点^点矩阵(2D数组),一个用于x另一个用于y
  • 使用numpy的数组算法

例:

In [52]: points = [(1,1), (2,2), (3,3), (4,4)]  # super-simple data

In [54]: Xb = numpy.repeat(numpy.array(points)[:,0], 4).reshape(4, 4)

In [60]: Xb
Out[60]: 
array([[1, 1, 1, 1],
       [2, 2, 2, 2],
       [3, 3, 3, 3],
       [4, 4, 4, 4]])

In [61]: Xa = numpy.tile(numpy.array(points)[:,0], 4).reshape(4, 4)

In [62]: Xa
Out[62]: 
array([[1, 2, 3, 4],
       [1, 2, 3, 4],
       [1, 2, 3, 4],
       [1, 2, 3, 4]])

# Yb = numpy.repeat(numpy.array(points)[:,1], 4).reshape(4, 4)
# Ya = numpy.tile(numpy.array(points)[:,1], 4).reshape(4, 4)

In [65]: D = ((Xa - Xb) ** 2 + (Ya - Yb) ** 2) ** 0.5

In [66]: D
Out[66]: 
array([[ 0.        ,  1.41421356,  2.82842712,  4.24264069],
       [ 1.41421356,  0.        ,  1.41421356,  2.82842712],
       [ 2.82842712,  1.41421356,  0.        ,  1.41421356],
       [ 4.24264069,  2.82842712,  1.41421356,  0.        ]])

In [71]: D < 2
Out[71]: 
array([[ True,  True, False, False],
       [ True,  True,  True, False],
       [False,  True,  True,  True],
       [False, False,  True,  True]], dtype=bool)

# Assuming you want only one copy from each pair (a,b), (b,a)
In [73]: triangle = numpy.tri(4, 4, -1, bool)

In [74]: triangle
Out[74]: 
array([[False, False, False, False],
       [ True, False, False, False],
       [ True,  True, False, False],
       [ True,  True,  True, False]], dtype=bool)

In [76]: neighbours = (D < 2) * triangle  # multiplication for "logical and"
Out[76]: 
array([[False, False, False, False],
       [ True, False, False, False],
       [False,  True, False, False],
       [False, False,  True, False]], dtype=bool)

# Neighbours' x and y coordinates are available so:
In [107]: numpy.compress(neighbours.flatten(), Xa.flatten())
Out[107]: array([1, 2, 3])

# Indices to elements in original `points` list like this:
Indexb = numpy.repeat(numpy.arange(4), 4).reshape(4, 4)
Indexa = numpy.tile(numpy.arange(4), 4).reshape(4, 4)
numpy.transpose([numpy.compress(neighbours.flatten(), Indexa.flatten()),
                 numpy.compress(neighbours.flatten(), Indexb.flatten())])
array([[0, 1],
       [1, 2],
       [2, 3]])

有了算法的总体思路,我认为您可以通过先删除(或复制到另一个列表)仅2*abs(px - qx) <= eps (重复y)的元素来减少经过欧几里得距离测试的点的列表,这比计算所有点的欧几里得要快得多。 如果eps小,那将起作用。

我不知道这是否可以加快您的代码的速度,但是计数循环的Python方法是这样的:

for i, p in enumerate(points):

另外-我不确定我是否每次都能理解整个字典(map)键的逻辑。 这段代码看起来并不像在做有用的事情

neighBourMap[index] = q

这会将键值对q的键值对添加到字典中。 您是否尝试过仅使用列表,即

neighBourMap = []

所有其他答案都是正确的,但是它们不会给您带来巨大的提速。 使用numpy数组可以使您加速,并行化可以使您加速。 但是,如果您拥有一百万个点,并且仍然使用当前的算法(进行n ^ 2距离计算),那么加速将不够。 (100万)^ 2是通往许多目标的方法。 如果您使用numpy?

您应该切换算法。 您应该将点存储在kd树中。 这样,您可以将搜索集中在几个邻近的候选对象上。 除了遍历所有点q ,您可以简单地使用|qx - px| < eps and |qy - py| < eps遍历所有点q |qx - px| < eps and |qy - py| < eps |qx - px| < eps and |qy - py| < eps |qx - px| < eps and |qy - py| < eps 如果您的eps很小,并且每个点只有几个邻居,那么这将使您的速度大大提高。

这是一份pdf文件,其中描述了该算法如何查找特定范围内的所有点: http : //www.cse.unr.edu/~bebis/CS302/Handouts/kdtree.pdf

您希望所有点相互组合。 您可以使用itertools.combinations

由于我们仅在进行所需的组合,因此无需继续查找要附加的字典索引。 我们可以将点及其邻居列表放在一起。

list使用defaultdict意味着我们不必在第一次查找点时手动创建list

另外,您实际上并不需要欧几里得距离的值,只想知道它是否小于其他值。 因此,比较平方将为您提供相同的结果。

要将point用作字典的键,它必须是不可变的,因此我们将其转换为元组:

def distance_squared(a, b):
    diff = complex(*a) - complex(*b)
    return diff.real ** 2 + diff.imag ** 2

from itertools import combinations
from collections import defaultdict

neighbours = defaultdict(list)
eps_squared = eps ** 2

point_neighbours = ((point, neighbours[tuple(point)]) for point in points)

for (p, p_neighbours), (q, _) in combinations(point_neighbours , r=2):
    if distance_squared(p, q) <= eps_squared:
        p_neighbours.append(q)

一件事情你可以取代

index in neighBourMap.keys()):`

只是

index in neighBourMap

由于不需要创建字典键的副本,因此运行速度更快。

更好的是,使用defaultdict(list)可以避免在追加到列表值之前检查键的需求。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM