[英]Fastest way to update a list of integers based on the duplicates in another list of integers
我有两个相同长度的 integer 值列表:项目列表和标签列表。 如果一个项目在项目列表中重复,这意味着它们在标签列表中使用不同的整数进行标记。 我想为所有标有这些整数的项目分配相同的整数/标签(例如第一次出现的 label)(请注意,这可能不仅仅是我们在项目列表中首先找到的重复项) .
这是我正在做的一个最小示例(我将列表转换为数组):
import numpy as np
import numba as nb
from collections import Counter
items = np.array([7,2,0,6,0,4,1,5,2,0])
labels = np.array([1,0,3,4,2,1,6,6,5,4])
dups = [x for x, c in Counter(items).items() if c>1]
#@nb.njit(fastmath=True)
def update_labels(items, labels, dups):
for dup in dups:
found = np.where(items==dup)[0]
l = labels[found]
isin = np.where((np.isin(labels, l)))[0]
labels[isin] = labels[isin[0]]
return labels
new_labels = update_labels(items, labels, dups)
print(new_labels) # prints [1 0 3 3 3 1 6 6 0 3]
该代码适用于小列表。 但是,对于较大的列表,例如
np.random.seed(0)
n = 1_000_000
items = np.random.randint(n, size=n)
labels = np.random.randint(int(0.8*n), size=n)
返回新标签需要很长时间。 瓶颈在update_labels() function 中,我也尝试使用numba jit 装饰器来加速它,但事实证明numba不支持np.isin 。
有没有办法让这个算法更有效和/或让它与numba一起(有效地)工作? 代码效率对我来说非常重要,因为我将它与大量列表(数千万)一起使用。 我也愿意使用 C 或 C++ function 并将其从 ZA7F5FZ23 称为最后的手段。 我使用 Python 3.x。
items = np.array([7, 2, 0, 6, 0, 4, 1, 5, 2, 0])
labels = np.array([1, 0, 3, 4, 2, 1, 6, 6, 5, 4])
d = {}
for i in range(len(items)):
label = d.setdefault(items[i], labels[i])
if label != labels[i]:
labels[i] = label
Output
[1 0 3 4 3 1 6 6 0 3]
这个给出了与原始版本相同的 output。
def update_labels(items, labels):
i_dict, l_dict, ranks = {}, {}, {}
for i in range(len(items)):
label = i_dict.setdefault(items[i], labels[i])
if labels[i] not in ranks:
ranks[labels[i]] = i
if label != labels[i]:
label1 = label
label2 = labels[i]
while label1 is not None and label2 is not None:
if ranks[label1] > ranks[label2]:
tmp = l_dict.get(label1)
l_dict[label1] = label2
label1 = tmp
elif ranks[label1] < ranks[label2]:
tmp = l_dict.get(label2)
l_dict[label2] = label1
label2 = tmp
else:
break
labels[i] = label
for i in range(len(labels)):
val = 0
label = labels[i]
while val != -1:
val = l_dict.get(label, -1)
if val != -1:
label = val
if label != labels[i]:
labels[i] = label
return labels
我觉得你的代码已经很优化了。 我唯一注意到的是,如果您对dups
数组进行切片并将 function update_labels
应用于仅限于相关索引的子问题,则对于大小n=100_000
的问题(参见update_labels_2
. Pramote Kuacharoen 的解决方案(参见 function update_labels_2
)要快得多,但不能针对大问题给出正确的解决方案(不知道它产生的解决方案是否适合您):
import numpy as np
import numba as nb
from collections import Counter
import time
np.random.seed(0)
n = 100_000
items = np.random.randint(n, size=n)
labels = np.random.randint(int(0.8*n), size=n)
dups = np.array([x for x, c in Counter(items).items() if c>1])
# --------------- 1. Original solution ---------------
def update_labels(items, labels, dups):
for dup in dups:
found = np.where(items==dup)[0]
l = labels[found]
isin = np.where((np.isin(labels, l)))[0]
labels[isin] = labels[isin[0]]
return labels
t_start = time.time()
new_labels = update_labels(items, np.copy(labels), dups)
print('Timer 1:', time.time()-t_start, 's')
# --------------- 2. Splitting into subproblems ---------------
def update_labels_2(items, labels, dups):
nb_slices = 20
offsets = [int(o) for o in np.linspace(0,dups.size,nb_slices+1)]
for i in range(nb_slices):
#for i in range(nb_slices-1,-1,-1): # ALSO WORKS
sub_dups = dups[offsets[i]:offsets[i+1]]
l = labels[np.isin(items, sub_dups)]
sub_index = np.where(np.isin(labels, l))[0]
# Apply your function to subproblem
labels[sub_index] = update_labels(items[sub_index], labels[sub_index], sub_dups)
return labels
t_start = time.time()
new_labels_2 = update_labels_2(items, np.copy(labels), dups)
print('Timer 2:', time.time()-t_start, 's')
print('Results 1&2 are equal!' if np.allclose(new_labels,new_labels_2) else 'Results 1&2 differ!')
# --------------- 3. Pramote Kuacharoen solution ---------------
def update_labels_3(items, labels, dups):
i_dict, l_dict = {}, {}
for i in range(len(labels)):
indices = l_dict.setdefault(labels[i], [])
indices.append(i)
for i in range(len(items)):
label_values = i_dict.setdefault(items[i], [])
if len(label_values) != 0 and labels[i] not in label_values:
labels[i] = label_values[0]
label_values.append(labels[i])
for key, value in l_dict.items():
label = ''
sizes = []
for v in value:
sizes.append(len(i_dict[items[v]]))
idx = np.argmax(sizes)
label = labels[value[idx]]
for v in value:
labels[v] = label
return labels
t_start = time.time()
new_labels_3 = update_labels_3(items, np.copy(labels), dups)
print('Timer 3:', time.time()-t_start, 's')
print('Results 1&3 are equal!' if np.allclose(new_labels,new_labels_3) else 'Results 1&3 differ!')
Output:
% python3 script.py
Timer 1: 5.082866907119751 s
Timer 2: 1.9104671478271484 s
Results 1&2 are equal!
Timer 3: 0.7601778507232666 s
Results 1&3 differ!
不幸的是,我获得的最佳加速是使用nb_slices=20
。 但是仍然有希望,因为您可以验证在 function update_labels_2
中以相反的顺序运行循环时,您仍然可以获得相同的顺序,因此,如果您可以证明子问题是独立的,您可以 go例如,使用mpi4py
并行。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.