简体   繁体   English

有效地生成所有排列

[英]Generating all permutations efficiently

I need to generate as fast as possible all permutations of integers 0 , 1 , 2 , ... , n - 1 and have result as a NumPy array of shape (factorial(n), n) , or to iterate through large portions of such an array to conserve memory.我需要尽可能快地生成整数012...n - 1的所有排列,并将结果作为形状为(factorial(n), n)NumPy数组,或者遍历大部分这样一个数组来保存 memory。

Is there some built-in function in NumPy for doing this? NumPy 中是否有一些内置的 function 用于执行此操作? Or some combination of functions.或者一些功能的组合。

Using itertools.permutations(...) is too slow, I need a faster method.使用itertools.permutations(...)太慢了,我需要一种更快的方法。

Update: faster version更新:更快的版本

This solution is about 5x faster than my original answer, due to avoiding unnecessary computation.由于避免了不必要的计算,这个解决方案比我原来的答案快了大约 5 倍。 This approach builds an array by expanding from one corner, using the same basic idea explained in superb rain's answer , but is much faster since it avoids unnecessary work.这种方法通过从一个角扩展来构建一个数组,使用在superrain's answer中解释的相同基本思想,但速度要快得多,因为它避免了不必要的工作。

def faster_permutations(n):
    # empty() is fast because it does not initialize the values of the array
    # order='F' uses Fortran ordering, which makes accessing elements in the same column fast
    perms = np.empty((np.math.factorial(n), n), dtype=np.uint8, order='F')
    perms[0, 0] = 0

    rows_to_copy = 1
    for i in range(1, n):
        perms[:rows_to_copy, i] = i
        for j in range(1, i + 1):
            start_row = rows_to_copy * j
            end_row = rows_to_copy * (j + 1)
            splitter = i - j
            perms[start_row: end_row, splitter] = i
            perms[start_row: end_row, :splitter] = perms[:rows_to_copy, :splitter]  # left side
            perms[start_row: end_row, splitter + 1:i + 1] = perms[:rows_to_copy, splitter:i]  # right side

        rows_to_copy *= i + 1

    return perms

Timings on my machine with n=11 :我的机器上的计时n=11

faster_permutations():                          0.12 seconds 
permutations() [superb rain's approach]:        1.44 seconds
permutations() with memory order optimization:  0.62 seconds

Original Answer:原答案:

Based on superb rain's answer , this is a faster version with more efficient memory access patterns:基于 superrain的回答,这是一个更快的版本,具有更高效的 memory 访问模式:

def fast_permutations(n):
    a = np.zeros((n, np.math.factorial(n)), np.uint8)
    f = 1
    for m in range(2, n + 1):
        b = a[n - m + 1:, :f]  # the block of permutations of range(m-1)
        for i in range(1, m):
            a[n - m, i * f:(i + 1) * f] = i
            a[n - m + 1:, i * f:(i + 1) * f] = b + (b >= i)
        b += 1
        f *= m
    return a.T

This is essentially the transpose of superb rain's version.这本质上是 supervisor rain 版本的转置。 It's more efficient because the memory locations accessed are closer together.它的效率更高,因为访问的 memory 个位置靠得更近。

On my machine, it's about 2x as fast as the original version (0.05 seconds vs 0.12 seconds for n=10 ).在我的机器上,它的速度大约是原始版本的 2 倍(0.05 秒对n=10的 0.12 秒)。

As I didn't find a good/fast-enough solution, I decided to implement whole permutations algorithm from scratch using Numba JIT / AOT code compiler/optimizer.由于我没有找到一个好的/足够快的解决方案,我决定使用Numba JIT / AOT代码编译器/优化器从头开始实现整个排列算法。

My next numba-based solution is 25x-50x times faster for large enough n than doing same task using itertools.permutations(...) .对于足够大的n我的下一个基于 numba 的解决方案比使用itertools.permutations(...)执行相同的任务快25x-50x倍。 See timings after code.查看代码后的时间。

If iterating by 1 permutation at a time my code is just 1.25x faster than itertools.permutations(...) , but according to initial question I needed either whole array of all permutations or at least iterating over large chunks.如果一次迭代 1 个排列,我的代码仅比itertools.permutations(...)1.25x ,但根据最初的问题,我需要所有排列的整个数组或至少迭代大块。

I've implemented possibility to both use numba and no-numba mode and both JIT and AOT variants in numba mode.我已经实现了在 numba 模式下使用 numba 和 no-numba 模式以及 JIT 和 AOT 变体的可能性。 Also it is possible to choose whether to iterate by one permutation at a time ( iter_ = True, iter_batches = False ) or by batch of permutations at a time which is much faster ( iter_ = True, iter_batches = True ) or to return whole array of all permutations without iteration ( iter_ = False ).此外,还可以选择是一次迭代一个排列( iter_ = True, iter_batches = False )还是一次迭代一批排列更快( iter_ = True, iter_batches = True )或返回整个数组没有迭代的所有排列( iter_ = False )。 Also it is possible to tweak batch size eg by batch_size = 1000 .也可以调整批量大小,例如通过batch_size = 1000

Central internal function is next_batch(...) which actually implements whole algorithm of generating next permutations given previous one.中央内部函数是next_batch(...) ,它实际上实现了在给定前一个排列的情况下生成下一个排列的整个算法。 It is the only JITed/AOTed by numba function, the rest are helper pure-Python wrappers.它是 numba 函数中唯一的 JITed/AOTed,其余的是辅助纯 Python 包装器。

My timings are not very precise as my laptop's CPU slows down at random points of time 2.2x times when over-heated (which happens often).我的计时不是很精确,因为我的笔记本电脑的 CPU 在过热时(经常发生)在随机时间点减速2.2x倍。

Try it online!在线试试吧!

# Needs: python -m pip install numba numpy timerit

def permutations(
    n, *, iter_ = True, numba_ = True, numba_aot = False,
    batch_size = 1000, iter_batches = False, state = {},
):
    key = (bool(numba_), bool(numba_aot))
    
    if key in state:
        return state[key](int(n), bool(iter_), int(batch_size), bool(iter_batches))
        
    def prepare(numba_, numba_aot):
        import numpy as np
        
        def next_batch(a, r):
            c, n = r.shape[0], r.shape[1]
            for ic in range(c):
                r[ic] = a
                a = r[ic]
                for i in range(n - 2, -1, -1):
                    if a[i] < a[i + 1]:
                        break
                else:
                    assert False # Already last permutation
                for j in range(n - 1, i, -1):
                    if a[i] < a[j]:
                        break
                a[i], a[j] = a[j], a[i]
                for k in range(1, (n - i + 1) >> 1):
                    a[i + k], a[n - k] = a[n - k], a[i + k]
            
        def factorial(n):
            res = 1
            for i in range(2, n + 1):
                res *= i
            return res
            
        def permutations_iter(nxb, n, batch_size, iter_batches):
            a = np.arange(n, dtype = np.uint8)
            if iter_batches:
                yield a[None, :]
            else:
                yield a
            if n <= 1:
                return
            total = factorial(n)
            for i in range(1, total, batch_size):
                batch = np.empty((min(batch_size, total - i), n), dtype = np.uint8)
                nxb(a, batch)
                if iter_batches:
                    yield batch
                else:
                    yield from iter(batch)
                a = batch[-1]
        
        def permutations_arr(nxb, n, batch_size):
            total = factorial(n)
            res = np.empty((total, n), dtype = np.uint8)
            res[0] = np.arange(n, dtype = np.uint8)
            for i in range(1, total, batch_size):
                nxb(res[i - 1], res[i : i + min(batch_size, total - i)])
            return res

        if not numba_:
            return lambda n, it, bs, ib: permutations_iter(next_batch, n, bs, ib) if it else permutations_arr(next_batch, n, bs)
        else:
            if not numba_aot:
                import numba
                nxb = numba.njit('void(u1[:], u1[:, :])', cache = True)(next_batch)
            else:
                import numba, numba.pycc
                cc = numba.pycc.CC('permutations_numba')
                cc.export('next_batch', 'void(u1[:], u1[:, :])')(next_batch)
                cc.compile()
                from permutations_numba import next_batch as nxb
                
            return lambda n, it, bs, ib: permutations_iter(nxb, n, bs, ib) if it else permutations_arr(nxb, n, bs)
            
    state[key] = prepare(numba_, numba_aot)
    return state[key](int(n), bool(iter_), int(batch_size), bool(iter_batches))

def test():
    import numpy as np, itertools
    from timerit import Timerit
    
    Timerit._default_asciimode = True

    # Heat-up / pre-compile
    permutations(2, numba_ = False)
    permutations(2, numba_ = True)

    for n in range(12):
        num = 99 if n <= 7 else 15 if n <= 8 else 3 if n <= 9 else 1
        print('-' * 60 + f'\nn = {str(n).rjust(2)}')

        print(f'itertools          : ', end = '', flush = True)
        for t in Timerit(num = num, verbose = 1):
            with t:
                ref = np.array(list(itertools.permutations(range(n))), dtype = np.uint8)

        if n <= 9:
            print(f'python_array       : ', end = '', flush = True)
            for t in Timerit(num = num, verbose = 1):
                with t:
                    curpa = permutations(n, iter_ = False, numba_ = False)
                assert np.array_equal(ref, curpa)
        
        for batch_size in [10, 100, 1000, 10000]:
            print(f'batch_size = {str(batch_size).rjust(5)}')
        
            print(f'numba_iter         : ', end = '', flush = True)
            for t in Timerit(num = num, verbose = 1):
                with t:
                    curi = np.array(list(permutations(n, iter_ = True, numba_ = True, batch_size = batch_size)))
                assert np.array_equal(ref, curi)
                
            print(f'numba_iter_batches : ', end = '', flush = True)
            for t in Timerit(num = num, verbose = 1):
                with t:
                    curib = np.concatenate(list(permutations(n, iter_ = True, numba_ = True, batch_size = batch_size, iter_batches = True)))
                assert np.array_equal(ref, curib)

            print(f'numba_array        : ', end = '', flush = True)
            for t in Timerit(num = num, verbose = 1):
                with t:
                    cura = permutations(n, iter_ = False, numba_ = True, batch_size = batch_size)
                assert np.array_equal(ref, cura)
        
if __name__ == '__main__':
    test()

Output (timings):输出(时间):

------------------------------------------------------------
n =  0
itertools          : Timed best=8.210 us, mean=8.335 +- 0.4 us
python_array       : Timed best=14.881 us, mean=15.457 +- 0.5 us
batch_size =    10
numba_iter         : Timed best=15.908 us, mean=16.126 +- 0.3 us
numba_iter_batches : Timed best=17.447 us, mean=17.929 +- 0.3 us
numba_array        : Timed best=15.394 us, mean=15.519 +- 0.3 us
batch_size =   100
numba_iter         : Timed best=15.908 us, mean=16.250 +- 0.3 us
numba_iter_batches : Timed best=17.447 us, mean=18.038 +- 0.2 us
numba_array        : Timed best=15.394 us, mean=15.519 +- 0.3 us
batch_size =  1000
numba_iter         : Timed best=15.908 us, mean=16.328 +- 0.3 us
numba_iter_batches : Timed best=17.960 us, mean=18.069 +- 0.2 us
numba_array        : Timed best=15.394 us, mean=15.441 +- 0.1 us
batch_size = 10000
numba_iter         : Timed best=15.908 us, mean=16.328 +- 0.2 us
numba_iter_batches : Timed best=17.448 us, mean=17.976 +- 0.2 us
numba_array        : Timed best=14.881 us, mean=15.410 +- 0.3 us
------------------------------------------------------------
n =  1
itertools          : Timed best=7.697 us, mean=7.790 +- 0.3 us
python_array       : Timed best=14.882 us, mean=15.488 +- 0.3 us
batch_size =    10
numba_iter         : Timed best=15.908 us, mean=16.064 +- 0.3 us
numba_iter_batches : Timed best=17.960 us, mean=18.318 +- 0.3 us
numba_array        : Timed best=14.881 us, mean=15.348 +- 0.3 us
batch_size =   100
numba_iter         : Timed best=15.908 us, mean=16.203 +- 0.3 us
numba_iter_batches : Timed best=17.960 us, mean=18.054 +- 0.2 us
numba_array        : Timed best=15.394 us, mean=15.472 +- 0.2 us
batch_size =  1000
numba_iter         : Timed best=15.908 us, mean=16.421 +- 0.1 us
numba_iter_batches : Timed best=17.960 us, mean=18.147 +- 0.3 us
numba_array        : Timed best=14.882 us, mean=15.379 +- 0.2 us
batch_size = 10000
numba_iter         : Timed best=15.908 us, mean=16.095 +- 0.2 us
numba_iter_batches : Timed best=17.960 us, mean=18.132 +- 0.3 us
numba_array        : Timed best=14.881 us, mean=15.395 +- 0.3 us
------------------------------------------------------------
n =  2
itertools          : Timed best=8.723 us, mean=8.786 +- 0.2 us
python_array       : Timed best=29.250 us, mean=29.670 +- 0.4 us
batch_size =    10
numba_iter         : Timed best=34.381 us, mean=35.035 +- 0.7 us
numba_iter_batches : Timed best=30.276 us, mean=30.790 +- 0.4 us
numba_array        : Timed best=22.579 us, mean=22.672 +- 0.2 us
batch_size =   100
numba_iter         : Timed best=34.381 us, mean=34.584 +- 0.3 us
numba_iter_batches : Timed best=30.277 us, mean=30.836 +- 0.2 us
numba_array        : Timed best=22.066 us, mean=22.595 +- 0.2 us
batch_size =  1000
numba_iter         : Timed best=34.381 us, mean=34.739 +- 0.4 us
numba_iter_batches : Timed best=30.277 us, mean=30.851 +- 0.3 us
numba_array        : Timed best=22.579 us, mean=22.626 +- 0.1 us
batch_size = 10000
numba_iter         : Timed best=34.381 us, mean=34.786 +- 0.4 us
numba_iter_batches : Timed best=30.276 us, mean=30.650 +- 0.3 us
numba_array        : Timed best=22.066 us, mean=22.641 +- 0.3 us
------------------------------------------------------------
n =  3
itertools          : Timed best=12.829 us, mean=13.093 +- 0.3 us
python_array       : Timed best=62.606 us, mean=63.461 +- 0.6 us
batch_size =    10
numba_iter         : Timed best=39.513 us, mean=40.120 +- 0.4 us
numba_iter_batches : Timed best=31.302 us, mean=31.661 +- 0.2 us
numba_array        : Timed best=22.579 us, mean=23.077 +- 0.3 us
batch_size =   100
numba_iter         : Timed best=39.513 us, mean=40.042 +- 0.2 us
numba_iter_batches : Timed best=31.302 us, mean=31.629 +- 0.3 us
numba_array        : Timed best=22.579 us, mean=23.154 +- 0.2 us
batch_size =  1000
numba_iter         : Timed best=39.513 us, mean=39.840 +- 0.4 us
numba_iter_batches : Timed best=31.302 us, mean=31.629 +- 0.4 us
numba_array        : Timed best=22.579 us, mean=23.170 +- 0.2 us
batch_size = 10000
numba_iter         : Timed best=39.513 us, mean=40.120 +- 0.5 us
numba_iter_batches : Timed best=30.789 us, mean=31.412 +- 0.3 us
numba_array        : Timed best=23.092 us, mean=23.232 +- 0.3 us
------------------------------------------------------------
n =  4
itertools          : Timed best=34.381 us, mean=34.911 +- 0.4 us
python_array       : Timed best=207.830 us, mean=209.152 +- 1.0 us
batch_size =    10
numba_iter         : Timed best=82.619 us, mean=83.054 +- 0.7 us
numba_iter_batches : Timed best=44.645 us, mean=44.754 +- 0.2 us
numba_array        : Timed best=31.302 us, mean=31.458 +- 0.2 us
batch_size =   100
numba_iter         : Timed best=63.632 us, mean=64.036 +- 0.4 us
numba_iter_batches : Timed best=32.329 us, mean=32.889 +- 0.2 us
numba_array        : Timed best=24.118 us, mean=24.600 +- 0.3 us
batch_size =  1000
numba_iter         : Timed best=63.632 us, mean=64.083 +- 0.5 us
numba_iter_batches : Timed best=32.329 us, mean=32.904 +- 0.3 us
numba_array        : Timed best=24.118 us, mean=24.569 +- 0.3 us
batch_size = 10000
numba_iter         : Timed best=63.119 us, mean=63.927 +- 0.4 us
numba_iter_batches : Timed best=32.329 us, mean=32.889 +- 0.5 us
numba_array        : Timed best=24.118 us, mean=24.461 +- 0.3 us
------------------------------------------------------------
n =  5
itertools          : Timed best=156.001 us, mean=166.311 +- 20.5 us
python_array       : Timed best=0.999 ms, mean=1.002 +- 0.0 ms
batch_size =    10
numba_iter         : Timed best=293.528 us, mean=294.461 +- 0.8 us
numba_iter_batches : Timed best=102.632 us, mean=103.254 +- 0.4 us
numba_array        : Timed best=64.145 us, mean=64.985 +- 0.5 us
batch_size =   100
numba_iter         : Timed best=198.080 us, mean=199.107 +- 0.8 us
numba_iter_batches : Timed best=44.132 us, mean=44.894 +- 0.4 us
numba_array        : Timed best=33.355 us, mean=33.884 +- 0.3 us
batch_size =  1000
numba_iter         : Timed best=186.791 us, mean=187.522 +- 0.4 us
numba_iter_batches : Timed best=37.973 us, mean=38.471 +- 0.3 us
numba_array        : Timed best=29.763 us, mean=30.183 +- 0.3 us
batch_size = 10000
numba_iter         : Timed best=186.790 us, mean=187.646 +- 0.7 us
numba_iter_batches : Timed best=37.974 us, mean=38.534 +- 0.3 us
numba_array        : Timed best=29.763 us, mean=30.245 +- 0.3 us
------------------------------------------------------------
n =  6
itertools          : Timed best=0.991 ms, mean=1.007 +- 0.0 ms
python_array       : Timed best=5.873 ms, mean=6.012 +- 0.0 ms
batch_size =    10
numba_iter         : Timed best=1.668 ms, mean=1.673 +- 0.0 ms
numba_iter_batches : Timed best=503.411 us, mean=506.506 +- 1.2 us
numba_array        : Timed best=293.015 us, mean=296.047 +- 1.2 us
batch_size =   100
numba_iter         : Timed best=1.036 ms, mean=1.145 +- 0.3 ms
numba_iter_batches : Timed best=120.593 us, mean=132.878 +- 23.0 us
numba_array        : Timed best=93.908 us, mean=97.438 +- 2.4 us
batch_size =  1000
numba_iter         : Timed best=962.178 us, mean=976.624 +- 23.9 us
numba_iter_batches : Timed best=78.001 us, mean=82.992 +- 7.7 us
numba_array        : Timed best=68.250 us, mean=69.852 +- 4.3 us
batch_size = 10000
numba_iter         : Timed best=963.717 us, mean=977.044 +- 27.3 us
numba_iter_batches : Timed best=77.487 us, mean=80.084 +- 7.5 us
numba_array        : Timed best=68.250 us, mean=69.634 +- 4.4 us
------------------------------------------------------------
n =  7
itertools          : Timed best=8.502 ms, mean=8.579 +- 0.0 ms
python_array       : Timed best=41.690 ms, mean=42.358 +- 0.8 ms
batch_size =    10
numba_iter         : Timed best=11.523 ms, mean=11.646 +- 0.2 ms
numba_iter_batches : Timed best=3.407 ms, mean=3.497 +- 0.1 ms
numba_array        : Timed best=1.944 ms, mean=1.975 +- 0.0 ms
batch_size =   100
numba_iter         : Timed best=7.050 ms, mean=7.397 +- 0.3 ms
numba_iter_batches : Timed best=659.925 us, mean=668.198 +- 5.9 us
numba_array        : Timed best=503.411 us, mean=506.086 +- 3.3 us
batch_size =  1000
numba_iter         : Timed best=6.576 ms, mean=6.630 +- 0.0 ms
numba_iter_batches : Timed best=382.305 us, mean=389.707 +- 4.4 us
numba_array        : Timed best=354.081 us, mean=360.364 +- 4.3 us
batch_size = 10000
numba_iter         : Timed best=6.463 ms, mean=6.504 +- 0.0 ms
numba_iter_batches : Timed best=349.976 us, mean=352.091 +- 1.5 us
numba_array        : Timed best=330.989 us, mean=337.194 +- 1.8 us
------------------------------------------------------------
n =  8
itertools          : Timed best=71.003 ms, mean=71.824 +- 0.5 ms
python_array       : Timed best=331.176 ms, mean=339.746 +- 7.3 ms
batch_size =    10
numba_iter         : Timed best=99.929 ms, mean=101.098 +- 1.3 ms
numba_iter_batches : Timed best=27.489 ms, mean=27.905 +- 0.3 ms
numba_array        : Timed best=15.370 ms, mean=15.560 +- 0.1 ms
batch_size =   100
numba_iter         : Timed best=62.168 ms, mean=62.765 +- 0.7 ms
numba_iter_batches : Timed best=5.083 ms, mean=5.119 +- 0.0 ms
numba_array        : Timed best=3.824 ms, mean=3.842 +- 0.0 ms
batch_size =  1000
numba_iter         : Timed best=57.706 ms, mean=57.935 +- 0.2 ms
numba_iter_batches : Timed best=2.824 ms, mean=2.832 +- 0.0 ms
numba_array        : Timed best=2.656 ms, mean=2.670 +- 0.0 ms
batch_size = 10000
numba_iter         : Timed best=57.457 ms, mean=60.128 +- 2.1 ms
numba_iter_batches : Timed best=2.615 ms, mean=2.635 +- 0.0 ms
numba_array        : Timed best=2.550 ms, mean=2.565 +- 0.0 ms
------------------------------------------------------------
n =  9
itertools          : Timed best=724.017 ms, mean=724.017 +- 0.0 ms
python_array       : Timed best=3.071 s, mean=3.071 +- 0.0 s
batch_size =    10
numba_iter         : Timed best=950.892 ms, mean=950.892 +- 0.0 ms
numba_iter_batches : Timed best=261.376 ms, mean=261.376 +- 0.0 ms
numba_array        : Timed best=145.207 ms, mean=145.207 +- 0.0 ms
batch_size =   100
numba_iter         : Timed best=584.761 ms, mean=584.761 +- 0.0 ms
numba_iter_batches : Timed best=50.632 ms, mean=50.632 +- 0.0 ms
numba_array        : Timed best=39.945 ms, mean=39.945 +- 0.0 ms
batch_size =  1000
numba_iter         : Timed best=535.190 ms, mean=535.190 +- 0.0 ms
numba_iter_batches : Timed best=29.557 ms, mean=29.557 +- 0.0 ms
numba_array        : Timed best=26.541 ms, mean=26.541 +- 0.0 ms
batch_size = 10000
numba_iter         : Timed best=533.592 ms, mean=533.592 +- 0.0 ms
numba_iter_batches : Timed best=27.507 ms, mean=27.507 +- 0.0 ms
numba_array        : Timed best=25.115 ms, mean=25.115 +- 0.0 ms
------------------------------------------------------------
n = 10
itertools          : Timed best=15.483 s, mean=15.483 +- 0.0 s
batch_size =    10
numba_iter         : Timed best=24.163 s, mean=24.163 +- 0.0 s
numba_iter_batches : Timed best=6.039 s, mean=6.039 +- 0.0 s
numba_array        : Timed best=3.246 s, mean=3.246 +- 0.0 s
batch_size =   100
numba_iter         : Timed best=13.891 s, mean=13.891 +- 0.0 s
numba_iter_batches : Timed best=1.136 s, mean=1.136 +- 0.0 s
numba_array        : Timed best=890.228 ms, mean=890.228 +- 0.0 ms
batch_size =  1000
numba_iter         : Timed best=12.768 s, mean=12.768 +- 0.0 s
numba_iter_batches : Timed best=693.685 ms, mean=693.685 +- 0.0 ms
numba_array        : Timed best=658.007 ms, mean=658.007 +- 0.0 ms
batch_size = 10000
numba_iter         : Timed best=11.175 s, mean=11.175 +- 0.0 s
numba_iter_batches : Timed best=278.304 ms, mean=278.304 +- 0.0 ms
numba_array        : Timed best=251.208 ms, mean=251.208 +- 0.0 ms
------------------------------------------------------------
n = 11
itertools          : Timed best=95.118 s, mean=95.118 +- 0.0 s
batch_size =    10
numba_iter         : Timed best=124.414 s, mean=124.414 +- 0.0 s
numba_iter_batches : Timed best=75.427 s, mean=75.427 +- 0.0 s
numba_array        : Timed best=28.079 s, mean=28.079 +- 0.0 s
batch_size =   100
numba_iter         : Timed best=70.749 s, mean=70.749 +- 0.0 s
numba_iter_batches : Timed best=6.084 s, mean=6.084 +- 0.0 s
numba_array        : Timed best=4.357 s, mean=4.357 +- 0.0 s
batch_size =  1000
numba_iter         : Timed best=67.576 s, mean=67.576 +- 0.0 s
numba_iter_batches : Timed best=8.572 s, mean=8.572 +- 0.0 s
numba_array        : Timed best=6.915 s, mean=6.915 +- 0.0 s
batch_size = 10000
numba_iter         : Timed best=123.208 s, mean=123.208 +- 0.0 s
numba_iter_batches : Timed best=3.348 s, mean=3.348 +- 0.0 s
numba_array        : Timed best=2.789 s, mean=2.789 +- 0.0 s

Here's a NumPy solution that builds the permutations of size m by modifying the permutations of size m-1 (see more explanation further down):这是一个 NumPy 解决方案,它通过修改大小为 m-1 的排列来构建大小为 m 的排列(请参阅下面的更多解释):

def permutations(n):
    a = np.zeros((np.math.factorial(n), n), np.uint8)
    f = 1
    for m in range(2, n+1):
        b = a[:f, n-m+1:]      # the block of permutations of range(m-1)
        for i in range(1, m):
            a[i*f:(i+1)*f, n-m] = i
            a[i*f:(i+1)*f, n-m+1:] = b + (b >= i)
        b += 1
        f *= m
    return a

Demo:演示:

>>> permutations(3)
array([[0, 1, 2],
       [0, 2, 1],
       [1, 0, 2],
       [1, 2, 0],
       [2, 0, 1],
       [2, 1, 0]], dtype=uint8)

For n=10, the itertools solution takes 5.5 seconds for me while this NumPy solution takes 0.2 seconds.对于 n=10,itertools 解决方案对我来说需要 5.5 秒,而这个 NumPy 解决方案需要 0.2 秒。

How it proceeds: It starts with a zero-array of the goal size, which already contains the permutations for range(1) at the top-right (I "dotted out" the other parts of the array):它是如何进行的:它从一个目标大小的零数组开始,它已经包含了右上方range(1)的排列(我“点缀”了数组的其他部分):

[[. . 0]
 [. . .]
 [. . .]
 [. . .]
 [. . .]
 [. . .]]

Then it turns that into the permutations of range(2) :然后将其转换为range(2)的排列:

[[. 0 1]
 [. 1 0]
 [. . .]
 [. . .]
 [. . .]
 [. . .]]

And then into the permutations of range(3) :然后进入range(3)的排列:

[[0 1 2]
 [0 2 1]
 [1 0 2]
 [1 2 0]
 [2 0 1]
 [2 1 0]]

It does so by filling the next-left column and by copying/modifying the previous block of permutations downwards.它通过填充下一个左列并向下复制/修改前一个排列块来实现。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM