简体   繁体   English

三元表示中的快速数字总和(Python)

[英]Fast sum of digits in a ternary representation (Python)

I have defined a function我已经定义了一个 function

def enumerateSpin(n):
    s = []
    for a in range(0,3**n):
        ternary_rep = np.base_repr(a,3)
        k = len(ternary_rep)
        r = (n-k)*'0'+ternary_rep
        if sum(map(int,r)) == n:
            s.append(r)
    return s

where I look at a number 0 <= a < 3^N and ask if the sum of its digits in the ternary representation sum up to a certain value.我看一个数字 0 <= a < 3^N 并询问它在三元表示中的数字总和是否达到某个值。 I do this by converting the number into a string of its ternary representation first.我通过首先将数字转换为其三元表示的字符串来做到这一点。 I am padding zeros because I want to store a list of fixed-length representations that I can later use for further computations (ie digit-by-digit comparison between two elements).我正在填充零,因为我想存储一个固定长度表示的列表,以便以后用于进一步计算(即两个元素之间的逐位比较)。

Right now np.base_repr and sum(map(int,#)) take roughly 5 us on my computer respectively, meaning roughly 10 us for an iteration, and I am looking for an approach where you can accomplish what I did but 10 times faster.现在np.base_reprsum(map(int,#))在我的计算机上分别占用大约 5 us,这意味着大约 10 us 进行迭代,我正在寻找一种方法,您可以完成我所做的工作,但速度快 10 倍.

(Edit: note about padding zeros on the left) (编辑:注意左侧填充零)

(Edit2: in hindsight, it is better to have the final representation be tuples of integers than strings). (Edit2:事后看来,最终表示最好是整数元组而不是字符串)。

(Edit3: for those wondering, the purpose of the code was to enumerate states of a spin-1 chain that have the same total S_z values.) (Edit3:对于那些想知道的人,代码的目的是枚举具有相同总 S_z 值的自旋 1 链的状态。)

You can use itertools.product to generate the digits and then convert to the string representation:您可以使用itertools.product生成数字,然后转换为字符串表示形式:

import itertools as it

def new(n):
    s = []
    for digits in it.product((0, 1, 2), repeat=n):
        if sum(digits) == n:
            s.append(''.join(str(x) for x in digits))
    return s

This gives me about 7x speedup:这给了我大约 7 倍的加速:

In [8]: %timeit enumerateSpin(12)
2.39 s ± 7.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [9]: %timeit new(12)
347 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Tested on Python 3.9.0 (IPython 7.20.0) (Linux).在 Python 3.9.0 (IPython 7.20.0) (Linux) 上测试。

The above procedure, using it.product , also generates numbers from which we know by reasoning that they don't obey the condition (this is the case for half of all numbers since the sum of digits must equal the number of digits).上面的过程,使用it.product ,还生成我们通过推理它们不遵守条件而知道的数字(这是所有数字的一半的情况,因为数字的总和必须等于数字的数量)。 For n digits, we can compute the various counts of digits 2 , 1 and 0 that eventually sum up to n .对于n位,我们可以计算最终总和为n的数字210的各种计数。 Then we can generate all distinct permutations of these digits and thus only generate relevant numbers:然后我们可以生成这些数字的所有不同排列,从而只生成相关数字:

import itertools as it
from more_itertools import distinct_permutations

def new2(n):
    all_digits = (('2',)*i + ('1',)*(n-2*i) + ('0',)*i for i in range(n//2+1))
    all_digits = it.chain.from_iterable(distinct_permutations(d) for d in all_digits)
    return (''.join(digits) for digits in all_digits)

Especially for large numbers of n this gives an additional, significant speedup:特别是对于大量的n ,这会带来额外的显着加速:

In [44]: %timeit -r 1 -n 1 new(16)
31.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [45]: %timeit -r 1 -n 1 list(new2(16))
7.82 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Note that the above solutions new and new2 have O(1) memory scaling (change new to yield instead of append ).请注意,上述解决方案newnew2具有 O(1) memory 缩放(将new更改为yield而不是append )。

A 10x improvement can be achieved by delegating all calculations to numpy in order to leverage vectorized processing:通过将所有计算委托给 numpy 以利用矢量化处理,可以实现 10 倍的改进:

def eSpin(n):
    nums    = np.arange(3**n,dtype=np.int)
    base3   = nums // (3**np.arange(n))[:,None] % 3
    matches = np.sum(base3,axis=0) == n
    digits  = np.sum(base3[:,matches] * 10**np.arange(n)[:,None],axis=0)
    return [f"{a:0{n}}" for a in digits]   

How it works (example for eSpin(3)):它是如何工作的(以 eSpin(3) 为例):

nums is an array of all numbers up to 3**n nums是一个包含最多 3**n 的所有数字的数组

   [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26]  

base3 converts it into base 3 digits in an additional dimension: base3将其转换为附加维度中的以 3 为基数的数字:

[[0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2]
 [0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2]
 [0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2]]

matches identifies the columns where the sum of base3 digits matches n matches标识 base3 数字之和与n匹配的列

 [0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 0 0]

digits converts the matching columns into a base 10 number formed of the base3 digits digits将匹配的列转换为由 base3 数字组成的以 10 为基数的数字

 [ 12  21 102 111 120 201 210]

And finally the matching (base10) numbers are formatted with leading zeros.最后匹配的(base10)数字用前导零格式化。

performance:表现:

from timeit import timeit
count = 1

print(enumerateSpin(10)==eSpin(10)) # True

t1 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t1) # 0.634 sec

t0 = timeit(lambda:enumerateSpin(13),number=count)
print("enumerateSpin",t0) # 7.362 sec

Tuple version:元组版本:

def eSpin2(n):
    nums    = np.arange(3**n,dtype=np.int)
    base3   = nums// (3**np.arange(n))[:,None]  % 3
    matches = np.sum(base3,axis=0) == n
    return [*map(tuple,base3[:,matches].T)]

eSpin2(3)
[(2, 1, 0), (1, 2, 0), (2, 0, 1), (1, 1, 1), (0, 2, 1), (1, 0, 2), (0, 1, 2)]

[EDIT] An even faster approach (40x to 80x faster than enumerateSpin) [编辑] 一种更快的方法(比 enumerateSpin 快 40 到 80 倍)

Using dynamic programming and memoization can provide a lot better performance:使用动态编程和记忆可以提供更好的性能:

@lru_cache()
def eSpin(n,base=3,target=None):
    if target is None: target = n
    if target == 0: return [(0,)*n]
    if target>base**n-1: return []
    if n==1: return [(target,)]
    result = []
    for d in range(min(base,target+1)):
        result.extend((d,)+suffix for suffix in eSpin(n-1,base,target-d) )
    return result

t4 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t4) # 0.108 sec

eSpin.cache_clear()
t5 = timeit(lambda:eSpin(16),number=count)
print("eSpin",t5) # 2.25 sec

Here's a multiprocessing approach.这是一种多处理方法。 It'll afford more time savings, the more the problem scales in size它会节省更多的时间,问题的规模越大

import multiprocessing as mp


def filter(n, qIn, qOut):  # this is the function that will be parallelized
    nums = range(3**n)
    answer = []
    for low,high in iter(qIn.get, None):
        for num in nums[low:high]:
            r = np.base_repr(num, 3)  # ternary representation
            if sum(int(i) for i in r) == num:  # this is your check
                answer.append('0'*(n-len(r)) +r)  # turn it into a fixed length
    qOut.put(answer)
    qOut.put(None)


def enumerateSpin(n):  # this is the primary entry point
    numProcs = mp.cpu_count()-1  # fiddle to taste
    chunkSize = n//numProcs

    qIn, qOut = [mp.Queue() for _ in range(2)]
    procs = [mp.Process(target=filter, args=(n, qIn, qOut)) for _ in range(numProcs)]

    for p in procs: p.start()
    for i in range(0, 3**n, chunkSize):  # chunkify your numbers so that IPC is more efficient
        qIn.put((i, i+chunkSize))
    for p in procs: qIn.put(None)

    answer = []
    done = 0
    while done < len(procs):
        t = qOut.get()
        if t is None:
            done += 1
            continue
        answer.extend(t)

    for p in procs: p.terminate()

    return answer

In general, to get the digits of a number in a specific base we can do:通常,要获取特定基数中的数字的数字,我们可以这样做:

while num > 0:
    digit = num % base
    num //= base
    print(digit)

When running this with num = 14, base = 3 we get:当使用num = 14, base = 3运行它时,我们得到:

2
1
1

Which means that 14 in ternary is 112.这意味着三进制中的 14 是 112。
We can extract that into a method digits(num, base) and only use np.base_repr(a,3) when we actualy need to convert the number into a string:我们可以将其提取到方法digits(num, base) ,并且仅在我们实际需要将数字转换为字符串时才使用np.base_repr(a,3)

def enumerateSpin(n):
    s = []
    for a in range(0,3**n):
        if sum(digits(a, 3)) == n:
            ternary_rep = np.base_repr(a,3)
            k = len(ternary_rep)
            r = (n-k)*'0'+ternary_rep
            s.append(r)
    return s

Output for enumerateSpin(4) : Output 用于enumerateSpin(4)

['0022', '0112', '0121', '0202', '0211', '0220', '1012', '1021', '1102', '1111', '1120', '1201', '1210', '2002', '2011', '2020', '2101', '2110', '2200']

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM