[英]How to count number of combinations?
我有問題,我想計算滿足以下條件的組合數:
a < b < a+d < c < b+d
其中a, b, c
是列表的元素, d
是固定的delta。
這是一個香草實現:
def count(l, d):
s = 0
for a in l:
for b in l:
for c in l:
if a < b < a + d < c < b + d:
s += 1
return s
這是一個測試:
def testCount():
l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
assert(32 == count(l, 4)) # Gone through everything by hand.
我怎樣才能加快速度呢? 我正在查看200萬的列表大小。
我正在處理[-pi,pi]范圍內的浮點數。 例如,這限制a < 0
。
我有一些實現,我建立了我用於b
和c
索引。 但是,以下代碼在某些情況下失敗。 (即這是錯誤的 )。
def count(l, d=pi):
low = lower(l, d)
high = upper(l, d)
s = 0
for indA in range(len(l)):
for indB in range(indA+1, low[indA]+1):
s += low[indB] + 1 - high[indA]
return s
def lower(l, d=pi):
'''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
Input must be sorted!
'''
ind = []
x = 0
length = len(l)
for elem in l:
while x < length and l[x] < elem + d:
x += 1
if l[x-1] < elem + d:
ind.append(x-1)
else:
assert(x == length)
ind.append(x)
return ind
def upper(l, d=pi):
''' Returns first index where l[i] > l + d'''
ind = []
x = 0
length = len(l)
for elem in l:
while x < length and l[x] <= elem + d:
x += 1
ind.append(x)
return ind
最初的問題來自眾所周知的數學/綜合競賽。 比賽要求您不要在網上發布解決方案。 但它是從兩周前開始的。
我可以用這個函數生成列表:
def points(n):
x = 1
y = 1
for _ in range(n):
x = (x * 1248) % 32323
y = (y * 8421) % 30103
yield atan2(x - 16161, y - 15051)
def C(n):
angles = points(n)
angles.sort()
return count(angles, pi)
from bisect import bisect_left, bisect_right
from collections import Counter
def count(l, d):
# cdef long bleft, bright, cleft, cright, ccount, s
s = 0
# Find the unique elements and their counts
cc = Counter(l)
l = sorted(cc.keys())
# Generate a cumulative sum array
cumulative = [0] * (len(l) + 1)
for i, key in enumerate(l, start=1):
cumulative[i] = cumulative[i-1] + cc[key]
# Pregenerate all the left and right lookups
lefthand = [bisect_right(l, a + d) for a in l]
righthand = [bisect_left(l, a + d) for a in l]
aright = bisect_left(l, l[-1] - d)
for ai in range(len(l)):
bleft = ai + 1
# Search only the values of a that have a+d in range
if bleft > aright:
break
# This finds b such that a < b < a + d.
bright = righthand[ai]
for bi in range(bleft, bright):
# This finds the range for c such that a+d < c < b+d.
cleft = lefthand[ai]
cright = righthand[bi]
if cleft != cright:
# Find the count of c elements in the range cleft..cright.
ccount = cumulative[cright] - cumulative[cleft]
s += cc[l[ai]] * cc[l[bi]] * ccount
return s
def testCount():
l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
result = count(l, 4)
assert(32 == result)
testCount()
擺脫重復的,相同的價值觀
僅迭代值的所需范圍
使用兩個索引的累積計數來消除c
的循環
在x + d
上緩存查找
這不再是O(n^3)
而是更像O(n ^ 2)`。
這顯然還沒有達到200萬。 以下是使用cython加速執行的較小浮點數據集(即很少或沒有重復)的時間:
50: 0:00:00.157849 seconds
100: 0:00:00.003752 seconds
200: 0:00:00.022494 seconds
400: 0:00:00.071192 seconds
800: 0:00:00.253750 seconds
1600: 0:00:00.951133 seconds
3200: 0:00:03.508596 seconds
6400: 0:00:10.869102 seconds
12800: 0:00:55.986448 seconds
這是我的基准代碼(不包括上面的操作代碼):
from math import atan2, pi
def points(n):
x, y = 1, 1
for _ in range(n):
x = (x * 1248) % 32323
y = (y * 8421) % 30103
yield atan2(x - 16161, y - 15051)
def C(n):
angles = sorted(points(n))
return count(angles, pi)
def test_large():
from datetime import datetime
for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
s = datetime.now()
C(n)
elapsed = datetime.now() - s
print("{1}: {0} seconds".format(elapsed, n))
if __name__ == '__main__':
testCount()
test_large()
有一種方法可以產生O(n log n)
算法。 設X
是一組值。 現在讓我們修復b
。 令A_b
為值的集合{ x in X: b - d < x < b }
並且C_b
是值的集合{ x in X: b < x < b + d }
。 如果我們能找到|{ (x,y) : A_b X C_b | y > x + d }|
|{ (x,y) : A_b X C_b | y > x + d }|
快,我們解決了這個問題。
如果我們對X
排序,我們可以將A_b
和C_b
表示為排序數組的指針,因為它們是連續的。 如果我們以非遞減順序處理b
候選,我們可以使用滑動窗口算法來維護這些集合。 它是這樣的:
X
設X = { x_1, x_2, ..., x_n }
, x_1 <= x_2 <= ... <= x_n
。 left = i = 1
並right
設置,使得C_b = { x_{i + 1}, ..., x_right }
。 設置count = 0
1
到n
迭代i
。 在每次迭代中,我們找出有效三元組(a,b,c)
的數量,其中b = x_i
。 要做到這一點,盡可能多地left
和right
增加,以便A_b = { x_left, ..., x_{i-1} }
和C_b = { x_{i + 1}, ..., x_right }
仍然成立。 在此過程中,您基本上添加和刪除虛構集A_b
和C_b
。 如果刪除或元素添加到集合中的一個,檢查有多少對(a, c)
與c > a + d
, a
從A_b
和c
從C_b
您添加或破壞(這可以通過一個簡單的二進制搜索實現在另一組)。 相應地更新count
,使得不變count = |{ (x,y) : A_b X C_b | y > x + d }|
count = |{ (x,y) : A_b X C_b | y > x + d }|
仍然持有。 count
值。 這是最終結果。 復雜度為O(n log n)
。
如果要使用此算法解決Euler問題,則必須避免出現浮點問題。 我建議使用僅使用整數算術的自定義比較函數(使用2D矢量幾何)按角度對點進行排序。 實現|ab| < d
|ab| < d
比較也可以僅使用整數運算來完成。 此外,由於您正在使用模2*pi
,您可能需要引入每個角度a
三個副本a
: a - 2*pi
, a
和a + 2*pi
。 然后,您只在[0, 2*pi)
范圍內查找b
並將結果除以3。
UPDATE OP在Python中實現了這個算法。 顯然它包含一些錯誤,但它表明了一般的想法:
def count(X, d):
X.sort()
count = 0
s = 0
length = len(X)
a_l = 0
a_r = 1
c_l = 0
c_r = 0
for b in X:
if X[a_r-1] < b:
# find boundaries of A s.t. b -d < a < b
while a_r < length and X[a_r] < b:
a_r += 1 # This adds an element to A_b.
ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
if c_l <= ind < c_r:
count += (ind - c_l)
while a_l < length and X[a_l] <= b - d:
a_l += 1 # This removes an element from A_b
ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
if c_l <= ind < c_r:
count -= (c_r - ind)
# Find boundaries of C s.t. b < c < b + d
while c_l < length and X[c_l] <= b:
c_l += 1 # this removes an element from C_b
ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
if a_l <= ind <= a_r:
count -= (ind - a_l)
while c_r < length and X[c_r] < b + d:
c_r += 1 # this adds an element to C_b
ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
if a_l <= ind <= a_r:
count += (ind - a_l)
s += count
return s
由於l
已排序且a < b < c
必須為true,因此您可以使用itertools.combinations()
來執行更少的循環:
sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
查看組合僅將此循環減少到816次迭代。
>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> d = 4
>>> sum(1 for a, b, c in combinations(l, r=3))
816
>>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
32
其中a < b
測試是多余的。
1)為了減少每個級別上的迭代次數,您可以從列表中刪除不在每個級別上傳遞條件的元素
2)使用set
with collections.counter
您可以通過刪除重復項來減少迭代:
from collections import Counter
def count(l, d):
n = Counter(l)
l = set(l)
s = 0
for a in l:
for b in (i for i in l if a < i < a+d):
for c in (i for i in l if a+d < i < b+d):
s += (n[a] * n[b] * n[c])
return s
>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> count(l, 4)
32
測試版本的迭代次數(a,b,c):
>>> count1(l, 4)
18 324 5832
我的版本:
>>> count2(l, 4)
9 16 7
基本思路是:
因此,您可以無條件地增加s,性能大致為O(N),N是數組的大小。
import collections
def count(l, d):
s = 0
# at first we get rid of repeated items
counter = collections.Counter(l)
# sort the list
uniq = sorted(set(l))
n = len(uniq)
# kad is the index of the first element > a+d
kad = 0
# ka is the index of a
for ka in range(n):
a = uniq[ka]
while uniq[kad] <= a+d:
kad += 1
if kad == n:
return s
for kb in range( ka+1, kad ):
# b only runs in the range [a..a+d)
b = uniq[kb]
if b >= a+d:
break
for kc in range( kad, n ):
# c only rund from (a+d..b+d)
c = uniq[kc]
if c >= b+d:
break
print( a, b, c )
s += counter[a] * counter[b] * counter[c]
return s
編輯:對不起,我搞砸了提交。 固定。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.