简体   繁体   English

多重排序 arrays 中的第 K 个最小元素

[英]Kth Smallest Element in multiple sorted arrays

Let's say we have two arrays:假设我们有两个 arrays:

array1 = [2,3,6,7,9]数组1 = [2,3,6,7,9]

array2 = [1,4,8,10]数组2 = [1,4,8,10]

I understood how to find the kth element of two sorted arrays in log(min(m,n)) where m is the length of array1 and n is the length of array2 as follows:我了解如何在log(min(m,n))中找到两个已排序 arrays 的第 k 个元素,其中m是 array1 的长度, n是 array2 的长度,如下所示:

def kthelement(arr1, arr2, m, n, k):
    if m > n:
        kthelement(arr2, arr1, n, m, k) 

    low = max(0, k - m)
    high = min(k, n)

    while low <= high:
        cut1 = (low + high) >> 1 
        cut2 = k - cut1 
        l1 = MIN_VALUE if cut1 == 0 else arr1[cut1 - 1] 
        l2 = MIN_VALUE if cut2 == 0 else arr2[cut2 - 1]
        r1 = MAX_VALUE if cut1 == n else arr1[cut1]
        r2 = MAX_VALUE if cut2 == m else arr2[cut2] 
        
        if l1 <= r2 and l2 <= r1:
            print(cut1, cut2)
            return max(l1, l2)
        elif l1 > r2:
            high = cut1 - 1
        else:
            low = cut1 + 1

But I couldn't figure out how to extend this to multiple sorted arrays case.但我不知道如何将其扩展到多个排序的 arrays 案例。 For example, given 3 arrays, I want to find the kth element of the final sorted array.例如,给定 3 arrays,我想找到最终排序数组的第 k 个元素。

array1 = [2,3,6,7,9]数组1 = [2,3,6,7,9]

array2 = [1,4,8,10]数组2 = [1,4,8,10]

array3 = [2,3,5,7]数组3 = [2,3,5,7]

Is it possible to achieve it in log(min(m,n)) as in the two array case?是否可以像在两个数组的情况下那样在log(min(m,n))中实现它?

If k is very large, We can make binary search on the answer, which leads to a solution with time complexity O(n*logN) where N is the range of each element, and n is the number of arrays.如果 k 很大,我们可以对答案进行二分查找,得到一个时间复杂度为O(n*logN)的解,其中 N 是每个元素的范围, n是 arrays 的个数。

What we need to learn is how to check some integer x whether <= correct answer or not.我们需要学习的是如何检查一些 integer x是否<=正确答案。 We can just enumerate each array, and make binary search on it to count the number of elements less than or equal to x.我们可以枚举每个数组,并对其进行二进制搜索以计算小于或等于 x 的元素数。 accumulate them, and compare it with k .累积它们,并将其与k进行比较。

from typing import List
import bisect

def query_k_min(vecs: List[List[int]], k: int) -> int:
    # we assume each number >=1 and <=10^9
    l, r = 0, 10**9
    while r - l > 1:
        m = (l+r)>>1
        tot = 0
        for vec in vecs:
            tot += bisect.bisect_right(vec, m)
        if tot >= k: r = m
        else: l = m
    return r

a = [[2,3,6,7,9],[1,4,8,10],[2,3,5,7]]
for x in range(1,14):
    print(query_k_min(a,x))


The general solution is to use a min-heap.一般的解决方案是使用最小堆。 If you have n sorted arrays and you want the kth smallest number, then the solution is O(k log n).如果你有n排序的 arrays 并且你想要第 k 个最小的数字,那么解决方案是 O(k log n)。

The idea is that you insert the first number from each array into the min-heap.这个想法是将每个数组中的第一个数字插入到最小堆中。 When inserting into the heap, you insert a tuple that contains the number, and the array that it came from.插入堆时,插入一个包含数字的元组,以及它来自的数组。

You then remove the smallest value from the heap and add the next number from the array that value came from.然后,您从堆中删除最小值,并从该值所在的数组中添加下一个数字。 You do this k times to get the kth smallest number.你这样做 k 次以获得第 k 个最小的数字。

See https://www.geeksforgeeks.org/find-m-th-smallest-value-in-k-sorted-arrays/ for the general idea.有关总体思路,请参见https://www.geeksforgeeks.org/find-m-th-smallest-value-in-k-sorted-arrays/

The following looks complicated, but if M is the sum of the logs of len(list)+2 , then the average case is O(M) and the worst case is O(M^2) .以下看起来很复杂,但如果Mlen(list)+2的对数之和,则平均情况为O(M) ,最坏情况为O(M^2) (The reason for the +2 is that even if the array has no elements, we need to do work, which we do by making the log to be of at least 2.) The worst case is very unlikely. (+2 的原因是即使数组没有元素,我们也需要做一些工作,我们通过使 log 至少为 2 来做到这一点。)最坏的情况不太可能发生。

The performance is independent of k .性能与k无关。

The idea the same as Quickselect .这个想法与Quickselect相同。 We are picking pivots, and splitting data around the pivot.我们正在挑选支点,并围绕 pivot 拆分数据。 But we do not look at each elements, we only figure out what chunk of each array that is still under consideration is before/after/landed at the pivot.但是我们不查看每个元素,我们只计算每个阵列中仍在考虑中的哪个块在 pivot 之前/之后/降落。 The average case is because every time we look at an array, with positive probability we get rid of half of what remains.平均情况是因为每次我们查看一个数组时,我们都会以正概率摆脱剩下的一半。 The worst case is because every time we look at the array we got a pivot from, we will get rid of half that array but may have to binary search every other array to decide we got rid of nothing else.最坏的情况是因为每次我们查看数组时,我们都会从中得到一个 pivot,我们将删除该数组的一半,但可能必须对每个其他数组进行二分搜索才能确定我们没有删除其他任何内容。

from collections import deque

def kth_of_sorted (k, arrays):
    # Initialize some global variables.
    known_low = 0
    known_high = 0
    total_size = 0

    # in_flight will be a double-ended queue of
    # (array, iteration, i, j, min_i, min_j)
    # Where:
    #    array is an input array
    #    iteration is which median it was compared to
    #    low is the lower bound on where kth might be
    #    high is the upper bound on where kth might be
    in_flight = deque()

    for a in arrays:
        if 0 < len(a):
            total_size += len(a)
            in_flight.append((a, 0, len(a)-1))

    # Sanity check.
    if k < 1 or total_size < k:
        return None

    while 0 < len(in_flight):
        start_a, start_low, start_high = in_flight.popleft()
        start_mid = (start_low + start_high) // 2
        pivot = start_a[start_mid]

        # If pivot is placed, how many are known?
        maybe_low = start_mid - start_low
        maybe_high = start_high - start_mid

        # This will be arrays taken from in_flight with:
        #
        #    (array, low, high, orig_low, orig_high)
        #
        # We are binary searching in these to figure out where the pivot
        # is going to go. Then we copy back to in_flight.
        to_process = deque()

        # This will be arrays taken from in_flight with:
        #
        #    (array, orig_low, mid, orig_high)
        #
        # where at mid we match the pivot.
        is_match = deque()
        # And we know an array with a pivot!
        is_match.append((start_a, start_low, start_mid, start_high))

        # This will be arrays taken from in_flight which we know do not have the pivot:
        #
        #    (array, low, high, orig_low, orig_high)
        #
        no_pivot = deque()

        while 0 < len(in_flight):
            a, low, high = in_flight.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, low, high))
                else:
                    no_pivot.append((a, mid+1, high, low, high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, low, high))
                else:
                    no_pivot.append((a, low, mid-1, low, high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, low, mid, high))

        # We do not yet know where the pivot_pos is.
        pivot_pos = None
        if k <= known_low + maybe_low:
            pivot_pos = 'right'
        elif total_size - known_high - maybe_high < k:
            pivot_pos = 'left'
        elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
            return pivot # WE FOUND IT!

        while pivot_pos is None:
            # This is very similar to how we processed in_flight.
            a, low, high, orig_low, orig_high = to_process.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, orig_low, orig_high))
                else:
                    no_pivot.append((a, mid+1, high, orig_low, orig_high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, orig_low, orig_high))
                else:
                    no_pivot.append((a, low, mid-1, orig_low, orig_high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, orig_low, mid, orig_high))

            if k <= known_low + maybe_low:
                pivot_pos = 'right'
            elif total_size - known_high - maybe_high < k:
                pivot_pos = 'left'
                 a, low, high = in_flight.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, low, high))
                else:
                    no_pivot.append((a, mid+1, high, low, high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, low, high))
                else:
                    no_pivot.append((a, low, mid-1, low, high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, low, mid, high))

        # We do not yet know where the pivot_pos is.
        pivot_pos = None
        if k <= known_low + maybe_low:
            pivot_pos = 'right'
        elif total_size - known_high - maybe_high < k:
            pivot_pos = 'left'
        elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
            return pivot # WE FOUND IT!

        while pivot_pos is None:
            # This is very similar to how we processed in_flight.
            a, low, high, orig_low, orig_high = to_process.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, orig_low, orig_high))
                else:
                    no_pivot.append((a, mid+1, high, orig_low, orig_high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, orig_low, orig_high))
                else:
                    no_pivot.append((a, low, mid-1, orig_low, orig_high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, orig_low, mid, orig_high))

            if k <= known_low + maybe_low:
                pivot_pos = 'right'
            elif total_size - known_high - maybe_high < k:
                pivot_pos = 'left'
            elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
                return pivot # WE FOUND IT!
       elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
                return pivot # WE FOUND IT!

        # And now place the pivot in the right position.
        if pivot_pos == 'right':
            known_high += maybe_high + len(is_match)
            # And put back the left side of each nonemptied array.
            for q in (to_process, no_pivot):
                while 0 < len(q):
                    a, low, high, orig_low, orig_high = q.popleft()
                    if orig_low <= high:
                        in_flight.append((a, orig_low, high))
            while 0 < len(is_match):
                a, low, mid, high = is_match.popleft()
                if low < mid:
                    in_flight.append((a, low, mid-1))
        else:
            known_low += maybe_low + len(is_match)
            # And put back the right side of each nonemptied array.
            for q in (to_process, no_pivot):
                while 0 < len(q):
                    a, low, high, orig_low, orig_high = q.popleft()
                    if low <= orig_high:
                        in_flight.append((a, low, orig_high))
            while 0 < len(is_match):
                a, low, mid, high = is_match.popleft()
                if mid < high:
                    in_flight.append((a, mid+1, high))

list1 = [2,3,6,7,9]
list2 = [1,4,8,10]
list3 = [2,3,5,7]
print(list1, list2, list3)
for i in range(1, len(list1) + len(list2) + len(list3)):
    print(i, kth_of_sorted(i,[list1, list2, list3]))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM