簡體   English   中英

Python 列表累積總和大於列表中元素的 rest

[英]Python list cumulative sum greater than rest of elements in list

我想寫一個 function 它將返回 cumsum 大於列表中 rest 的最小數字。 List 將只有值 -1 和 1。 列表可能有數百萬個元素。 例如

v = [1 1 -1 1 -1 1 -1 1]

這里的答案應該是 2,因為

1) 1 > 1 is False 
2) (1 + 1) 2 > 0 (-1 + 1 -1 +1 -1 +1)

再舉一個例子

v = [-1 -1 1 1]

答案 4

我已經嘗試過的代碼:

def cumsum_grt(v):
    for i in range(1, len(v)):
        k = i
        if sum(v[:k]) > sum(v[k:]):
            break
    return k

這個 function 工作正常,但是有什么方法可以提高性能嗎? 由於無法在幾秒鍾內計算出大型列表的結果,因此失敗。

def cumsum_grt(v):
    total_sum = sum(v)
    curr_sum = v[0]
    for i in range(1, len(v)):
        if curr_sum > (total_sum - abs(curr_sum)):
            break
        curr_sum += v[i]
    return i

測試:

lst = [1, 1, -1, 1, -1, 1, -1, 1]
lst2 = [1, 1, -1, 20, -1, 15, -1, 1]
lst3 = [-2, -1, 4, -1]
lst4 = [-1,-1,-1,-1]

print(cumsum_grt(lst))   # 2
print(cumsum_grt(lst2))  # 4
print(cumsum_grt(lst3))  # 3
print(cumsum_grt(lst4))  # 1

時間性能測量:

In [101]: lst = [1, 1, -1, 20, -1, 15, -1, 5, -1, -2, 40]                                                                    

In [102]: %timeit cumsum_grt(lst)                                                                                            
70.3 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [103]: %timeit cumsum_grt_lenik(lst)                                                                                      
8.23 µs ± 27.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [104]: %timeit cumsum_grt_roman(lst)                                                                                      
8.22 µs ± 30.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

這是線性的,O(N),而你的版本類似於 O(N*N):

def cumsum_grt(v):
    so_far = 0
    the_rest = sum(v)
    for i in range(len(v)):
        if so_far > the_rest :
            return i
        so_far += v[i]
        the_rest -= v[i]
    return len(v)

您可以將nextitertools.accumulate一起使用,將當前累積總和與總和減去累積總和進行比較,然后使用enumerate獲取該 position 的索引。 position 在列表的第一個元素之前帶有[0]chain

>>> from itertools import accumulate, chain
>>> v = [1, 1, -1, 1, -1, 1, -1, 1]
>>> s = sum(v)
>>> next((i for i, a in enumerate(chain([0], accumulate(v))) if a > s - a), len(v))
2

注意:不要在if條件內計算sum(v) ,否則它將是 O(n²)。 最后的len(v)是默認值,以防累積總和不足以滿足任何元素。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM