简体   繁体   中英

Get number of keys matching a value in itertools.groupby

I have lists of binary values and I am trying to get the number of groups of consecutive 1s in each list.

Here's a few examples:

[0, 0, 0, 0, 0, 0, 0, 0] -> 0
[1, 1, 1, 1, 1, 1, 1, 1] -> 1
[0, 1, 1, 1, 1, 0, 0, 0] -> 1
[0, 1, 1, 1, 0, 0, 1, 0] -> 2

I use itertools.groupby() to split the lists into groups, which gets me an iterator with the keys and groups, but I can't quite figure out how to get the number of groups of 1s specifically.

Obviously, I could iterate over the keys and count up with an if statement but I'm sure there's a better way.

While writing the question, I found the following solution (which was obvious in retrospect).

run_count = sum(k == 1 for k, g in itertools.groupby(labels_sample))

I'm not sure if it's the best, but it works.

In this specific case, having the keys 0 and 1 , you can omit the k == 1 check and include the zeros in the sum.

sum(k for k, _ in groupby([0, 1, 1, 1, 0, 0, 1, 0])) -> 2

Not with groupby , but to possibly answer the "a better way", this appears to be faster:

def count_groups_of_ones(lst):
    it = iter(lst)
    count = 0
    while 1 in it:
        count += 1
        0 in it
    return count

Benchmark results for your four small lists:

  3.72 ms  with_groupby
  1.76 ms  with_in_iterator

And with longer lists (your lists multiplied by 1000):

984.32 ms  with_groupby
669.11 ms  with_in_iterator

Benchmark code ( Try it online! ):

def with_groupby(lst):
    return sum(k for k, _ in groupby(lst))

def with_in_iterator(lst):
    it = iter(lst)
    count = 0
    while 1 in it:
        count += 1
        0 in it
    return count

from timeit import repeat
from itertools import groupby
from collections import deque
from operator import itemgetter, countOf

funcs = [
    with_groupby,
    with_in_iterator,
]

def benchmark(lists, number):
    print('lengths:', *map(len, lists))
    for _ in range(3):
        for func in funcs:
            t = min(repeat(lambda: deque(map(func, lists), 0), number=number)) / number
            print('%6.2f ms ' % (t * 1e6), func.__name__)
        print()    

lists = [
    [0, 0, 0, 0, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1],
    [0, 1, 1, 1, 1, 0, 0, 0],
    [0, 1, 1, 1, 0, 0, 1, 0],
]

for func in funcs:
    print(*map(func, lists))
benchmark(lists, 10000)
benchmark([lst * 1000 for lst in lists], 40)

Another option that is more general:

def count_groups(lst, value):
    start = object()
    return sum((a is start or a != value) and b == value for a, b in zip([start] + lst, lst))

count_groups([0, 1, 1, 1, 0, 0, 1, 0], 1) # 2

If optimizing for speed over a long list, try adapting this answer that uses numpy :

def count_groups(lst, value):
    return np.diff(np.array(lst) == value, prepend=False, append=False).sum() // 2

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM