簡體   English   中英

如何計算組合數?

[英]How to count number of combinations?

我有問題,我想計算滿足以下條件的組合數:

 a < b < a+d < c < b+d

其中a, b, c是列表的元素, d是固定的delta。

這是一個香草實現:

def count(l, d):
    s = 0
    for a in l:
        for b in l:
            for c in l:
                if a < b < a + d < c < b + d:
                    s += 1
    return s

這是一個測試:

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    assert(32 == count(l, 4)) # Gone through everything by hand.

我怎樣才能加快速度呢? 我正在查看200萬的列表大小。

補充資料

我正在處理[-pi,pi]范圍內的浮點數。 例如,這限制a < 0

到目前為止我所擁有的:

我有一些實現,我建立了我用於bc索引。 但是,以下代碼在某些情況下失敗。 (即這是錯誤的 )。

def count(l, d=pi):
    low = lower(l, d)
    high = upper(l, d)
    s = 0
    for indA in range(len(l)):
            for indB in range(indA+1, low[indA]+1):
                    s += low[indB] + 1 - high[indA]
    return s

def lower(l, d=pi):
    '''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
    Input must be sorted!
    '''
    ind = []
    x = 0
    length = len(l)
    for  elem in l:
        while x < length and l[x] < elem + d:
            x += 1
        if l[x-1] < elem + d:
            ind.append(x-1)
        else:
            assert(x == length)
            ind.append(x)
    return ind


def upper(l, d=pi):
    ''' Returns first index where l[i] > l + d'''
    ind = []
    x = 0
    length = len(l)
    for elem in l:
        while x < length and l[x] <= elem + d:
            x += 1
        ind.append(x)
    return ind

原始問題

最初的問題來自眾所周知的數學/綜合競賽。 比賽要求您不要在網上發布解決方案。 但它是從兩周前開始的。

我可以用這個函數生成列表:

def points(n):
    x = 1
    y = 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = points(n)
    angles.sort()
    return count(angles, pi)
from bisect import bisect_left, bisect_right
from collections import Counter

def count(l, d):
    # cdef long bleft, bright, cleft, cright, ccount, s
    s = 0

    # Find the unique elements and their counts
    cc = Counter(l)

    l = sorted(cc.keys())

    # Generate a cumulative sum array
    cumulative = [0] * (len(l) + 1)
    for i, key in enumerate(l, start=1):
        cumulative[i] = cumulative[i-1] + cc[key]

    # Pregenerate all the left and right lookups
    lefthand = [bisect_right(l, a + d) for a in l]
    righthand = [bisect_left(l, a + d) for a in l]

    aright = bisect_left(l, l[-1] - d)
    for ai in range(len(l)):
        bleft = ai + 1
        # Search only the values of a that have a+d in range
        if bleft > aright:
            break
        # This finds b such that a < b < a + d.
        bright = righthand[ai]
        for bi in range(bleft, bright):
            # This finds the range for c such that a+d < c < b+d.
            cleft = lefthand[ai]
            cright = righthand[bi]
            if cleft != cright:
                # Find the count of c elements in the range cleft..cright.
                ccount = cumulative[cright] - cumulative[cleft]
                s += cc[l[ai]] * cc[l[bi]] * ccount
    return s

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    result = count(l, 4)
    assert(32 == result)

testCount()
  1. 擺脫重復的,相同的價值觀

  2. 僅迭代值的所需范圍

  3. 使用兩個索引的累積計數來消除c的循環

  4. x + d上緩存查找

這不再是O(n^3)而是更像O(n ^ 2)`。

這顯然還沒有達到200萬。 以下是使用cython加速執行的較小浮點數據集(即很少或沒有重復)的時間:

50: 0:00:00.157849 seconds
100: 0:00:00.003752 seconds
200: 0:00:00.022494 seconds
400: 0:00:00.071192 seconds
800: 0:00:00.253750 seconds
1600: 0:00:00.951133 seconds
3200: 0:00:03.508596 seconds
6400: 0:00:10.869102 seconds
12800: 0:00:55.986448 seconds

這是我的基准代碼(不包括上面的操作代碼):

from math import atan2, pi

def points(n):
    x, y = 1, 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = sorted(points(n))
    return count(angles, pi)

def test_large():
    from datetime import datetime
    for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
        s = datetime.now()
        C(n)
        elapsed = datetime.now() - s
        print("{1}: {0} seconds".format(elapsed, n))

if __name__ == '__main__':
    testCount()
    test_large()

有一種方法可以產生O(n log n)算法。 X是一組值。 現在讓我們修復b A_b為值的集合{ x in X: b - d < x < b }並且C_b是值的集合{ x in X: b < x < b + d } 如果我們能找到|{ (x,y) : A_b X C_b | y > x + d }| |{ (x,y) : A_b X C_b | y > x + d }| 快,我們解決了這個問題。

如果我們對X排序,我們可以將A_bC_b表示為排序數組的指針,因為它們是連續的。 如果我們以非遞減順序處理b候選,我們可以使用滑動窗口算法來維護這些集合。 它是這樣的:

  1. 排序X X = { x_1, x_2, ..., x_n }x_1 <= x_2 <= ... <= x_n
  2. 設置left = i = 1right設置,使得C_b = { x_{i + 1}, ..., x_right } 設置count = 0
  3. 1n迭代i 在每次迭代中,我們找出有效三元組(a,b,c)的數量,其中b = x_i 要做到這一點,盡可能多地leftright增加,以便A_b = { x_left, ..., x_{i-1} }C_b = { x_{i + 1}, ..., x_right }仍然成立。 在此過程中,您基本上添加和刪除虛構集A_bC_b 如果刪除或元素添加到集合中的一個,檢查有多少對(a, c)c > a + daA_bcC_b您添加或破壞(這可以通過一個簡單的二進制搜索實現在另一組)。 相應地更新count ,使得不變count = |{ (x,y) : A_b X C_b | y > x + d }| count = |{ (x,y) : A_b X C_b | y > x + d }| 仍然持有。
  4. 總結每次迭代中的count值。 這是最終結果。

復雜度為O(n log n)

如果要使用此算法解決Euler問題,則必須避免出現浮點問題。 我建議使用僅使用整數算術的自定義比較函數(使用2D矢量幾何)按角度對點進行排序。 實現|ab| < d |ab| < d比較也可以僅使用整數運算來完成。 此外,由於您正在使用模2*pi ,您可能需要引入每個角度a三個副本aa - 2*piaa + 2*pi 然后,您只在[0, 2*pi)范圍內查找b並將結果除以3。

UPDATE OP在Python中實現了這個算法。 顯然它包含一些錯誤,但它表明了一般的想法:

def count(X, d):
    X.sort()
    count = 0
    s = 0
    length = len(X)
    a_l = 0
    a_r = 1
    c_l = 0
    c_r = 0
    for b in X:
        if X[a_r-1] < b:
            # find boundaries of A s.t. b -d < a < b
            while a_r < length and X[a_r] < b:
                a_r += 1  # This adds an element to A_b. 
                ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count += (ind - c_l)
            while a_l < length and X[a_l] <= b - d:
                a_l += 1  # This removes an element from A_b
                ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count -= (c_r - ind)
            # Find boundaries of C s.t. b < c < b + d
            while c_l < length and X[c_l] <= b:
                c_l += 1  # this removes an element from C_b
                ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count -= (ind - a_l)
            while c_r  < length and X[c_r] < b + d:
                c_r += 1 # this adds an element to C_b
                ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count += (ind - a_l)
            s += count
    return s

由於l已排序且a < b < c必須為true,因此您可以使用itertools.combinations()來執行更少的循環:

sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)

查看組合僅將此循環減少到816次迭代。

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> d = 4
>>> sum(1 for a, b, c in combinations(l, r=3))
816
>>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
32

其中a < b測試是多余的。

1)為了減少每個級別上的迭代次數,您可以從列表中刪除不在每個級別上傳遞條件的元素
2)使用set with collections.counter您可以通過刪除重復項來減少迭代:

from collections import Counter
def count(l, d):
    n = Counter(l)
    l = set(l)
    s = 0
    for a in l:
        for b in (i for i in l if a < i < a+d):
            for c in (i for i in l if a+d < i < b+d):
                s += (n[a] * n[b] * n[c])
    return s

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> count(l, 4)
32

測試版本的迭代次數(a,b,c):

>>> count1(l, 4)
18 324 5832

我的版本:

>>> count2(l, 4)
9 16 7

基本思路是:

  1. 擺脫重復的,相同的價值觀
  2. 讓每個值僅在它必須迭代的范圍內迭代。

因此,您可以無條件地增加s,性能大致為O(N),N是數組的大小。

import collections

def count(l, d):
    s = 0
    # at first we get rid of repeated items
    counter = collections.Counter(l)
    # sort the list
    uniq = sorted(set(l))
    n = len(uniq)
    # kad is the index of the first element > a+d
    kad = 0 
    # ka is the index of a
    for ka in range(n):
        a = uniq[ka]
        while uniq[kad] <= a+d:
            kad += 1
            if kad == n:
                return s

        for kb in range( ka+1, kad ):
            # b only runs in the range [a..a+d)
            b = uniq[kb]
            if b  >= a+d:
                break
            for kc in range( kad, n ):
                # c only rund from (a+d..b+d)
                c = uniq[kc]
                if c >= b+d:
                    break
                print( a, b, c )
                s += counter[a] * counter[b] * counter[c]
    return s

編輯:對不起,我搞砸了提交。 固定。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM