[英]What is the most efficient way of getting the intersection of k sorted arrays?
给定 k 排序 arrays 获取这些列表交集的最有效方法是什么
例子
输入:
[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]
Output:
[1,7]
有一种方法可以根据我在 nlogk 时间的编程访谈元素一书中读到的内容来获得 k 排序的 arrays 的并集。 我想知道是否有办法为十字路口做类似的事情
## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# collect results in nlogK time
while heap:
elem, ary = heapq.heappop(heap)
it = srtd_iters[ary]
res.append(elem)
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
编辑:显然这是一个我试图解决的算法问题,所以我不能使用任何内置函数,如设置交集等
这是一种 O(n) 方法,它不需要任何特殊数据结构或辅助 memory 超出一个迭代器和每个子列表一个值的基本要求:
from itertools import cycle
def intersection(data):
ITERATOR, VALUE = 0, 1
n = len(data)
result = []
try:
pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
pair = next(pairs)
curr = pair[VALUE] # Candidate is the largest value seen so far
matches = 1 # Number of pairs where the candidate occurs
while True:
iterator, value = pair = next(pairs)
while value < curr:
value = next(iterator)
pair[VALUE] = value
if value > curr:
curr, matches = value, 1
continue
matches += 1
if matches != n:
continue
result.append(curr)
while (value := next(iterator)) == curr:
pass
pair[VALUE] = value
curr, matches = value, 1
except StopIteration:
return result
这是一个示例 session:
>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> intersection(data)
[1, 7]
该算法围绕迭代器、值对循环。 如果一个值在所有对中都匹配,则它属于交集。 如果一个值低于目前看到的任何其他值,则当前迭代器前进。 如果某个值大于目前看到的任何值,则它将成为新目标,并且匹配计数重置为 1。 当任何迭代器用尽时,算法就完成了。
itertools.cycle()的使用是完全可选的。 它很容易通过增加最后环绕的索引来模拟。
代替:
iterator, value = pair = next(pairs)
你可以写:
pairnum += 1
if pairnum == n:
pairnum = 0
iterator, value = pair = pairs[pairnum]
或者更简洁:
pairnum = (pairnum + 1) % n
iterator, value = pair = pairs[pairnum]
如果要保留重复(如多重集),这是一个简单的修改,只需更改result.append(curr)
之后的四行以从每个迭代器中删除匹配元素:
def intersection(data):
ITERATOR, VALUE = 0, 1
n = len(data)
result = []
try:
pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
pair = next(pairs)
curr = pair[VALUE] # Candidate is the largest value seen so far
matches = 1 # Number of pairs where the candidate occurs
while True:
iterator, value = pair = next(pairs)
while value < curr:
value = next(iterator)
pair[VALUE] = value
if value > curr:
curr, matches = value, 1
continue
matches += 1
if matches != n:
continue
result.append(curr)
for i in range(n):
iterator, value = pair = next(pairs)
pair[VALUE] = next(iterator)
curr, matches = pair[VALUE], 1
except StopIteration:
return result
对的,这是可能的。 我已经修改了您的示例代码来执行此操作。
我的回答假设您的问题是关于算法的 - 如果您想要使用set
s 运行最快的代码,请参阅其他答案。
这保持了O(n log(k))
时间复杂度: if lowest:= elem or ary != times_seen:
和unbench_all = False
之间的所有代码都是O(log(k))
。 主循环中有一个嵌套循环( for unbenched in range(times_seen):
),但这只会运行times_seen
次,并且times_seen
最初为 0,并且在每次运行此内部循环后重置为 0,并且只能递增每次主循环迭代一次,因此内循环总共不能比主循环进行更多的迭代。 因此,由于内循环内部的代码是O(log(k))
并且最多运行与外循环相同的次数,而外循环是O(log(k))
并且运行n
次,所以算法是O(n log(k))
。
该算法依赖于在 Python 中如何比较元组。 它比较元组的第一项,如果它们相等,则比较第二项(即(x, a) < (x, b)
当且仅当a < b
时为真)。 在这个算法中,与问题中的示例代码不同,当一个项目从堆中弹出时,它不一定在同一次迭代中再次被推送。 由于我们需要检查所有子列表是否包含相同的数字,所以在从堆中弹出一个数字后,它的子列表就是我所说的“benched”,意思是它不会被添加回堆中。 这是因为我们需要检查其他子列表是否包含相同的项目,所以现在不需要添加这个子列表的下一个项目。
如果一个数字确实在所有子列表中,那么堆看起来像[(2,0),(2,1),(2,2),(2,3)]
,所有第一个元素元组相同,因此heappop
将 select 具有最低子列表索引的那个。 这意味着第一个索引 0 将被弹出并且times_seen
将递增到 1,然后索引 1 将被弹出并且times_seen
将递增到 2 - 如果ary
不等于times_seen
则该数字不在所有子的交集列表。 这会导致条件if lowest:= elem or ary != times_seen:
,它决定何时不应在结果中出现数字。 此if
语句的else
分支用于当它仍然可能在结果中时。
unbench_all
boolean 适用于需要从工作台中删除所有子列表的情况 - 这可能是因为:
当unbench_all
为True
时,从堆中删除的所有子列表都将重新添加。 众所周知,这些是索引在range(times_seen)
中的那些,因为该算法仅在它们具有相同编号时才从堆中删除项目,因此它们必须按索引顺序连续删除,并且从索引 0 开始,并且他们一定有times_seen
。 这意味着我们不需要存储 benched 子列表的索引,只需要存储已经被 benched 的数字。
import heapq
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# the number of tims that the current number has been seen
times_seen = 0
# the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
lowest = heap[0][0] if heap else None
# collect results in nlogK time
while heap:
elem, ary = heap[0]
unbench_all = True
if lowest != elem or ary != times_seen:
if lowest == elem:
heapq.heappop(heap)
it = srtd_iters[ary]
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
else:
heapq.heappop(heap)
times_seen += 1
if times_seen == len(srtd_arys):
res.append(elem)
else:
unbench_all = False
if unbench_all:
for unbenched in range(times_seen):
unbenched_it = srtd_iters[unbenched]
nxt = next(unbenched_it, None)
if nxt:
heapq.heappush(heap, (nxt, unbenched))
times_seen = 0
if heap:
lowest = heap[0][0]
return res
if __name__ == '__main__':
a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
a2 = [[1, 1], [1, 1, 2, 2, 3]]
for arys in [a1, a2]:
print(mergeArys(arys))
如果您愿意,可以这样编写等效算法:
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# collect results in nlogK time
while heap:
elem, ary = heap[0]
lowest = elem
keep_elem = True
for i in range(len(srtd_arys)):
elem, ary = heap[0]
if lowest != elem or ary != i:
if ary != i:
heapq.heappop(heap)
it = srtd_iters[ary]
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
keep_elem = False
i -= 1
break
heapq.heappop(heap)
if keep_elem:
res.append(elem)
for unbenched in range(i+1):
unbenched_it = srtd_iters[unbenched]
nxt = next(unbenched_it, None)
if nxt:
heapq.heappush(heap, (nxt, unbenched))
if len(heap) < len(srtd_arys):
heap = []
return res
您可以使用reduce
:
from functools import reduce
a = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
reduce(lambda x, y: x & set(y), a[1:], set(a[0]))
{1, 7}
您可以使用内置集合和集合交集:
d = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
result = set(d[0]).intersection(*d[1:])
{1, 7}
我想出了这个算法。 它不超过 O(n k) 我不知道它是否对你来说足够好。 该算法的要点是,您可以为每个数组设置 k 个索引,并且每次迭代都可以找到交集中下一个元素的索引并增加每个索引,直到超出数组的边界并且交集中没有更多项目. 诀窍是因为 arrays 已排序,您可以查看两个不同 arrays 中的两个元素,如果一个大于另一个,您可以立即丢弃另一个,因为您知道您的数字不能小于您正在查看的数字。 该算法的最坏情况是每个索引都将增加到需要 k n 时间的边界,因为索引无法减小其值。
inter = []
for n in range(len(arrays[0])):
if indexes[0] >= len(arrays[0]):
return inter
for i in range(1,k):
if indexes[i] >= len(arrays[i]):
return inter
while indexes[i] < len(arrays[i]) and arrays[i][indexes[i]] < arrays[0][indexes[0]]:
indexes[i] += 1
while indexes[i] < len(arrays[i]) and indexes[0] < len(arrays[0]) and arrays[i][indexes[i]] > arrays[0][indexes[0]]:
indexes[0] += 1
if indexes[0] < len(arrays[0]):
inter.append(arrays[0][indexes[0]])
indexes = [idx+1 for idx in indexes]
return inter
您可以将位掩码与 one-hot 编码一起使用。 内部列表成为 maxterms。 你和他们一起为路口,或他们为联合。 然后你必须转换回来,为此我使用了一些 hack 。
problem = [[1,3,5,7],[1,1,3,5,8,7],[1,4,7,9]];
debruijn = [0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9];
u32 = accum = (1 << 32) - 1;
for vec in problem:
maxterm = 0;
for v in vec:
maxterm |= 1 << v;
accum &= maxterm;
# https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogDeBruijn
result = [];
while accum:
power = accum;
accum &= accum - 1; # Peter Wegner CACM 3 (1960), 322
power &= ~accum;
result.append(debruijn[((power * 0x077CB531) & u32) >> 27]);
print result;
这使用(模拟)32 位整数,因此您的集合中只能有[0, 31]
。
*我对Python没有经验,所以我计时了。 绝对应该使用set.intersection
。
你说我们不能使用集合,但是 dicts / hash 表怎么样? (是的,我知道它们基本上是一样的):D
如果是这样,这是一个相当简单的方法(请原谅 py2 语法):
arrays = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
counts = {}
for ar in arrays:
last = None
for i in ar:
if (i != last):
counts[i] = counts.get(i, 0) + 1
last = i
N = len(arrays)
intersection = [i for i, n in counts.iteritems() if n == N]
print intersection
与 Raymond Hettinger 的解决方案相同,但具有更基本的 python 代码:
def intersection(arrays, unique: bool=False):
result = []
if not len(arrays) or any(not len(array) for array in arrays):
return result
pointers = [0] * len(arrays)
target = arrays[0][0]
start_step = 0
current_step = 1
while True:
idx = current_step % len(arrays)
array = arrays[idx]
while pointers[idx] < len(array) and array[pointers[idx]] < target:
pointers[idx] += 1
if pointers[idx] < len(array) and array[pointers[idx]] > target:
target = array[pointers[idx]]
start_step = current_step
current_step += 1
continue
if unique:
while (
pointers[idx] + 1 < len(array)
and array[pointers[idx]] == array[pointers[idx] + 1]
):
pointers[idx] += 1
if (current_step - start_step) == len(arrays):
result.append(target)
for other_idx, other_array in enumerate(arrays):
pointers[other_idx] += 1
if pointers[idx] < len(array):
target = array[pointers[idx]]
start_step = current_step
if pointers[idx] == len(array):
return result
current_step += 1
这是一个 O(n) 答案(其中n = sum(len(sublist) for sublist in data)
)。
from itertools import cycle
def intersection(data):
result = []
maxval = float("-inf")
consecutive = 0
try:
for sublist in cycle(iter(sublist) for sublist in data):
value = next(sublist)
while value < maxval:
value = next(sublist)
if value > maxval:
maxval = value
consecutive = 0
continue
consecutive += 1
if consecutive >= len(data)-1:
result.append(maxval)
consecutive = 0
except StopIteration:
return result
print(intersection([[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]))
[1, 7]
当列表的每个子集中都有重复项时,上述某些方法未涵盖示例。 下面的代码实现了这个交集,如果列表的子集中有很多重复,它会更有效:) 如果不确定重复,建议使用 Counter from collections from collections import Counter
。 自定义计数器 function 用于提高处理大型重复项的效率。 但仍然无法击败Raymond Hettinger 的执行力。
def counter(my_list):
my_list = sorted(my_list)
first_val, *all_val = my_list
p_index = my_list.index(first_val)
my_counter = {}
for item in all_val:
c_index = my_list.index(item)
diff = abs(c_index-p_index)
p_index = c_index
my_counter[first_val] = diff
first_val = item
c_index = my_list.index(item)
diff = len(my_list) - c_index
my_counter[first_val] = diff
return my_counter
def my_func(data):
if not data or not isinstance(data, list):
return
# get the first value
first_val, *all_val = data
if not isinstance(first_val, list):
return
# count items in first value
p = counter(first_val) # counter({1: 2, 3: 1, 5: 1, 7: 1})
# collect all common items and calculate the minimum occurance in intersection
for val in all_val:
# collecting common items
c = counter(val)
# calculate the minimum occurance in intersection
inner_dict = {}
for inner_val in set(c).intersection(set(p)):
inner_dict[inner_val] = min(p[inner_val], c[inner_val])
p = inner_dict
# >>>p
# {1: 2, 7: 1}
# Sort by keys of counter
sorted_items = sorted(p.items(), key=lambda x:x[0]) # [(1, 2), (7, 1)]
result=[i[0] for i in sorted_items for _ in range(i[1])] # [1, 1, 7]
return result
这是示例示例
>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> my_func(data=data)
[1, 7]
>>> data = [[1,1,3,5,7],[1,1,3,5,7],[1,1,4,7,9]]
>>> my_func(data=data)
[1, 1, 7]
您可以使用函数heapq.merge 、 chain.from_iterable和groupby执行以下操作
from heapq import merge
from itertools import groupby, chain
ls = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
def index_groups(lst):
"""[1, 1, 3, 5, 7] -> [(1, 0), (1, 1), (3, 0), (5, 0), (7, 0)]"""
return chain.from_iterable(((e, i) for i, e in enumerate(group)) for k, group in groupby(lst))
iterables = (index_groups(li) for li in ls)
flat = merge(*iterables)
res = [k for (k, _), g in groupby(flat) if sum(1 for _ in g) == len(ls)]
print(res)
Output
[1, 7]
这个想法是提供一个额外的值(使用枚举)来区分同一列表中的相等值(参见 function index_groups
)。
该算法的复杂度为O(n)
,其中n
是输入中每个列表的长度之和。
请注意 output 用于(每个列表额外 1 个):
ls = [[1, 1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 1, 4, 7, 9]]
是:
[1, 1, 7]
这是单程计数算法,是其他人建议的简化版本。
def intersection(iterables):
target, count = None, 0
for it in itertools.cycle(map(iter, iterables)):
for value in it:
if count == 0 or value > target:
target, count = value, 1
break
if value == target:
count += 1
break
else: # exhausted iterator
return
if count >= len(iterables):
yield target
count = 0
二进制和指数搜索还没有出现。 即使使用“无内置”约束,它们也很容易重新创建。
在实践中,这会更快,并且是次线性的。 在最坏的情况下——交叉点没有缩小——天真的方法会重复工作。 但是有一个解决方案:集成二进制搜索,同时将 arrays 分成两半。
def intersection(seqs):
seq = min(seqs, key=len)
if not seq:
return
pivot = seq[len(seq) // 2]
lows, counts, highs = [], [], []
for seq in seqs:
start = bisect.bisect_left(seq, pivot)
stop = bisect.bisect_right(seq, pivot, start)
lows.append(seq[:start])
counts.append(stop - start)
highs.append(seq[stop:])
yield from intersection(lows)
yield from itertools.repeat(pivot, min(counts))
yield from intersection(highs)
两者都处理重复项。 两者都保证 O(N) 最坏情况时间(将切片计算为原子)。 后者将接近 O(min_size) 速度; 通过总是将最小的分成两半,它基本上不会遭受不均匀分裂的厄运。
我不禁注意到这似乎是福利骗子问题的一种变体; 请参阅 David Gries 的书《编程科学》 。 Edsger Dijkstra 还为此写了一篇 EWD,请参阅Ascending Functions and the Welfare Crook 。
假设我们有三卷长磁带,每卷都包含按字母顺序排列的姓名列表:
实际上,这三个列表都是无穷无尽的,因此没有给出上限。 众所周知,所有三个名单上至少有一个人。 编写一个程序来定位第一个这样的人。
我们的有序列表问题的交集是福利骗子问题的推广。
这是福利骗子问题的(相当原始?)Python 解决方案:
def find_welfare_crook(f, g, h, i, j, k):
"""f, g, and h are "ascending functions," i.e.,
i <= j implies f[i] <= f[j] or, equivalently,
f[i] < f[j] implies i < j, and the same goes for g and h.
i, j, k define where to start the search in each list.
"""
# This is an implementation of a solution to the Welfare Crook
# problems presented in David Gries's book, The Science of Programming.
# The surprising and beautiful thing is that the guard predicates are
# so few and so simple.
i , j , k = i , j , k
while True:
if f[i] < g[j]:
i += 1
elif g[j] < h[k]:
j += 1
elif h[k] < f[i]:
k += 1
else:
break
return (i,j,k)
# The other remarkable thing is how the negation of the guard
# predicates works out to be: f[i] == g[j] and g[j] == c[k].
这可以推广到K个列表,这就是我设计的; 我不知道这是 Pythonic,但它非常紧凑:
def findIntersectionLofL(lofl):
"""Generalized findIntersection function which operates on a "list of lists." """
K = len(lofl)
indices = [0 for i in range(K)]
result = []
#
try:
while True:
# idea is to maintain the indices via a construct like the following:
allEqual = True
for i in range(K):
if lofl[i][indices[i]] < lofl[(i+1)%K][indices[(i+1)%K]] :
indices[i] += 1
allEqual = False
# When the above iteration finishes, if all of the list
# items indexed by the indices are equal, then another
# item common to all of the lists must be added to the result.
if allEqual :
result.append(lofl[0][indices[0]])
while lofl[0][indices[0]] == lofl[1][indices[1]]:
indices[0] += 1
except IndexError as e:
# Eventually, the foregoing iteration will advance one of the
# indices past the end of one of the lists, and when that happens
# an IndexError exception will be raised. This means the algorithm
# is finished.
return result
此解决方案不保留重复项。 通过更改程序在“while True”循环末尾的条件中所做的操作来更改程序以包含所有重复的项目是留给读者的练习。
来自@greybeard 的评论提示了如下所示的改进,在“数组索引模数”(“(i+1)%K”表达式)的预计算和进一步调查中也带来了内部迭代结构的变化,以进一步去除高架:
def findIntersectionLofLunRolled(lofl):
"""Generalized findIntersection function which operates on a "list of lists."
Accepts a list-of-lists, lofl. Each of the lists must be ordered.
Returns the list of each element which appears in all of the lists at least once.
"""
K = len(lofl)
indices = [0] * K
result = []
lt = [ (i, (i+1) % K) for i in range(K) ] # avoids evaluation of index exprs inside the loop
#
try:
while True:
allUnEqual = True
while allUnEqual:
allUnEqual = False
for i,j in lt:
if lofl[i][indices[i]] < lofl[j][indices[j]]:
indices[i] += 1
allUnEqual = True
# Now all of the lofl[i][indices[i]], for all i, are the same value.
# Store that value in the result, and then advance all of the indices
# past that common value:
v = lofl[0][indices[0]]
result.append(v)
for i,j in lt:
while lofl[i][indices[i]] == v:
indices[i] += 1
except IndexError as e:
# Eventually, the foregoing iteration will advance one of the
# indices past the end of one of the lists, and when that happens
# an IndexError exception will be raised. This means the algorithm
# is finished.
return result
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.