簡體   English   中英

根據另一個整數列表中的重復項更新整數列表的最快方法

[英]Fastest way to update a list of integers based on the duplicates in another list of integers

我有兩個相同長度的 integer 值列表:項目列表和標簽列表。 如果一個項目在項目列表中重復,這意味着它們在標簽列表中使用不同的整數進行標記。 我想為所有標有這些整數的項目分配相同的整數/標簽(例如第一次出現的 label)(請注意,這可能不僅僅是我們在項目列表中首先找到的重復項) .

這是我正在做的一個最小示例(我將列表轉換為數組):

import numpy as np
import numba as nb
from collections import Counter

items  = np.array([7,2,0,6,0,4,1,5,2,0])
labels = np.array([1,0,3,4,2,1,6,6,5,4])

dups = [x for x, c in Counter(items).items() if c>1]

#@nb.njit(fastmath=True)
def update_labels(items, labels, dups):
    for dup in dups:
        found = np.where(items==dup)[0]
        l = labels[found]
        isin = np.where((np.isin(labels, l)))[0]
        labels[isin] = labels[isin[0]]
    return labels

new_labels = update_labels(items, labels, dups)
print(new_labels) # prints [1 0 3 3 3 1 6 6 0 3]

該代碼適用於小列表。 但是,對於較大的列表,例如

np.random.seed(0)
n = 1_000_000
items  = np.random.randint(n, size=n)
labels = np.random.randint(int(0.8*n), size=n)

返回新標簽需要很長時間。 瓶頸在update_labels() function 中,我也嘗試使用numba jit 裝飾器來加速它,但事實證明numba不支持np.isin

有沒有辦法讓這個算法更有效和/或讓它與numba一起(有效地)工作? 代碼效率對我來說非常重要,因為我將它與大量列表(數千萬)一起使用。 我也願意使用 C 或 C++ function 並將其從 ZA7F5FZ23 稱為最后的手段。 我使用 Python 3.x。

items = np.array([7, 2, 0, 6, 0, 4, 1, 5, 2, 0])
labels = np.array([1, 0, 3, 4, 2, 1, 6, 6, 5, 4])

d = {}

for i in range(len(items)):
    label = d.setdefault(items[i], labels[i])
    if label != labels[i]:
        labels[i] = label

Output

[1 0 3 4 3 1 6 6 0 3]

這個給出了與原始版本相同的 output。

def update_labels(items, labels):
    i_dict, l_dict, ranks = {}, {}, {}

    for i in range(len(items)):
        label = i_dict.setdefault(items[i], labels[i])
        if labels[i] not in ranks:
            ranks[labels[i]] = i

        if label != labels[i]:
            label1 = label
            label2 = labels[i]
            while label1 is not None and label2 is not None:
                if ranks[label1] > ranks[label2]:
                    tmp = l_dict.get(label1)
                    l_dict[label1] = label2
                    label1 = tmp
                elif ranks[label1] < ranks[label2]:
                    tmp = l_dict.get(label2)
                    l_dict[label2] = label1
                    label2 = tmp
                else:
                    break

            labels[i] = label

    for i in range(len(labels)):
        val = 0
        label = labels[i]
        while val != -1:
            val = l_dict.get(label, -1)
            if val != -1:
                label = val
        if label != labels[i]:
            labels[i] = label

    return labels

我覺得你的代碼已經很優化了。 我唯一注意到的是,如果您對dups數組進行切片並將 function update_labels應用於僅限於相關索引的子問題,則對於大小n=100_000的問題(參見update_labels_2 . Pramote Kuacharoen 的解決方案(參見 function update_labels_2 )要快得多,但不能針對大問題給出正確的解決方案(不知道它產生的解決方案是否適合您):

import numpy as np
import numba as nb
from collections import Counter
import time

np.random.seed(0)
n = 100_000
items  = np.random.randint(n, size=n)
labels = np.random.randint(int(0.8*n), size=n)

dups = np.array([x for x, c in Counter(items).items() if c>1])

# --------------- 1. Original solution ---------------
def update_labels(items, labels, dups):
    for dup in dups:
        found = np.where(items==dup)[0]
        l = labels[found]
        isin = np.where((np.isin(labels, l)))[0]
        labels[isin] = labels[isin[0]]
    return labels

t_start = time.time()
new_labels = update_labels(items, np.copy(labels), dups)
print('Timer 1:', time.time()-t_start, 's')

# --------------- 2. Splitting into subproblems ---------------
def update_labels_2(items, labels, dups):
    nb_slices = 20
    offsets = [int(o) for o in np.linspace(0,dups.size,nb_slices+1)]
    for i in range(nb_slices):
    #for i in range(nb_slices-1,-1,-1): # ALSO WORKS
        sub_dups = dups[offsets[i]:offsets[i+1]]
        l = labels[np.isin(items, sub_dups)]
        sub_index = np.where(np.isin(labels, l))[0]
        # Apply your function to subproblem
        labels[sub_index] = update_labels(items[sub_index], labels[sub_index], sub_dups)
    return labels

t_start = time.time()
new_labels_2 = update_labels_2(items, np.copy(labels), dups)
print('Timer 2:', time.time()-t_start, 's')

print('Results 1&2 are equal!' if np.allclose(new_labels,new_labels_2) else 'Results 1&2 differ!')

# --------------- 3. Pramote Kuacharoen solution ---------------
def update_labels_3(items, labels, dups):
    i_dict, l_dict = {}, {}
    for i in range(len(labels)):
        indices = l_dict.setdefault(labels[i], [])
        indices.append(i)
    for i in range(len(items)):
        label_values = i_dict.setdefault(items[i], [])
        if len(label_values) != 0 and labels[i] not in label_values:
            labels[i] = label_values[0]
        label_values.append(labels[i])
    for key, value in l_dict.items():
        label = ''
        sizes = []
        for v in value:
            sizes.append(len(i_dict[items[v]]))
            idx = np.argmax(sizes)
            label = labels[value[idx]]
        for v in value:
            labels[v] = label
    return labels

t_start = time.time()
new_labels_3 = update_labels_3(items, np.copy(labels), dups)
print('Timer 3:', time.time()-t_start, 's')
print('Results 1&3 are equal!' if np.allclose(new_labels,new_labels_3) else 'Results 1&3 differ!')

Output:

% python3 script.py
Timer 1: 5.082866907119751 s
Timer 2: 1.9104671478271484 s
Results 1&2 are equal!
Timer 3: 0.7601778507232666 s
Results 1&3 differ!

不幸的是,我獲得的最佳加速是使用nb_slices=20 但是仍然有希望,因為您可以驗證在 function update_labels_2中以相反的順序運行循環時,您仍然可以獲得相同的順序,因此,如果您可以證明子問題是獨立的,您可以 go例如,使用mpi4py並行。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM