簡體   English   中英

在二維數組上查找第 K 個最小元素(或中值)的最快算法?

[英]Fastest algorithm for Kth smallest Element (or median) finding on 2 Dimensional Array?

我在相關主題上看到了很多 SO 主題,但沒有一個提供有效的方法。

我想在二維數組[1..M][1..N]上找到第k-th最小元素(或中值),其中每行按升序排序並且所有元素都是不同的。

我認為有O(M log MN)解決方案,但我不知道實現。 (中位數的中位數或使用具有線性復雜性的分區是一些方法,但不再知道......)。

這是一個舊的谷歌面試問題,可以在這里搜索。

但現在我想提示或描述最有效的算法最快的算法)。

我也在這里讀過一篇論文,但我不明白。

更新 1: 此處找到一種解決方案,但當維度為奇數時。

所以要解決這個問題,它有助於解決一個稍微不同的問題。 我們想知道每行中第 k 個截止點所在位置的上限/下限。 那么我們就可以通過,驗證下界及以下的事物數<k,上界及以下的事物數>k,並且它們之間只有一個值。

我已經提出了一種策略,可以在所有行中同時對這些邊界進行二分搜索。 作為二分搜索,它“應該”通過O(log(n))次。 每次傳遞涉及O(m)工作,總共O(m log(n))次。 我把應該放在引號中,因為我沒有證據表明它實際上需要O(log(n))次傳遞。 事實上,有可能在一行中過於激進,從其他行中發現所選的樞軸已關閉,然后不得不后退。 但我相信它幾乎沒有后退,實際上是O(m log(n))

策略是跟蹤下限、上限和中間的每一行。 每次通過,我們都會對范圍進行一系列加權,以降低、降低到中、從中到上、從上到結尾,權重是其中的事物數量,值是系列中的最后一個。 然后我們在該數據結構中找到第 k 個值(按權重),並將其用作我們在每個維度中進行二分搜索的主元。

如果樞軸超出從下到上的范圍,我們通過在糾正錯誤的方向上加寬間隔來進行糾正。

當我們有正確的序列時,我們就有了答案。

有很多邊緣情況,所以盯着完整的代碼可能會有所幫助。

我還假設每一行的所有元素都是不同的。 如果不是,您可能會陷入無限循環。 (解決這意味着更多的邊緣情況......)

import random

# This takes (k, [(value1, weight1), (value2, weight2), ...])
def weighted_kth (k, pairs):
    # This does quickselect for average O(len(pairs)).
    # Median of medians is deterministically the same, but a bit slower
    pivot = pairs[int(random.random() * len(pairs))][0]

    # Which side of our answer is the pivot on?
    weight_under_pivot = 0
    pivot_weight = 0
    for value, weight in pairs:
        if value < pivot:
            weight_under_pivot += weight
        elif value == pivot:
            pivot_weight += weight

    if weight_under_pivot + pivot_weight < k:
        filtered_pairs = []
        for pair in pairs:
            if pivot < pair[0]:
                filtered_pairs.append(pair)
        return weighted_kth (k - weight_under_pivot - pivot_weight, filtered_pairs)
    elif k <= weight_under_pivot:
        filtered_pairs = []
        for pair in pairs:
            if pair[0] < pivot:
                filtered_pairs.append(pair)
        return weighted_kth (k, filtered_pairs)
    else:
        return pivot

# This takes (k, [[...], [...], ...])
def kth_in_row_sorted_matrix (k, matrix):
    # The strategy is to discover the k'th value, and also discover where
    # that would be in each row.
    #
    # For each row we will track what we think the lower and upper bounds
    # are on where it is.  Those bounds start as the start and end and
    # will do a binary search.
    #
    # In each pass we will break each row into ranges from start to lower,
    # lower to mid, mid to upper, and upper to end.  Some ranges may be
    # empty.  We will then create a weighted list of ranges with the weight
    # being the length, and the value being the end of the list.  We find
    # where the k'th spot is in that list, and use that approximate value
    # to refine each range.  (There is a chance that a range is wrong, and
    # we will have to deal with that.)
    #
    # We finish when all of the uppers are above our k, all the lowers
    # one are below, and the upper/lower gap is more than 1 only when our
    # k'th element is in the middle.

    # Our data structure is simply [row, lower, upper, bound] for each row.
    data = [[row, 0, min(k, len(row)-1), min(k, len(row)-1)] for row in matrix]
    is_search = True
    while is_search:
        pairs = []
        for row, lower, upper, bound in data:
            # Literal edge cases
            if 0 == upper:
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            elif lower == bound:
                pairs.append((row[lower], lower + 1))
            elif lower + 1 == upper: # No mid.
                pairs.append((row[lower], lower + 1))
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            else:
                mid = (upper + lower) // 2
                pairs.append((row[lower], lower + 1))
                pairs.append((row[mid], mid - lower))
                pairs.append((row[upper], upper - mid))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))

        pivot = weighted_kth(k, pairs)

        # Now that we have our pivot, we try to adjust our parameters.
        # If any adjusts we continue our search.
        is_search = False
        new_data = []
        for row, lower, upper, bound in data:
            # First cases where our bounds weren't bounds for our pivot.
            # We rebase the interval and either double the range.
            # - double the size of the range
            # - go halfway to the edge
            if 0 < lower and pivot <= row[lower]:
                is_search = True
                if pivot == row[lower]:
                    new_data.append((row, lower-1, min(lower+1, bound), bound))
                elif upper <= lower:
                    new_data.append((row, lower-1, lower, bound))
                else:
                    new_data.append((row, max(lower // 2, lower - 2*(upper - lower)), lower, bound))
            elif upper < bound and row[upper] <= pivot:
                is_search = True
                if pivot == row[upper]:
                    new_data.append((row, upper-1, upper+1, bound))
                elif lower < upper:
                    new_data.append((row, upper, min((upper+bound+1)//2, upper + 2*(upper - lower)), bound))
                else:
                    new_data.append((row, upper, upper+1, bound))
            elif lower + 1 < upper:
                if upper == lower+2 and pivot == row[lower+1]:
                    new_data.append((row, lower, upper, bound)) # Looks like we found the pivot.
                else:
                    # We will split this interval.
                    is_search = True
                    mid = (upper + lower) // 2
                    if row[mid] < pivot:
                        new_data.append((row, mid, upper, bound))
                    elif pivot < row[mid] pivot:
                        new_data.append((row, lower, mid, bound))
                    else:
                        # We center our interval on the pivot
                        new_data.append((row, (lower+mid)//2, (mid+upper+1)//2, bound))
            else:
                # We look like we found where the pivot would be in this row.
                new_data.append((row, lower, upper, bound))
        data = new_data # And set up the next search
    return pivot

已添加另一個答案以提供實際解決方案。 由於評論中有相當多的兔子洞,因此保留了這個。


我相信最快的解決方案是 k-way 合並算法。 它是一個O(N log K)算法,將K排序列表與總共N項目合並為一個大小為N排序列表。

https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge

給定一個MxN列表。 這最終是O(MNlog(M)) 但是,這是為了對整個列表進行排序。 由於您只需要前K最小的項目而不是所有N*M ,因此性能為O(Klog(M)) 假設O(K) <= O(M) ,這比您正在尋找的要好得多。

盡管這假設您有N大小為M排序列表。 如果您實際上有M大小為N排序列表,則可以通過更改循環數據的方式輕松處理(請參閱下面的偽代碼),盡管這確實意味着性能是O(K log(N))

k-way 合並只是將每個列表的第一項添加到堆或其他具有O(log N)插入和O(log N) find-mind 的數據結構中。

k-way合並的偽代碼看起來有點像這樣:

  1. 對於每個排序列表,將第一個值插入到數據結構中,並通過某種方式確定該值來自哪個列表。 IE:您可以將[value, row_index, col_index]插入到數據結構中,而不僅僅是value 這還可以讓您輕松處理列或行的循環。
  2. 從數據結構中刪除最低值並附加到排序列表。
  3. 鑒於在第2步中的項目來自列表I從列表中添加下一個最低值, I的數據結構。 IE:如果值為row 5 col 4 (data[5][4]) 然后,如果您將行用作列表,則下一個值將是row 5 col 5 (data[5][5]) 如果您使用的是列,則下一個值為row 6 col 4 (data[6][4]) 將下一個值插入到數據結構中,就像您在 #1 中所做的一樣(即: [value, row_index, col_index]
  4. 根據需要返回第 2 步。

根據您的需要,執行 2-4 K次步驟。

似乎最好的方法是在越來越大的塊中進行 k-way 合並。 k-way 合並試圖建立一個排序列表,但我們不需要它排序,我們不需要考慮每個元素。 相反,我們將創建一個半排序的區間。 間隔將被排序,但僅按最高值排序。

https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge

我們使用與 k-way 合並相同的方法,但有所不同。 基本上它旨在間接構建一個半排序的子列表。 例如,不是找到 [1,2,3,4,5,6,7,8,10] 來確定 K=10,而是會找到類似 [(1,3),(4,6), (7,15)]。 通過 K-way 合並,我們一次從每個列表中考慮 1 個項目。 在這種懸停方法中,當從給定列表中提取時,我們要首先考慮 Z 個項目,然后是 2 * Z 個項目,然后是 2 * 2 * Z 個項目,因此第 i 次是 2^i * Z 個項目。 給定一個 MxN 矩陣,這意味着它需要我們從列表中提取O(log(N))項目M次。

  1. 對於每個已排序的列表,將前K個子列表插入到數據結構中,並使用某種方法來確定值來自哪個列表。 我們希望數據結構使用我們插入的子列表中的最高值。 在這種情況下,我們需要類似於 [max_value of sublist, row index, start_index, end_index]。 O(m)
  2. 從數據結構中刪除最低值(現在是值列表)並附加到排序列表。 O(log (m))
  3. 鑒於在第2步中的項目來自列表I補充下2^i * Z值從列表I要在第i個時間從該特定列表拉(基本上只是增加一倍,這是存在於數量的數據結構子列表剛剛從數據結構中刪除)。 O(log m)
  4. 如果半排序子列表的大小大於 K,則使用二分查找找到第 k 個值。 O(log N)) 如果數據結構中還有任何子列表,其中最小值小於 k。 使用列表作為輸入轉到第 1 步,新的Kk - (size of semi-sorted list)
  5. 如果半排序子列表的大小等於K,則返回半排序子列表中的最后一個值,這是第K個值。
  6. 如果半排序子列表的大小小於 K,則返回步驟 2。

至於性能。 讓我們看看這里:

  • 花費O(m log m)將初始值添加到數據結構中。
  • 它最多需要考慮O(m)個子列表,每個子列表都需要O(log n)時間來實現`O(m log n)。
  • 它需要在最后執行二分搜索O(log m) ,如果不確定 K 的值是什么(步驟 4),它可能需要將問題簡化為遞歸子列表,但我不認為'會影響大 O。編輯:我相信這在最壞的情況下只會增加另一個O(mlog(n)) ,這對大 O 沒有影響。

所以看起來它是O(mlog(m) + mlog(n))或簡單的O(mlog(mn))

作為優化,如果 K 高於NM/2考慮最小值時考慮最大值,在考慮最大值時考慮最小值。 當 K 接近NM時,這將大大提高性能。

btillyNuclearman的答案提供了兩種不同的方法,一種二進制搜索和行的k 路合並

我的建議是結合這兩種方法。

  • 如果k足夠小(比方說小於M乘以 2 或 3)或足夠大(對於對稱,接近N x M ),請找到具有 M 行合並的k元素。 當然,我們不應該合並所有元素,只合並第一個k

  • 否則,開始檢查矩陣的第一列和最后一列,以找到最小值(女巫在第一列)和最大值(在最后一列)。

  • 將第一個關鍵值估計為這兩個值的線性組合。 類似於pivot = min + k * (max - min) / (N * M)

  • 在每一行中執行二分搜索以確定不大於樞軸的最后一個元素(更接近的元素)。 小於或等於主元的元素數量是簡單推導出來的。 將那些與k的總和進行比較將判斷選擇的樞軸值是太大還是太小,讓我們相應地修改它。 跟蹤所有行之間的最大值,它可能是第 k 個元素或僅用於評估下一個樞軸。 如果我們將所述總和視為主元的函數,那么現在的數字問題是找到sum(pivot) - k的零,這是一個單調(離散)函數。 在最壞的情況下,我們可以使用二分法(對數復雜度)或割線法。

  • 我們可以理想地將每一行划分為三個范圍:

    • 在左邊,肯定小於或等於第k元素的元素。
    • 在中間,未確定的范圍。
    • 在右邊,肯定大於第k元素的元素。
  • 不確定范圍將在每次迭代時減少,最終大多數行變為空。 在某些時候,仍然在未確定范圍內、散布在整個矩陣中的元素數量將小到足以訴諸這些范圍的單個 M 路合並。

  • 如果我們將單次迭代的時間復雜度視為O(MlogN)M 個二分搜索,我們需要將其乘以樞軸收斂到第k元素的值所需的迭代次數,這可以是O(logNM) 如果N > M ,則總和為O(MlogNlogM)O(MlogNlogN)

  • 請注意,如果該算法用於查找中值,則將 M 路合並作為最后一步也很容易找到第 ( k + 1)元素。

可能是我遺漏了一些東西,但是如果你的NxM矩陣AM行已經按升序排序,沒有元素重復,那么第k行的最小值只是從O(1)行中選擇第k個元素。 要移動到 2D,您只需選擇第k列,將其升序排序O(M.log(M))並再次選擇導致O(N.log(N)) k-th元素。

  1. 讓矩陣A[N][M]

    其中元素是A[column][row]

  2. 排序A升序k-thO(M.log(M))

    所以排序A[k][i]其中i = { 1,2,3,...M }升序

  3. 選擇A[k][k]作為結果

如果您想要A中所有元素的第 k 個最小,那么您需要以類似於合並排序的形式利用已經排序的行。

  1. 創建空列表c[]以保存k最小值

  2. 工藝柱

  3. 創建臨時數組b[]

    它保存處理過的列快速排序升序O(N.log(N))

  4. 合並c[]b[]所以c[]最多保存k最小值

    使用臨時數組d[]將導致O(k+n)

  5. 如果在合並期間未使用b任何項目,則停止處理列

    這可以通過添加標志數組f來完成,它將保存在合並期間從b,c的值,然后檢查是否從b中獲取了任何值

  6. 輸出c[k-1]

綜合起來,最終的復雜度是O(min(k,M).N.log(N))如果我們認為k小於M我們可以重寫為O(kNlog(N))否則O(MNlog(N)) 此外,平均而言,要迭代的列數將更不可能~(1+(k/N))所以平均復雜度將是~O(N.log(N))但這只是我的瘋狂猜測可能是錯誤的。

這里的小 C++/VCL 示例:

//$$---- Form CPP ----
//---------------------------------------------------------------------------
#include <vcl.h>
#pragma hdrstop
#include "Unit1.h"
#include "sorts.h"
//---------------------------------------------------------------------------
#pragma package(smart_init)
#pragma resource "*.dfm"
TForm1 *Form1;
//---------------------------------------------------------------------------
const int m=10,n=8; int a[m][n],a0[m][n]; // a[col][row]
//---------------------------------------------------------------------------
void generate()
    {
    int i,j,k,ii,jj,d=13,b[m];
    Randomize();
    RandSeed=0x12345678;
    // a,a0 = some distinct pseudorandom values (fully ordered asc)
    for (k=Random(d),j=0;j<n;j++)
     for (i=0;i<m;i++,k+=Random(d)+1)
      { a0[i][j]=k; a[i][j]=k; }
    // schuffle a
    for (j=0;j<n;j++)
     for (i=0;i<m;i++)
        {
        ii=Random(m);
        jj=Random(n);
        k=a[i][j]; a[i][j]=a[ii][jj]; a[ii][jj]=k;
        }
    // sort rows asc
    for (j=0;j<n;j++)
        {
        for (i=0;i<m;i++) b[i]=a[i][j];
        sort_asc_quick(b,m);
        for (i=0;i<m;i++) a[i][j]=b[i];
        }

    }
//---------------------------------------------------------------------------
int kmin(int k) // k-th min from a[m][n] where a rows are already sorted
    {
    int i,j,bi,ci,di,b[n],*c,*d,*e,*f,cn;
    c=new int[k+k+k]; d=c+k; f=d+k;
    // handle edge cases
    if (m<1) return -1;
    if (k>m*n) return -1;
    if (m==1) return a[0][k];
    // process columns
    for (cn=0,i=0;i<m;i++)
        {
        // b[] = sorted_asc a[i][]
        for (j=0;j<n;j++) b[j]=a[i][j];     // O(n)
        sort_asc_quick(b,n);                // O(n.log(n))
        // c[] = c[] + b[] asc sorted and limited to cn size
        for (bi=0,ci=0,di=0;;)              // O(k+n)
            {
                 if ((ci>=cn)&&(bi>=n)) break;
            else if (ci>=cn)     { d[di]=b[bi]; f[di]=1; bi++; di++; }
            else if (bi>= n)     { d[di]=c[ci]; f[di]=0; ci++; di++; }
            else if (b[bi]<c[ci]){ d[di]=b[bi]; f[di]=1; bi++; di++; }
            else                 { d[di]=c[ci]; f[di]=0; ci++; di++; }
            if (di>k) di=k;
            }
        e=c; c=d; d=e; cn=di;
        for (ci=0,j=0;j<cn;j++) ci|=f[j];   // O(k)
        if (!ci) break;
        }
    k=c[k-1];
    delete[] c;
    return k;
    }
//---------------------------------------------------------------------------
__fastcall TForm1::TForm1(TComponent* Owner):TForm(Owner)
    {
    int i,j,k;
    AnsiString txt="";

    generate();

    txt+="a0[][]\r\n";
    for (j=0;j<n;j++,txt+="\r\n")
     for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a0[i][j]);

    txt+="\r\na[][]\r\n";
    for (j=0;j<n;j++,txt+="\r\n")
     for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a[i][j]);

    k=20;
    txt+=AnsiString().sprintf("\r\n%ith smallest from a0 = %4i\r\n",k,a0[(k-1)%m][(k-1)/m]);
    txt+=AnsiString().sprintf("\r\n%ith smallest from a  = %4i\r\n",k,kmin(k));

    mm_log->Lines->Add(txt);
    }
//-------------------------------------------------------------------------

忽略 VCL 的東西。 函數 generate 計算a0, a矩陣,其中a0是完全排序的, a只對行進行排序並且所有值都是不同的。 函數kmin是上述算法,從a[m][n]返回第 k 個最小值用於排序,我使用了這個:

template <class T> void sort_asc_quick(T *a,int n)
    {
    int i,j; T a0,a1,p;
    if (n<=1) return;                                   // stop recursion
    if (n==2)                                           // edge case
        {
        a0=a[0];
        a1=a[1];
        if (a0>a1) { a[0]=a1; a[1]=a0; }                // condition
        return;
        }
    for (a0=a1=a[0],i=0;i<n;i++)                        // pivot = midle (should be median)
        {
        p=a[i];
        if (a0>p) a0=p;
        if (a1<p) a1=p;
        } if (a0==a1) return; p=(a0+a1+1)/2;            // if the same values stop
    if (a0==p) p++;
    for (i=0,j=n-1;i<=j;)                               // regroup
        {
        a0=a[i];
        if (a0<p) i++; else { a[i]=a[j]; a[j]=a0; j--; }// condition
        }
    sort_asc_quick(a  ,  i);                            // recursion a[]<=p
    sort_asc_quick(a+i,n-i);                            // recursion a[]> p
    }

這里的輸出:

a0[][]
  10   17   29   42   54   66   74   85   90  102 
 112  114  123  129  142  145  146  150  157  161 
 166  176  184  191  195  205  213  216  222  224 
 226  237  245  252  264  273  285  290  291  296 
 309  317  327  334  336  349  361  370  381  390 
 397  398  401  411  422  426  435  446  452  462 
 466  477  484  496  505  515  522  524  525  530 
 542  545  548  553  555  560  563  576  588  590 

a[][]
 114  142  176  264  285  317  327  422  435  466 
 166  336  349  381  452  477  515  530  542  553 
 157  184  252  273  291  334  446  524  545  563 
  17  145  150  237  245  290  370  397  484  576 
  42  129  195  205  216  309  398  411  505  560 
  10  102  123  213  222  224  226  390  496  555 
  29   74   85  146  191  361  426  462  525  590 
  54   66   90  112  161  296  401  522  548  588 

20th smallest from a0 =  161

20th smallest from a  =  161

這個例子只迭代了 5 列......

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM