简体   繁体   English

如何计算组合数?

[英]How to count number of combinations?

I have the problem that I want to count the number of combinations that fulfill the following condition: 我有问题,我想计算满足以下条件的组合数:

 a < b < a+d < c < b+d

Where a, b, c are elements of a list, and d is a fixed delta. 其中a, b, c是列表的元素, d是固定的delta。

Here is a vanilla implementation: 这是一个香草实现:

def count(l, d):
    s = 0
    for a in l:
        for b in l:
            for c in l:
                if a < b < a + d < c < b + d:
                    s += 1
    return s

Here is a test: 这是一个测试:

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    assert(32 == count(l, 4)) # Gone through everything by hand.

Question

How can I speed this up? 我怎样才能加快速度呢? I am looking at list sizes of 2 Million. 我正在查看200万的列表大小。

Supplementary Information 补充资料

I am dealing with floats in the range of [-pi, pi]. 我正在处理[-pi,pi]范围内的浮点数。 For example, this limits a < 0 . 例如,这限制a < 0

What I have so far: 到目前为止我所拥有的:

I have some implementation where I build indices that I use for b and c . 我有一些实现,我建立了我用于bc索引。 However, the below code fails some cases. 但是,以下代码在某些情况下失败。 (ie This is wrong ). (即这是错误的 )。

def count(l, d=pi):
    low = lower(l, d)
    high = upper(l, d)
    s = 0
    for indA in range(len(l)):
            for indB in range(indA+1, low[indA]+1):
                    s += low[indB] + 1 - high[indA]
    return s

def lower(l, d=pi):
    '''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
    Input must be sorted!
    '''
    ind = []
    x = 0
    length = len(l)
    for  elem in l:
        while x < length and l[x] < elem + d:
            x += 1
        if l[x-1] < elem + d:
            ind.append(x-1)
        else:
            assert(x == length)
            ind.append(x)
    return ind


def upper(l, d=pi):
    ''' Returns first index where l[i] > l + d'''
    ind = []
    x = 0
    length = len(l)
    for elem in l:
        while x < length and l[x] <= elem + d:
            x += 1
        ind.append(x)
    return ind

Original Problem 原始问题

The original problem is from a well known math/comp-sci competition. 最初的问题来自众所周知的数学/综合竞赛。 The competition asks that you don't post solutions on the net. 比赛要求您不要在网上发布解决方案。 But it is from two weeks ago. 但它是从两周前开始的。

I can generate the list with this function: 我可以用这个函数生成列表:

def points(n):
    x = 1
    y = 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = points(n)
    angles.sort()
    return count(angles, pi)
from bisect import bisect_left, bisect_right
from collections import Counter

def count(l, d):
    # cdef long bleft, bright, cleft, cright, ccount, s
    s = 0

    # Find the unique elements and their counts
    cc = Counter(l)

    l = sorted(cc.keys())

    # Generate a cumulative sum array
    cumulative = [0] * (len(l) + 1)
    for i, key in enumerate(l, start=1):
        cumulative[i] = cumulative[i-1] + cc[key]

    # Pregenerate all the left and right lookups
    lefthand = [bisect_right(l, a + d) for a in l]
    righthand = [bisect_left(l, a + d) for a in l]

    aright = bisect_left(l, l[-1] - d)
    for ai in range(len(l)):
        bleft = ai + 1
        # Search only the values of a that have a+d in range
        if bleft > aright:
            break
        # This finds b such that a < b < a + d.
        bright = righthand[ai]
        for bi in range(bleft, bright):
            # This finds the range for c such that a+d < c < b+d.
            cleft = lefthand[ai]
            cright = righthand[bi]
            if cleft != cright:
                # Find the count of c elements in the range cleft..cright.
                ccount = cumulative[cright] - cumulative[cleft]
                s += cc[l[ai]] * cc[l[bi]] * ccount
    return s

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    result = count(l, 4)
    assert(32 == result)

testCount()
  1. gets rid of repeated, identical values 摆脱重复的,相同的价值观

  2. iterates over only the required range for a value 仅迭代值的所需范围

  3. uses a cumulative count across two indices to eliminate the loop over c 使用两个索引的累积计数来消除c的循环

  4. cache lookups on x + d x + d上缓存查找

This is no longer O(n^3) but more like O(n^2)`. 这不再是O(n^3)而是更像O(n ^ 2)`。

This clearly does not yet scale up to 2 million. 这显然还没有达到200万。 Here are my times on smaller floating point data sets (ie few or no duplicates) using cython to speed up execution: 以下是使用cython加速执行的较小浮点数据集(即很少或没有重复)的时间:

50: 0:00:00.157849 seconds
100: 0:00:00.003752 seconds
200: 0:00:00.022494 seconds
400: 0:00:00.071192 seconds
800: 0:00:00.253750 seconds
1600: 0:00:00.951133 seconds
3200: 0:00:03.508596 seconds
6400: 0:00:10.869102 seconds
12800: 0:00:55.986448 seconds

Here is my benchmarking code (not including the operative code above): 这是我的基准代码(不包括上面的操作代码):

from math import atan2, pi

def points(n):
    x, y = 1, 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = sorted(points(n))
    return count(angles, pi)

def test_large():
    from datetime import datetime
    for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
        s = datetime.now()
        C(n)
        elapsed = datetime.now() - s
        print("{1}: {0} seconds".format(elapsed, n))

if __name__ == '__main__':
    testCount()
    test_large()

There is an approach to your problem that yields an O(n log n) algorithm. 有一种方法可以产生O(n log n)算法。 Let X be the set of values. X是一组值。 Now let's fix b . 现在让我们修复b Let A_b be the set of values { x in X: b - d < x < b } and C_b be the set of values { x in X: b < x < b + d } . A_b为值的集合{ x in X: b - d < x < b }并且C_b是值的集合{ x in X: b < x < b + d } If we can find |{ (x,y) : A_b X C_b | y > x + d }| 如果我们能找到|{ (x,y) : A_b X C_b | y > x + d }| |{ (x,y) : A_b X C_b | y > x + d }| fast, we solved the problem. 快,我们解决了这个问题。

If we sort X , we can represent A_b and C_b as pointers into the sorted array, because they are contiguous. 如果我们对X排序,我们可以将A_bC_b表示为排序数组的指针,因为它们是连续的。 If we process the b candidates in non-decreasing order, we can thus maintain these sets using a sliding window algorithm . 如果我们以非递减顺序处理b候选,我们可以使用滑动窗口算法来维护这些集合。 It goes like this: 它是这样的:

  1. sort X . 排序X Let X = { x_1, x_2, ..., x_n } , x_1 <= x_2 <= ... <= x_n . X = { x_1, x_2, ..., x_n }x_1 <= x_2 <= ... <= x_n
  2. Set left = i = 1 and set right so that C_b = { x_{i + 1}, ..., x_right } . 设置left = i = 1right设置,使得C_b = { x_{i + 1}, ..., x_right } Set count = 0 设置count = 0
  3. Iterate i from 1 to n . 1n迭代i In every iteration we find out the number of valid triples (a,b,c) with b = x_i . 在每次迭代中,我们找出有效三元组(a,b,c)的数量,其中b = x_i To do that, increase left and right as much as necessary so that A_b = { x_left, ..., x_{i-1} } and C_b = { x_{i + 1}, ..., x_right } still holds. 要做到这一点,尽可能多地leftright增加,以便A_b = { x_left, ..., x_{i-1} }C_b = { x_{i + 1}, ..., x_right }仍然成立。 In the process, you basically add and remove elements from the imaginary sets A_b and C_b . 在此过程中,您基本上添加和删除虚构集A_bC_b If you remove or add an element to one of the sets, check how many pairs (a, c) with c > a + d , a from A_b and c from C_b you add or destroy (this can be achieved by a simple binary search in the other set). 如果删除或元素添加到集合中的一个,检查有多少对(a, c)c > a + daA_bcC_b您添加或破坏(这可以通过一个简单的二进制搜索实现在另一组)。 Update count accordingly so that the invariant count = |{ (x,y) : A_b X C_b | y > x + d }| 相应地更新count ,使得不变count = |{ (x,y) : A_b X C_b | y > x + d }| count = |{ (x,y) : A_b X C_b | y > x + d }| still holds. 仍然持有。
  4. sum up the values of count in every iteration. 总结每次迭代中的count值。 This is the final result. 这是最终结果。

The complexity is O(n log n) . 复杂度为O(n log n)

If you want to solve the Euler problem with this algorithm, you have to avoid floating point issues. 如果要使用此算法解决Euler问题,则必须避免出现浮点问题。 I suggest sorting the points by angle using a custom comparison function that uses integer arithmetics only (using 2D vector geometry). 我建议使用仅使用整数算术的自定义比较函数(使用2D矢量几何)按角度对点进行排序。 Implementing the |ab| < d 实现|ab| < d |ab| < d comparisons can also be done using integer operations only. |ab| < d比较也可以仅使用整数运算来完成。 Also, since you are working modulo 2*pi , you would probably have to introduce three copies of every angle a : a - 2*pi , a and a + 2*pi . 此外,由于您正在使用模2*pi ,您可能需要引入每个角度a三个副本aa - 2*piaa + 2*pi You then only look for b in the range [0, 2*pi) and divide the result by three. 然后,您只在[0, 2*pi)范围内查找b并将结果除以3。

UPDATE OP implemented this algorithm in Python. UPDATE OP在Python中实现了这个算法。 Apparently it contains some bugs but it demonstrates the general idea: 显然它包含一些错误,但它表明了一般的想法:

def count(X, d):
    X.sort()
    count = 0
    s = 0
    length = len(X)
    a_l = 0
    a_r = 1
    c_l = 0
    c_r = 0
    for b in X:
        if X[a_r-1] < b:
            # find boundaries of A s.t. b -d < a < b
            while a_r < length and X[a_r] < b:
                a_r += 1  # This adds an element to A_b. 
                ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count += (ind - c_l)
            while a_l < length and X[a_l] <= b - d:
                a_l += 1  # This removes an element from A_b
                ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count -= (c_r - ind)
            # Find boundaries of C s.t. b < c < b + d
            while c_l < length and X[c_l] <= b:
                c_l += 1  # this removes an element from C_b
                ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count -= (ind - a_l)
            while c_r  < length and X[c_r] < b + d:
                c_r += 1 # this adds an element to C_b
                ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count += (ind - a_l)
            s += count
    return s

Since l is sorted and a < b < c must be true, you could use itertools.combinations() to do fewer loops: 由于l已排序且a < b < c必须为true,因此您可以使用itertools.combinations()来执行更少的循环:

sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)

Looking at combinations only reduces this loop to 816 iterations. 查看组合仅将此循环减少到816次迭代。

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> d = 4
>>> sum(1 for a, b, c in combinations(l, r=3))
816
>>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
32

where the a < b test is redundant. 其中a < b测试是多余的。

1) To reduce amount of iterations on each level you can remove elements from list that dont pass condition on each level 1)为了减少每个级别上的迭代次数,您可以从列表中删除不在每个级别上传递条件的元素
2) Using set with collections.counter you can reduce iterations by removing duplicates: 2)使用set with collections.counter您可以通过删除重复项来减少迭代:

from collections import Counter
def count(l, d):
    n = Counter(l)
    l = set(l)
    s = 0
    for a in l:
        for b in (i for i in l if a < i < a+d):
            for c in (i for i in l if a+d < i < b+d):
                s += (n[a] * n[b] * n[c])
    return s

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> count(l, 4)
32

Tested count of iterations (a, b, c) for your version: 测试版本的迭代次数(a,b,c):

>>> count1(l, 4)
18 324 5832

my version: 我的版本:

>>> count2(l, 4)
9 16 7

The basic ideas are: 基本思路是:

  1. Get rid of repeated, identical values 摆脱重复的,相同的价值观
  2. Have each value iterate only over the range it has to iterate. 让每个值仅在它必须迭代的范围内迭代。

As a result you can increase s unconditionally and the performance is roughly O(N), with N being the size of the array. 因此,您可以无条件地增加s,性能大致为O(N),N是数组的大小。

import collections

def count(l, d):
    s = 0
    # at first we get rid of repeated items
    counter = collections.Counter(l)
    # sort the list
    uniq = sorted(set(l))
    n = len(uniq)
    # kad is the index of the first element > a+d
    kad = 0 
    # ka is the index of a
    for ka in range(n):
        a = uniq[ka]
        while uniq[kad] <= a+d:
            kad += 1
            if kad == n:
                return s

        for kb in range( ka+1, kad ):
            # b only runs in the range [a..a+d)
            b = uniq[kb]
            if b  >= a+d:
                break
            for kc in range( kad, n ):
                # c only rund from (a+d..b+d)
                c = uniq[kc]
                if c >= b+d:
                    break
                print( a, b, c )
                s += counter[a] * counter[b] * counter[c]
    return s

EDIT: Sorry, I messed up the submission. 编辑:对不起,我搞砸了提交。 Fixed. 固定。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM