簡體   English   中英

獲取與 itertools.groupby 中的值匹配的鍵數

[英]Get number of keys matching a value in itertools.groupby

我有二進制值列表,我正在嘗試獲取每個列表中連續 1 的組數。

下面是幾個例子:

[0, 0, 0, 0, 0, 0, 0, 0] -> 0
[1, 1, 1, 1, 1, 1, 1, 1] -> 1
[0, 1, 1, 1, 1, 0, 0, 0] -> 1
[0, 1, 1, 1, 0, 0, 1, 0] -> 2

我使用 itertools.groupby() 將列表分成組,這讓我得到了一個帶有鍵和組的迭代器,但我不太清楚如何具體獲取 1 組的數量。

顯然,我可以遍歷鍵並使用 if 語句進行計數,但我確信有更好的方法。

在寫問題時,我找到了以下解決方案(回想起來很明顯)。

run_count = sum(k == 1 for k, g in itertools.groupby(labels_sample))

我不確定它是否是最好的,但它確實有效。

在這種特定情況下,使用鍵01 ,您可以省略k == 1檢查並在總和中包含零。

sum(k for k, _ in groupby([0, 1, 1, 1, 0, 0, 1, 0])) -> 2

不是groupby ,而是可能回答“更好的方法”,這似乎更快:

def count_groups_of_ones(lst):
    it = iter(lst)
    count = 0
    while 1 in it:
        count += 1
        0 in it
    return count

四個小列表的基准測試結果:

  3.72 ms  with_groupby
  1.76 ms  with_in_iterator

使用更長的列表(您的列表乘以 1000):

984.32 ms  with_groupby
669.11 ms  with_in_iterator

基准代碼( 在線試用! ):

def with_groupby(lst):
    return sum(k for k, _ in groupby(lst))

def with_in_iterator(lst):
    it = iter(lst)
    count = 0
    while 1 in it:
        count += 1
        0 in it
    return count

from timeit import repeat
from itertools import groupby
from collections import deque
from operator import itemgetter, countOf

funcs = [
    with_groupby,
    with_in_iterator,
]

def benchmark(lists, number):
    print('lengths:', *map(len, lists))
    for _ in range(3):
        for func in funcs:
            t = min(repeat(lambda: deque(map(func, lists), 0), number=number)) / number
            print('%6.2f ms ' % (t * 1e6), func.__name__)
        print()    

lists = [
    [0, 0, 0, 0, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1],
    [0, 1, 1, 1, 1, 0, 0, 0],
    [0, 1, 1, 1, 0, 0, 1, 0],
]

for func in funcs:
    print(*map(func, lists))
benchmark(lists, 10000)
benchmark([lst * 1000 for lst in lists], 40)

另一個更通用的選項:

def count_groups(lst, value):
    start = object()
    return sum((a is start or a != value) and b == value for a, b in zip([start] + lst, lst))

count_groups([0, 1, 1, 1, 0, 0, 1, 0], 1) # 2

如果針對長列表的速度進行優化,請嘗試調整使用numpy的答案

def count_groups(lst, value):
    return np.diff(np.array(lst) == value, prepend=False, append=False).sum() // 2

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM