[英]How to speed up intersection of dict of sets in Python
我有一本包含一组整数的字典。
{'A': {9, 203, 404, 481},
'B': {9},
'C': {110},
'D': {9, 314, 426},
'E': {59, 395, 405}
}
您可以使用以下方法生成数据:
data = {}
for i in string.ascii_uppercase:
n = 25
rng = np.random.default_rng()
data[i] = set(rng.choice(100, size=n, replace=False))
我需要获取字典子集的交集列表。 因此,在示例中,['A','B','D'] 的交集的输出将返回 [9]
我已经想出了 2 种不同的方法来做到这一点,但是当套装价值增长时,这两种方法都会变慢。
cols = ['A','B','D']
# method 1
lis = list(map(data.get, cols))
idx = list(set.intersection(*lis))
#method 2 (10x slower then method 1)
query_dict = dict((k, data[k]) for k in cols)
idx2 = list(reduce(set.intersection, (set(val) for val in query_dict.values())))
当集合增长(每集 >10k int)时,运行时会快速增长。
我可以使用其他数据类型,然后在 dict 中设置,如列表或 numpy 数组等。
有没有更快的方法来完成这个?
编辑:
我最初遇到的问题是这个数据框:
T S A B C D
0 49.378 1.057 AA AB AA AA
1 1.584 1.107 BC BA AA AA
2 1.095 0.000 BB BB AD
3 10.572 1.224 BA AB AA AA
4 0.000 0.000 DC BA AB
对于每一行,我必须对具有共同 A、B、C、D 的所有行求和 'T',如果达到阈值,则继续在 B、C、D 上继续,然后是 C、D,然后只有 D 如果还没有达到门槛。
但是这真的很慢,所以首先我尝试使用 get_dummies 然后获取列的乘积。 然而,这很慢,所以我转向带有索引的 numpy 数组进行求和。 这是迄今为止最快的选择,但是相交是唯一仍然需要太多时间来计算的东西。
编辑2:
事实证明,我对自己太苛刻了,使用 pandas groupby 是可能的,而且速度非常快。
代码:
parts = [['A','B','C','D'],['B','C','D'],['C','D'],['D']]
for part in parts:
temp_df = df.groupby(part,as_index=False).sum()
temp_df = temp_df[temp_df['T'] > 100]
df = pd.merge(df,temp_df,on=part,how='left',suffixes=["","_" + "".join(part)])
df['T_sum'] = df[['T_ABCD','T_BCD','T_CD','T_D']].min(axis=1)
df['S_sum'] = df[['S_ABCD','S_BCD','S_CD','S_D']].min(axis=1)
df.drop(['T_ABCD','T_BCD','T_CD','T_D','S_ABCD','S_BCD','S_CD','S_D'],, axis=1, inplace=True)
可能代码可以更简洁一些,但我不知道如何在合并中仅替换 NaN 值。
这里的问题是如何有效地找到几个集合的交集。 根据评论: “最大 n 是 1000 万 - 3000 万,列 a、b、c、d 几乎可以是唯一的行,共有 100 万行。” 所以集合很大,但大小不一样。 集合交集是一个结合和交换操作,所以我们可以按照我们喜欢的任何顺序取交集。
两个集合相交的时间复杂度是O(min(len(set1), len(set2)))
,所以我们应该选择一个顺序来做交集,这样可以最小化中间集的大小。
如果我们事先不知道哪些集合对有小的交集,我们能做的最好的事情就是按大小顺序将它们相交。 在第一个交集之后,最小的集合总是最后一个交集的结果,所以我们想把它与下一个最小的输入集相交。 这是更好地利用set.intersection
一次,而不是对所有的套reduce
在这里,因为这是实现基本相同的方式为reduce
会做,但在C.
def intersect_sets(sets):
return set.intersection(*sorted(sets, key=len))
在我们对成对交集一无所知的情况下,C 实现中唯一可能的放缓可能是为多个中间集分配了不必要的内存。 这可以通过例如{ x for x in first_set if all(x in s for s in other_sets) }
来避免,但结果证明要慢得多。
我用最大 600 万的设置对其进行了测试,大约有 10% 的成对重叠。 这是四组相交的时间; 四点之后,累加器大约是原始大小的 0.1%,因此任何进一步的交叉点无论如何都将花费可以忽略不计的时间。 橙色线表示最佳顺序(最小的两个在前)的相交集,蓝线表示最差的顺序(最大的两个在前)的相交集。
正如预期的那样,两者都在设定的大小中花费大致线性的时间,但有很多噪音,因为我没有对多个样本进行平均。 在相同的数据上测量,最优顺序始终是最差顺序的 2-3 倍,大概是因为这是最小和第二大集合大小之间的比率。
在我的机器上,4组大小2-600万相交大约需要100ms,所以上到3000万应该需要半秒左右; 我认为你不太可能击败它,但半秒应该没问题。 如果它始终比您的真实数据花费的时间长得多,那么问题就在于您的数据不是均匀随机的。 如果是这种情况,那么除此之外,Stack Overflow 可能不会为您做太多事情,因为提高效率将在很大程度上取决于您的真实数据的特定分布(尽管请参阅下文有关您必须回答相同问题的许多查询的情况)套)。
我的计时代码如下。
import string
import random
def gen_sets(m, min_n, max_n):
n_range = range(min_n, max_n)
x_range = range(min_n * 10, max_n * 10)
return [
set(random.sample(x_range, n))
for n in [min_n, max_n, *random.sample(n_range, m - 2)]
]
def intersect_best_order(sets):
return set.intersection(*sorted(sets, key=len))
def intersect_worst_order(sets):
return set.intersection(*sorted(sets, key=len, reverse=True))
from timeit import timeit
print('min_n', 'max_n', 'best order', 'worst order', sep='\t')
for min_n in range(100000, 2000001, 100000):
max_n = min_n * 3
data = gen_sets(4, min_n, max_n)
t1 = timeit(lambda: intersect_best_order(data), number=1)
t2 = timeit(lambda: intersect_worst_order(data), number=1)
print(min_n, max_n, t1, t2, sep='\t')
如果你需要做很多查询,那么首先计算成对交集可能是值得的:
from itertools import combinations
pairwise_intersection_sizes = {
(a, b): set_a & set_b
for ((a, set_a), (b, set_b)) in combinations(data.items(), 2)
}
如果某些交集比其他交集小很多,那么可以使用预先计算的成对交集来选择更好的顺序进行set.intersection
in。 给定一些集合,您可以选择具有最小预计算交集的对,然后进行set.intersection
on该预先计算的结果以及其余的输入集。 特别是在一些成对交叉点几乎为空的非均匀情况下,这可能是一个很大的改进。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.