繁体   English   中英

根据另一个整数列表中的重复项更新整数列表的最快方法

[英]Fastest way to update a list of integers based on the duplicates in another list of integers

我有两个相同长度的 integer 值列表:项目列表和标签列表。 如果一个项目在项目列表中重复,这意味着它们在标签列表中使用不同的整数进行标记。 我想为所有标有这些整数的项目分配相同的整数/标签(例如第一次出现的 label)(请注意,这可能不仅仅是我们在项目列表中首先找到的重复项) .

这是我正在做的一个最小示例(我将列表转换为数组):

import numpy as np
import numba as nb
from collections import Counter

items  = np.array([7,2,0,6,0,4,1,5,2,0])
labels = np.array([1,0,3,4,2,1,6,6,5,4])

dups = [x for x, c in Counter(items).items() if c>1]

#@nb.njit(fastmath=True)
def update_labels(items, labels, dups):
    for dup in dups:
        found = np.where(items==dup)[0]
        l = labels[found]
        isin = np.where((np.isin(labels, l)))[0]
        labels[isin] = labels[isin[0]]
    return labels

new_labels = update_labels(items, labels, dups)
print(new_labels) # prints [1 0 3 3 3 1 6 6 0 3]

该代码适用于小列表。 但是,对于较大的列表,例如

np.random.seed(0)
n = 1_000_000
items  = np.random.randint(n, size=n)
labels = np.random.randint(int(0.8*n), size=n)

返回新标签需要很长时间。 瓶颈在update_labels() function 中,我也尝试使用numba jit 装饰器来加速它,但事实证明numba不支持np.isin

有没有办法让这个算法更有效和/或让它与numba一起(有效地)工作? 代码效率对我来说非常重要,因为我将它与大量列表(数千万)一起使用。 我也愿意使用 C 或 C++ function 并将其从 ZA7F5FZ23 称为最后的手段。 我使用 Python 3.x。

items = np.array([7, 2, 0, 6, 0, 4, 1, 5, 2, 0])
labels = np.array([1, 0, 3, 4, 2, 1, 6, 6, 5, 4])

d = {}

for i in range(len(items)):
    label = d.setdefault(items[i], labels[i])
    if label != labels[i]:
        labels[i] = label

Output

[1 0 3 4 3 1 6 6 0 3]

这个给出了与原始版本相同的 output。

def update_labels(items, labels):
    i_dict, l_dict, ranks = {}, {}, {}

    for i in range(len(items)):
        label = i_dict.setdefault(items[i], labels[i])
        if labels[i] not in ranks:
            ranks[labels[i]] = i

        if label != labels[i]:
            label1 = label
            label2 = labels[i]
            while label1 is not None and label2 is not None:
                if ranks[label1] > ranks[label2]:
                    tmp = l_dict.get(label1)
                    l_dict[label1] = label2
                    label1 = tmp
                elif ranks[label1] < ranks[label2]:
                    tmp = l_dict.get(label2)
                    l_dict[label2] = label1
                    label2 = tmp
                else:
                    break

            labels[i] = label

    for i in range(len(labels)):
        val = 0
        label = labels[i]
        while val != -1:
            val = l_dict.get(label, -1)
            if val != -1:
                label = val
        if label != labels[i]:
            labels[i] = label

    return labels

我觉得你的代码已经很优化了。 我唯一注意到的是,如果您对dups数组进行切片并将 function update_labels应用于仅限于相关索引的子问题,则对于大小n=100_000的问题(参见update_labels_2 . Pramote Kuacharoen 的解决方案(参见 function update_labels_2 )要快得多,但不能针对大问题给出正确的解决方案(不知道它产生的解决方案是否适合您):

import numpy as np
import numba as nb
from collections import Counter
import time

np.random.seed(0)
n = 100_000
items  = np.random.randint(n, size=n)
labels = np.random.randint(int(0.8*n), size=n)

dups = np.array([x for x, c in Counter(items).items() if c>1])

# --------------- 1. Original solution ---------------
def update_labels(items, labels, dups):
    for dup in dups:
        found = np.where(items==dup)[0]
        l = labels[found]
        isin = np.where((np.isin(labels, l)))[0]
        labels[isin] = labels[isin[0]]
    return labels

t_start = time.time()
new_labels = update_labels(items, np.copy(labels), dups)
print('Timer 1:', time.time()-t_start, 's')

# --------------- 2. Splitting into subproblems ---------------
def update_labels_2(items, labels, dups):
    nb_slices = 20
    offsets = [int(o) for o in np.linspace(0,dups.size,nb_slices+1)]
    for i in range(nb_slices):
    #for i in range(nb_slices-1,-1,-1): # ALSO WORKS
        sub_dups = dups[offsets[i]:offsets[i+1]]
        l = labels[np.isin(items, sub_dups)]
        sub_index = np.where(np.isin(labels, l))[0]
        # Apply your function to subproblem
        labels[sub_index] = update_labels(items[sub_index], labels[sub_index], sub_dups)
    return labels

t_start = time.time()
new_labels_2 = update_labels_2(items, np.copy(labels), dups)
print('Timer 2:', time.time()-t_start, 's')

print('Results 1&2 are equal!' if np.allclose(new_labels,new_labels_2) else 'Results 1&2 differ!')

# --------------- 3. Pramote Kuacharoen solution ---------------
def update_labels_3(items, labels, dups):
    i_dict, l_dict = {}, {}
    for i in range(len(labels)):
        indices = l_dict.setdefault(labels[i], [])
        indices.append(i)
    for i in range(len(items)):
        label_values = i_dict.setdefault(items[i], [])
        if len(label_values) != 0 and labels[i] not in label_values:
            labels[i] = label_values[0]
        label_values.append(labels[i])
    for key, value in l_dict.items():
        label = ''
        sizes = []
        for v in value:
            sizes.append(len(i_dict[items[v]]))
            idx = np.argmax(sizes)
            label = labels[value[idx]]
        for v in value:
            labels[v] = label
    return labels

t_start = time.time()
new_labels_3 = update_labels_3(items, np.copy(labels), dups)
print('Timer 3:', time.time()-t_start, 's')
print('Results 1&3 are equal!' if np.allclose(new_labels,new_labels_3) else 'Results 1&3 differ!')

Output:

% python3 script.py
Timer 1: 5.082866907119751 s
Timer 2: 1.9104671478271484 s
Results 1&2 are equal!
Timer 3: 0.7601778507232666 s
Results 1&3 differ!

不幸的是,我获得的最佳加速是使用nb_slices=20 但是仍然有希望,因为您可以验证在 function update_labels_2中以相反的顺序运行循环时,您仍然可以获得相同的顺序,因此,如果您可以证明子问题是独立的,您可以 go例如,使用mpi4py并行。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM