[英]How to count number of combinations?
I have the problem that I want to count the number of combinations that fulfill the following condition: 我有问题,我想计算满足以下条件的组合数:
a < b < a+d < c < b+d
Where a, b, c
are elements of a list, and d
is a fixed delta. 其中
a, b, c
是列表的元素, d
是固定的delta。
Here is a vanilla implementation: 这是一个香草实现:
def count(l, d):
s = 0
for a in l:
for b in l:
for c in l:
if a < b < a + d < c < b + d:
s += 1
return s
Here is a test: 这是一个测试:
def testCount():
l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
assert(32 == count(l, 4)) # Gone through everything by hand.
How can I speed this up? 我怎样才能加快速度呢? I am looking at list sizes of 2 Million.
我正在查看200万的列表大小。
I am dealing with floats in the range of [-pi, pi]. 我正在处理[-pi,pi]范围内的浮点数。 For example, this limits
a < 0
. 例如,这限制
a < 0
。
I have some implementation where I build indices that I use for b
and c
. 我有一些实现,我建立了我用于
b
和c
索引。 However, the below code fails some cases. 但是,以下代码在某些情况下失败。 (ie This is wrong ).
(即这是错误的 )。
def count(l, d=pi):
low = lower(l, d)
high = upper(l, d)
s = 0
for indA in range(len(l)):
for indB in range(indA+1, low[indA]+1):
s += low[indB] + 1 - high[indA]
return s
def lower(l, d=pi):
'''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
Input must be sorted!
'''
ind = []
x = 0
length = len(l)
for elem in l:
while x < length and l[x] < elem + d:
x += 1
if l[x-1] < elem + d:
ind.append(x-1)
else:
assert(x == length)
ind.append(x)
return ind
def upper(l, d=pi):
''' Returns first index where l[i] > l + d'''
ind = []
x = 0
length = len(l)
for elem in l:
while x < length and l[x] <= elem + d:
x += 1
ind.append(x)
return ind
The original problem is from a well known math/comp-sci competition. 最初的问题来自众所周知的数学/综合竞赛。 The competition asks that you don't post solutions on the net.
比赛要求您不要在网上发布解决方案。 But it is from two weeks ago.
但它是从两周前开始的。
I can generate the list with this function: 我可以用这个函数生成列表:
def points(n):
x = 1
y = 1
for _ in range(n):
x = (x * 1248) % 32323
y = (y * 8421) % 30103
yield atan2(x - 16161, y - 15051)
def C(n):
angles = points(n)
angles.sort()
return count(angles, pi)
from bisect import bisect_left, bisect_right
from collections import Counter
def count(l, d):
# cdef long bleft, bright, cleft, cright, ccount, s
s = 0
# Find the unique elements and their counts
cc = Counter(l)
l = sorted(cc.keys())
# Generate a cumulative sum array
cumulative = [0] * (len(l) + 1)
for i, key in enumerate(l, start=1):
cumulative[i] = cumulative[i-1] + cc[key]
# Pregenerate all the left and right lookups
lefthand = [bisect_right(l, a + d) for a in l]
righthand = [bisect_left(l, a + d) for a in l]
aright = bisect_left(l, l[-1] - d)
for ai in range(len(l)):
bleft = ai + 1
# Search only the values of a that have a+d in range
if bleft > aright:
break
# This finds b such that a < b < a + d.
bright = righthand[ai]
for bi in range(bleft, bright):
# This finds the range for c such that a+d < c < b+d.
cleft = lefthand[ai]
cright = righthand[bi]
if cleft != cright:
# Find the count of c elements in the range cleft..cright.
ccount = cumulative[cright] - cumulative[cleft]
s += cc[l[ai]] * cc[l[bi]] * ccount
return s
def testCount():
l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
result = count(l, 4)
assert(32 == result)
testCount()
gets rid of repeated, identical values 摆脱重复的,相同的价值观
iterates over only the required range for a value 仅迭代值的所需范围
uses a cumulative count across two indices to eliminate the loop over c
使用两个索引的累积计数来消除
c
的循环
cache lookups on x + d
在
x + d
上缓存查找
This is no longer O(n^3)
but more like O(n^2)`. 这不再是
O(n^3)
而是更像O(n ^ 2)`。
This clearly does not yet scale up to 2 million. 这显然还没有达到200万。 Here are my times on smaller floating point data sets (ie few or no duplicates) using cython to speed up execution:
以下是使用cython加速执行的较小浮点数据集(即很少或没有重复)的时间:
50: 0:00:00.157849 seconds
100: 0:00:00.003752 seconds
200: 0:00:00.022494 seconds
400: 0:00:00.071192 seconds
800: 0:00:00.253750 seconds
1600: 0:00:00.951133 seconds
3200: 0:00:03.508596 seconds
6400: 0:00:10.869102 seconds
12800: 0:00:55.986448 seconds
Here is my benchmarking code (not including the operative code above): 这是我的基准代码(不包括上面的操作代码):
from math import atan2, pi
def points(n):
x, y = 1, 1
for _ in range(n):
x = (x * 1248) % 32323
y = (y * 8421) % 30103
yield atan2(x - 16161, y - 15051)
def C(n):
angles = sorted(points(n))
return count(angles, pi)
def test_large():
from datetime import datetime
for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
s = datetime.now()
C(n)
elapsed = datetime.now() - s
print("{1}: {0} seconds".format(elapsed, n))
if __name__ == '__main__':
testCount()
test_large()
There is an approach to your problem that yields an O(n log n)
algorithm. 有一种方法可以产生
O(n log n)
算法。 Let X
be the set of values. 设
X
是一组值。 Now let's fix b
. 现在让我们修复
b
。 Let A_b
be the set of values { x in X: b - d < x < b }
and C_b
be the set of values { x in X: b < x < b + d }
. 令
A_b
为值的集合{ x in X: b - d < x < b }
并且C_b
是值的集合{ x in X: b < x < b + d }
。 If we can find |{ (x,y) : A_b X C_b | y > x + d }|
如果我们能找到
|{ (x,y) : A_b X C_b | y > x + d }|
|{ (x,y) : A_b X C_b | y > x + d }|
fast, we solved the problem. 快,我们解决了这个问题。
If we sort X
, we can represent A_b
and C_b
as pointers into the sorted array, because they are contiguous. 如果我们对
X
排序,我们可以将A_b
和C_b
表示为排序数组的指针,因为它们是连续的。 If we process the b
candidates in non-decreasing order, we can thus maintain these sets using a sliding window algorithm . 如果我们以非递减顺序处理
b
候选,我们可以使用滑动窗口算法来维护这些集合。 It goes like this: 它是这样的:
X
. X
Let X = { x_1, x_2, ..., x_n }
, x_1 <= x_2 <= ... <= x_n
. X = { x_1, x_2, ..., x_n }
, x_1 <= x_2 <= ... <= x_n
。 left = i = 1
and set right
so that C_b = { x_{i + 1}, ..., x_right }
. left = i = 1
并right
设置,使得C_b = { x_{i + 1}, ..., x_right }
。 Set count = 0
count = 0
i
from 1
to n
. 1
到n
迭代i
。 In every iteration we find out the number of valid triples (a,b,c)
with b = x_i
. (a,b,c)
的数量,其中b = x_i
。 To do that, increase left
and right
as much as necessary so that A_b = { x_left, ..., x_{i-1} }
and C_b = { x_{i + 1}, ..., x_right }
still holds. left
和right
增加,以便A_b = { x_left, ..., x_{i-1} }
和C_b = { x_{i + 1}, ..., x_right }
仍然成立。 In the process, you basically add and remove elements from the imaginary sets A_b
and C_b
. A_b
和C_b
。 If you remove or add an element to one of the sets, check how many pairs (a, c)
with c > a + d
, a
from A_b
and c
from C_b
you add or destroy (this can be achieved by a simple binary search in the other set). (a, c)
与c > a + d
, a
从A_b
和c
从C_b
您添加或破坏(这可以通过一个简单的二进制搜索实现在另一组)。 Update count
accordingly so that the invariant count = |{ (x,y) : A_b X C_b | y > x + d }|
count
,使得不变count = |{ (x,y) : A_b X C_b | y > x + d }|
count = |{ (x,y) : A_b X C_b | y > x + d }|
still holds. count
in every iteration. count
值。 This is the final result. The complexity is O(n log n)
. 复杂度为
O(n log n)
。
If you want to solve the Euler problem with this algorithm, you have to avoid floating point issues. 如果要使用此算法解决Euler问题,则必须避免出现浮点问题。 I suggest sorting the points by angle using a custom comparison function that uses integer arithmetics only (using 2D vector geometry).
我建议使用仅使用整数算术的自定义比较函数(使用2D矢量几何)按角度对点进行排序。 Implementing the
|ab| < d
实现
|ab| < d
|ab| < d
comparisons can also be done using integer operations only. |ab| < d
比较也可以仅使用整数运算来完成。 Also, since you are working modulo 2*pi
, you would probably have to introduce three copies of every angle a
: a - 2*pi
, a
and a + 2*pi
. 此外,由于您正在使用模
2*pi
,您可能需要引入每个角度a
三个副本a
: a - 2*pi
, a
和a + 2*pi
。 You then only look for b
in the range [0, 2*pi)
and divide the result by three. 然后,您只在
[0, 2*pi)
范围内查找b
并将结果除以3。
UPDATE OP implemented this algorithm in Python. UPDATE OP在Python中实现了这个算法。 Apparently it contains some bugs but it demonstrates the general idea:
显然它包含一些错误,但它表明了一般的想法:
def count(X, d):
X.sort()
count = 0
s = 0
length = len(X)
a_l = 0
a_r = 1
c_l = 0
c_r = 0
for b in X:
if X[a_r-1] < b:
# find boundaries of A s.t. b -d < a < b
while a_r < length and X[a_r] < b:
a_r += 1 # This adds an element to A_b.
ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
if c_l <= ind < c_r:
count += (ind - c_l)
while a_l < length and X[a_l] <= b - d:
a_l += 1 # This removes an element from A_b
ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
if c_l <= ind < c_r:
count -= (c_r - ind)
# Find boundaries of C s.t. b < c < b + d
while c_l < length and X[c_l] <= b:
c_l += 1 # this removes an element from C_b
ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
if a_l <= ind <= a_r:
count -= (ind - a_l)
while c_r < length and X[c_r] < b + d:
c_r += 1 # this adds an element to C_b
ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
if a_l <= ind <= a_r:
count += (ind - a_l)
s += count
return s
Since l
is sorted and a < b < c
must be true, you could use itertools.combinations()
to do fewer loops: 由于
l
已排序且a < b < c
必须为true,因此您可以使用itertools.combinations()
来执行更少的循环:
sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
Looking at combinations only reduces this loop to 816 iterations. 查看组合仅将此循环减少到816次迭代。
>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> d = 4
>>> sum(1 for a, b, c in combinations(l, r=3))
816
>>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
32
where the a < b
test is redundant. 其中
a < b
测试是多余的。
1) To reduce amount of iterations on each level you can remove elements from list that dont pass condition on each level 1)为了减少每个级别上的迭代次数,您可以从列表中删除不在每个级别上传递条件的元素
2) Using set
with collections.counter
you can reduce iterations by removing duplicates: 2)使用
set
with collections.counter
您可以通过删除重复项来减少迭代:
from collections import Counter
def count(l, d):
n = Counter(l)
l = set(l)
s = 0
for a in l:
for b in (i for i in l if a < i < a+d):
for c in (i for i in l if a+d < i < b+d):
s += (n[a] * n[b] * n[c])
return s
>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> count(l, 4)
32
Tested count of iterations (a, b, c) for your version: 测试版本的迭代次数(a,b,c):
>>> count1(l, 4)
18 324 5832
my version: 我的版本:
>>> count2(l, 4)
9 16 7
The basic ideas are: 基本思路是:
As a result you can increase s unconditionally and the performance is roughly O(N), with N being the size of the array. 因此,您可以无条件地增加s,性能大致为O(N),N是数组的大小。
import collections
def count(l, d):
s = 0
# at first we get rid of repeated items
counter = collections.Counter(l)
# sort the list
uniq = sorted(set(l))
n = len(uniq)
# kad is the index of the first element > a+d
kad = 0
# ka is the index of a
for ka in range(n):
a = uniq[ka]
while uniq[kad] <= a+d:
kad += 1
if kad == n:
return s
for kb in range( ka+1, kad ):
# b only runs in the range [a..a+d)
b = uniq[kb]
if b >= a+d:
break
for kc in range( kad, n ):
# c only rund from (a+d..b+d)
c = uniq[kc]
if c >= b+d:
break
print( a, b, c )
s += counter[a] * counter[b] * counter[c]
return s
EDIT: Sorry, I messed up the submission. 编辑:对不起,我搞砸了提交。 Fixed.
固定。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.