[英]Fastest way to count element in a list satisfying conditions
我希望列表中的元素数量满足另一个列表指定的某些条件。 我的方式是使用sum
和any
。 简单的测试代码是:
>>> x1 = list(xrange(300))
>>> x2 = [random.randrange(20, 50) for i in xrange(30)]
>>> def test():
ns = []
for i in xrange(10000):
ns.append(sum(1 for j in x2 if any(abs(k-j)<=10 for k in x1)))
return ns
使用Profiler显示sum
和any
引起的最多时间,有什么方法可以改善这个?
>>> cProfile.run('ns = test()')
8120003 function calls in 0.699 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.003 0.003 1.552 1.552 <pyshell#678>:2(test)
310000 0.139 0.000 1.532 0.000 <pyshell#678>:5(<genexpr>)
1 0.000 0.000 1.552 1.552 <string>:1(<module>)
7490000 0.196 0.000 0.196 0.000 {abs}
300000 0.345 0.000 1.377 0.000 {any}
10000 0.001 0.000 0.001 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
10000 0.016 0.000 1.548 0.000 {sum}
功能test
仅包含10000次迭代。 通常,我会进行数万次迭代,并且使用cProfile.run
显示此块会导致大部分执行时间。
================================================== =================
编辑
根据@ DavisHerring的回答,使用二进制搜索。
from _bisect import *
>>> x1 = list(xrange(300))
>>> x2 = [random.randrange(20, 50) for i in xrange(30)]
>>> def testx():
ns = []
x2k = sorted(x2)
x1k = sorted(x1)
for i in xrange(10000):
bx = [bisect_left(x1k, xk) for xk in x2k]
rn = sum(1 if k==0 and x1k[k]-xk<=10
else 1 if k==len(x1k) and xk-x1k[k-1]<=10
else xk-x1k[k-1]<=10 or x1k[k]-xk<=10
for k, xk in zip(bx, x2k))
ns.append(rn)
return ns
根据cProfile.run
,达到0.196 seconds
,快3倍+。
你的谓词的本质是至关重要的; 因为它是一条线的距离,您可以为您的数据提供相应的结构以加快搜索速度。 有几种变化:
对列表x1
排序:然后您可以使用二进制搜索来查找最近的值并检查它们是否足够接近。
如果列表x2
更长,并且其大部分元素不在范围内,则可以通过对其进行排序并搜索每个可接受间隔的开始和结束来使其更快一些。
如果对两个列表进行排序,则可以一起执行它们并在线性时间内完成。 这是渐近等价的,当然,除非有其他理由对它们进行排序。
使用区间树数据结构。 适合您需求的非常简单的实现可以如下:
class SimpleIntervalTree:
def __init__(self, points, radius):
intervals = []
l, r = None, None
for p in sorted(points):
if r is None or r < p - radius:
if r is not None:
intervals.append((l, r))
l = p - radius
r = p + radius
if r is not None:
intervals.append((l, r))
self._tree = self._to_tree(intervals)
def _to_tree(self, intervals):
if len(intervals) == 0:
return None
i = len(intervals) // 2
return {
'left': self._to_tree(intervals[0:i]),
'value': intervals[i],
'right': self._to_tree(intervals[i + 1:])
}
def __contains__(self, item):
t = self._tree
while t is not None:
l, r = t['value']
if item < l:
t = t['left']
elif item > r:
t = t['right']
else:
return True
return False
然后你的代码看起来像这样:
x1 = list(range(300))
x2 = [random.randrange(20, 50) for i in range(30)]
it = SimpleIntervalTree(x1, 10)
def test():
ns = []
for i in range(10000):
ns.append(sum(1 for j in x2 if j in it))
return ns
在__init__
,点列表首先转换为连续间隔列表。 接下来,将间隔放入平衡二叉搜索树中 。 在该树中,每个节点包含一个间隔,节点的每个左子树包含较低的间隔,并且节点的每个右子树包含较高的间隔。 这样,每当我们想要测试一个点是否在任何段( __contains__
)中时,我们就从根开始执行二进制搜索。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.