[英]Fastest algorithm for Kth smallest Element (or median) finding on 2 Dimensional Array?
所以要解決這個問題,它有助於解決一個稍微不同的問題。 我們想知道每行中第 k 個截止點所在位置的上限/下限。 那么我們就可以通過,驗證下界及以下的事物數<k,上界及以下的事物數>k,並且它們之間只有一個值。
我已經提出了一種策略,可以在所有行中同時對這些邊界進行二分搜索。 作為二分搜索,它“應該”通過O(log(n))
次。 每次傳遞涉及O(m)
工作,總共O(m log(n))
次。 我把應該放在引號中,因為我沒有證據表明它實際上需要O(log(n))
次傳遞。 事實上,有可能在一行中過於激進,從其他行中發現所選的樞軸已關閉,然后不得不后退。 但我相信它幾乎沒有后退,實際上是O(m log(n))
。
策略是跟蹤下限、上限和中間的每一行。 每次通過,我們都會對范圍進行一系列加權,以降低、降低到中、從中到上、從上到結尾,權重是其中的事物數量,值是系列中的最后一個。 然后我們在該數據結構中找到第 k 個值(按權重),並將其用作我們在每個維度中進行二分搜索的主元。
如果樞軸超出從下到上的范圍,我們通過在糾正錯誤的方向上加寬間隔來進行糾正。
當我們有正確的序列時,我們就有了答案。
有很多邊緣情況,所以盯着完整的代碼可能會有所幫助。
我還假設每一行的所有元素都是不同的。 如果不是,您可能會陷入無限循環。 (解決這意味着更多的邊緣情況......)
import random
# This takes (k, [(value1, weight1), (value2, weight2), ...])
def weighted_kth (k, pairs):
# This does quickselect for average O(len(pairs)).
# Median of medians is deterministically the same, but a bit slower
pivot = pairs[int(random.random() * len(pairs))][0]
# Which side of our answer is the pivot on?
weight_under_pivot = 0
pivot_weight = 0
for value, weight in pairs:
if value < pivot:
weight_under_pivot += weight
elif value == pivot:
pivot_weight += weight
if weight_under_pivot + pivot_weight < k:
filtered_pairs = []
for pair in pairs:
if pivot < pair[0]:
filtered_pairs.append(pair)
return weighted_kth (k - weight_under_pivot - pivot_weight, filtered_pairs)
elif k <= weight_under_pivot:
filtered_pairs = []
for pair in pairs:
if pair[0] < pivot:
filtered_pairs.append(pair)
return weighted_kth (k, filtered_pairs)
else:
return pivot
# This takes (k, [[...], [...], ...])
def kth_in_row_sorted_matrix (k, matrix):
# The strategy is to discover the k'th value, and also discover where
# that would be in each row.
#
# For each row we will track what we think the lower and upper bounds
# are on where it is. Those bounds start as the start and end and
# will do a binary search.
#
# In each pass we will break each row into ranges from start to lower,
# lower to mid, mid to upper, and upper to end. Some ranges may be
# empty. We will then create a weighted list of ranges with the weight
# being the length, and the value being the end of the list. We find
# where the k'th spot is in that list, and use that approximate value
# to refine each range. (There is a chance that a range is wrong, and
# we will have to deal with that.)
#
# We finish when all of the uppers are above our k, all the lowers
# one are below, and the upper/lower gap is more than 1 only when our
# k'th element is in the middle.
# Our data structure is simply [row, lower, upper, bound] for each row.
data = [[row, 0, min(k, len(row)-1), min(k, len(row)-1)] for row in matrix]
is_search = True
while is_search:
pairs = []
for row, lower, upper, bound in data:
# Literal edge cases
if 0 == upper:
pairs.append((row[upper], 1))
if upper < bound:
pairs.append((row[bound], bound - upper))
elif lower == bound:
pairs.append((row[lower], lower + 1))
elif lower + 1 == upper: # No mid.
pairs.append((row[lower], lower + 1))
pairs.append((row[upper], 1))
if upper < bound:
pairs.append((row[bound], bound - upper))
else:
mid = (upper + lower) // 2
pairs.append((row[lower], lower + 1))
pairs.append((row[mid], mid - lower))
pairs.append((row[upper], upper - mid))
if upper < bound:
pairs.append((row[bound], bound - upper))
pivot = weighted_kth(k, pairs)
# Now that we have our pivot, we try to adjust our parameters.
# If any adjusts we continue our search.
is_search = False
new_data = []
for row, lower, upper, bound in data:
# First cases where our bounds weren't bounds for our pivot.
# We rebase the interval and either double the range.
# - double the size of the range
# - go halfway to the edge
if 0 < lower and pivot <= row[lower]:
is_search = True
if pivot == row[lower]:
new_data.append((row, lower-1, min(lower+1, bound), bound))
elif upper <= lower:
new_data.append((row, lower-1, lower, bound))
else:
new_data.append((row, max(lower // 2, lower - 2*(upper - lower)), lower, bound))
elif upper < bound and row[upper] <= pivot:
is_search = True
if pivot == row[upper]:
new_data.append((row, upper-1, upper+1, bound))
elif lower < upper:
new_data.append((row, upper, min((upper+bound+1)//2, upper + 2*(upper - lower)), bound))
else:
new_data.append((row, upper, upper+1, bound))
elif lower + 1 < upper:
if upper == lower+2 and pivot == row[lower+1]:
new_data.append((row, lower, upper, bound)) # Looks like we found the pivot.
else:
# We will split this interval.
is_search = True
mid = (upper + lower) // 2
if row[mid] < pivot:
new_data.append((row, mid, upper, bound))
elif pivot < row[mid] pivot:
new_data.append((row, lower, mid, bound))
else:
# We center our interval on the pivot
new_data.append((row, (lower+mid)//2, (mid+upper+1)//2, bound))
else:
# We look like we found where the pivot would be in this row.
new_data.append((row, lower, upper, bound))
data = new_data # And set up the next search
return pivot
已添加另一個答案以提供實際解決方案。 由於評論中有相當多的兔子洞,因此保留了這個。
我相信最快的解決方案是 k-way 合並算法。 它是一個O(N log K)
算法,將K
排序列表與總共N
項目合並為一個大小為N
排序列表。
https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge
給定一個MxN
列表。 這最終是O(MNlog(M))
。 但是,這是為了對整個列表進行排序。 由於您只需要前K
最小的項目而不是所有N*M
,因此性能為O(Klog(M))
。 假設O(K) <= O(M)
,這比您正在尋找的要好得多。
盡管這假設您有N
大小為M
排序列表。 如果您實際上有M
大小為N
排序列表,則可以通過更改循環數據的方式輕松處理(請參閱下面的偽代碼),盡管這確實意味着性能是O(K log(N))
。
k-way 合並只是將每個列表的第一項添加到堆或其他具有O(log N)
插入和O(log N)
find-mind 的數據結構中。
k-way合並的偽代碼看起來有點像這樣:
[value, row_index, col_index]
插入到數據結構中,而不僅僅是value
。 這還可以讓您輕松處理列或行的循環。I
從列表中添加下一個最低值, I
的數據結構。 IE:如果值為row 5 col 4 (data[5][4])
。 然后,如果您將行用作列表,則下一個值將是row 5 col 5 (data[5][5])
。 如果您使用的是列,則下一個值為row 6 col 4 (data[6][4])
。 將下一個值插入到數據結構中,就像您在 #1 中所做的一樣(即: [value, row_index, col_index]
) 根據您的需要,執行 2-4 K
次步驟。
似乎最好的方法是在越來越大的塊中進行 k-way 合並。 k-way 合並試圖建立一個排序列表,但我們不需要它排序,我們不需要考慮每個元素。 相反,我們將創建一個半排序的區間。 間隔將被排序,但僅按最高值排序。
https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge
我們使用與 k-way 合並相同的方法,但有所不同。 基本上它旨在間接構建一個半排序的子列表。 例如,不是找到 [1,2,3,4,5,6,7,8,10] 來確定 K=10,而是會找到類似 [(1,3),(4,6), (7,15)]。 通過 K-way 合並,我們一次從每個列表中考慮 1 個項目。 在這種懸停方法中,當從給定列表中提取時,我們要首先考慮 Z 個項目,然后是 2 * Z 個項目,然后是 2 * 2 * Z 個項目,因此第 i 次是 2^i * Z 個項目。 給定一個 MxN 矩陣,這意味着它需要我們從列表中提取O(log(N))
項目M
次。
K
個子列表插入到數據結構中,並使用某種方法來確定值來自哪個列表。 我們希望數據結構使用我們插入的子列表中的最高值。 在這種情況下,我們需要類似於 [max_value of sublist, row index, start_index, end_index]。 O(m)
O(log (m))
I
補充下2^i * Z
值從列表I
要在第i個時間從該特定列表拉(基本上只是增加一倍,這是存在於數量的數據結構子列表剛剛從數據結構中刪除)。 O(log m)
O(log N))
。 如果數據結構中還有任何子列表,其中最小值小於 k。 使用列表作為輸入轉到第 1 步,新的K
為k - (size of semi-sorted list)
。至於性能。 讓我們看看這里:
O(m log m)
將初始值添加到數據結構中。O(m)
個子列表,每個子列表都需要O(log n)
時間來實現`O(m log n)。O(log m)
,如果不確定 K 的值是什么(步驟 4),它可能需要將問題簡化為遞歸子列表,但我不認為'會影響大 O。編輯:我相信這在最壞的情況下只會增加另一個O(mlog(n))
,這對大 O 沒有影響。 所以看起來它是O(mlog(m) + mlog(n))
或簡單的O(mlog(mn))
。
作為優化,如果 K 高於NM/2
考慮最小值時考慮最大值,在考慮最大值時考慮最小值。 當 K 接近NM
時,這將大大提高性能。
btilly和Nuclearman的答案提供了兩種不同的方法,一種二進制搜索和行的k 路合並。
我的建議是結合這兩種方法。
如果k足夠小(比方說小於M乘以 2 或 3)或足夠大(對於對稱,接近N x M ),請找到具有 M 行合並的第k個元素。 當然,我們不應該合並所有元素,只合並第一個k 。
否則,開始檢查矩陣的第一列和最后一列,以找到最小值(女巫在第一列)和最大值(在最后一列)。
將第一個關鍵值估計為這兩個值的線性組合。 類似於pivot = min + k * (max - min) / (N * M)
。
在每一行中執行二分搜索以確定不大於樞軸的最后一個元素(更接近的元素)。 小於或等於主元的元素數量是簡單推導出來的。 將那些與k的總和進行比較將判斷選擇的樞軸值是太大還是太小,讓我們相應地修改它。 跟蹤所有行之間的最大值,它可能是第 k 個元素或僅用於評估下一個樞軸。 如果我們將所述總和視為主元的函數,那么現在的數字問題是找到sum(pivot) - k
的零,這是一個單調(離散)函數。 在最壞的情況下,我們可以使用二分法(對數復雜度)或割線法。
我們可以理想地將每一行划分為三個范圍:
不確定范圍將在每次迭代時減少,最終大多數行變為空。 在某些時候,仍然在未確定范圍內、散布在整個矩陣中的元素數量將小到足以訴諸這些范圍的單個 M 路合並。
如果我們將單次迭代的時間復雜度視為O(MlogN)
或M 個二分搜索,我們需要將其乘以樞軸收斂到第k個元素的值所需的迭代次數,這可以是O(logNM)
。 如果N > M ,則總和為O(MlogNlogM)
或O(MlogNlogN)
。
請注意,如果該算法用於查找中值,則將 M 路合並作為最后一步也很容易找到第 ( k + 1)個元素。
可能是我遺漏了一些東西,但是如果你的NxM
矩陣A
有M
行已經按升序排序,沒有元素重復,那么第k
行的最小值只是從O(1)
行中選擇第k
個元素。 要移動到 2D,您只需選擇第k
列,將其升序排序O(M.log(M))
並再次選擇導致O(N.log(N))
k-th
元素。
讓矩陣A[N][M]
其中元素是A[column][row]
排序A
升序k-th
列O(M.log(M))
所以排序A[k][i]
其中i = { 1,2,3,...M }
升序
選擇A[k][k]
作為結果
如果您想要A
中所有元素的第 k 個最小,那么您需要以類似於合並排序的形式利用已經排序的行。
創建空列表c[]
以保存k
最小值
工藝柱
創建臨時數組b[]
它保存處理過的列快速排序升序O(N.log(N))
合並c[]
和b[]
所以c[]
最多保存k
最小值
使用臨時數組d[]
將導致O(k+n)
如果在合並期間未使用b
任何項目,則停止處理列
這可以通過添加標志數組f
來完成,它將保存在合並期間從b,c
的值,然后檢查是否從b
中獲取了任何值
輸出c[k-1]
綜合起來,最終的復雜度是O(min(k,M).N.log(N))
如果我們認為k
小於M
我們可以重寫為O(kNlog(N))
否則O(MNlog(N))
。 此外,平均而言,要迭代的列數將更不可能~(1+(k/N))
所以平均復雜度將是~O(N.log(N))
但這只是我的瘋狂猜測可能是錯誤的。
這里的小 C++/VCL 示例:
//$$---- Form CPP ----
//---------------------------------------------------------------------------
#include <vcl.h>
#pragma hdrstop
#include "Unit1.h"
#include "sorts.h"
//---------------------------------------------------------------------------
#pragma package(smart_init)
#pragma resource "*.dfm"
TForm1 *Form1;
//---------------------------------------------------------------------------
const int m=10,n=8; int a[m][n],a0[m][n]; // a[col][row]
//---------------------------------------------------------------------------
void generate()
{
int i,j,k,ii,jj,d=13,b[m];
Randomize();
RandSeed=0x12345678;
// a,a0 = some distinct pseudorandom values (fully ordered asc)
for (k=Random(d),j=0;j<n;j++)
for (i=0;i<m;i++,k+=Random(d)+1)
{ a0[i][j]=k; a[i][j]=k; }
// schuffle a
for (j=0;j<n;j++)
for (i=0;i<m;i++)
{
ii=Random(m);
jj=Random(n);
k=a[i][j]; a[i][j]=a[ii][jj]; a[ii][jj]=k;
}
// sort rows asc
for (j=0;j<n;j++)
{
for (i=0;i<m;i++) b[i]=a[i][j];
sort_asc_quick(b,m);
for (i=0;i<m;i++) a[i][j]=b[i];
}
}
//---------------------------------------------------------------------------
int kmin(int k) // k-th min from a[m][n] where a rows are already sorted
{
int i,j,bi,ci,di,b[n],*c,*d,*e,*f,cn;
c=new int[k+k+k]; d=c+k; f=d+k;
// handle edge cases
if (m<1) return -1;
if (k>m*n) return -1;
if (m==1) return a[0][k];
// process columns
for (cn=0,i=0;i<m;i++)
{
// b[] = sorted_asc a[i][]
for (j=0;j<n;j++) b[j]=a[i][j]; // O(n)
sort_asc_quick(b,n); // O(n.log(n))
// c[] = c[] + b[] asc sorted and limited to cn size
for (bi=0,ci=0,di=0;;) // O(k+n)
{
if ((ci>=cn)&&(bi>=n)) break;
else if (ci>=cn) { d[di]=b[bi]; f[di]=1; bi++; di++; }
else if (bi>= n) { d[di]=c[ci]; f[di]=0; ci++; di++; }
else if (b[bi]<c[ci]){ d[di]=b[bi]; f[di]=1; bi++; di++; }
else { d[di]=c[ci]; f[di]=0; ci++; di++; }
if (di>k) di=k;
}
e=c; c=d; d=e; cn=di;
for (ci=0,j=0;j<cn;j++) ci|=f[j]; // O(k)
if (!ci) break;
}
k=c[k-1];
delete[] c;
return k;
}
//---------------------------------------------------------------------------
__fastcall TForm1::TForm1(TComponent* Owner):TForm(Owner)
{
int i,j,k;
AnsiString txt="";
generate();
txt+="a0[][]\r\n";
for (j=0;j<n;j++,txt+="\r\n")
for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a0[i][j]);
txt+="\r\na[][]\r\n";
for (j=0;j<n;j++,txt+="\r\n")
for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a[i][j]);
k=20;
txt+=AnsiString().sprintf("\r\n%ith smallest from a0 = %4i\r\n",k,a0[(k-1)%m][(k-1)/m]);
txt+=AnsiString().sprintf("\r\n%ith smallest from a = %4i\r\n",k,kmin(k));
mm_log->Lines->Add(txt);
}
//-------------------------------------------------------------------------
忽略 VCL 的東西。 函數 generate 計算a0, a
矩陣,其中a0
是完全排序的, a
只對行進行排序並且所有值都是不同的。 函數kmin
是上述算法,從a[m][n]
返回第 k 個最小值用於排序,我使用了這個:
template <class T> void sort_asc_quick(T *a,int n)
{
int i,j; T a0,a1,p;
if (n<=1) return; // stop recursion
if (n==2) // edge case
{
a0=a[0];
a1=a[1];
if (a0>a1) { a[0]=a1; a[1]=a0; } // condition
return;
}
for (a0=a1=a[0],i=0;i<n;i++) // pivot = midle (should be median)
{
p=a[i];
if (a0>p) a0=p;
if (a1<p) a1=p;
} if (a0==a1) return; p=(a0+a1+1)/2; // if the same values stop
if (a0==p) p++;
for (i=0,j=n-1;i<=j;) // regroup
{
a0=a[i];
if (a0<p) i++; else { a[i]=a[j]; a[j]=a0; j--; }// condition
}
sort_asc_quick(a , i); // recursion a[]<=p
sort_asc_quick(a+i,n-i); // recursion a[]> p
}
這里的輸出:
a0[][]
10 17 29 42 54 66 74 85 90 102
112 114 123 129 142 145 146 150 157 161
166 176 184 191 195 205 213 216 222 224
226 237 245 252 264 273 285 290 291 296
309 317 327 334 336 349 361 370 381 390
397 398 401 411 422 426 435 446 452 462
466 477 484 496 505 515 522 524 525 530
542 545 548 553 555 560 563 576 588 590
a[][]
114 142 176 264 285 317 327 422 435 466
166 336 349 381 452 477 515 530 542 553
157 184 252 273 291 334 446 524 545 563
17 145 150 237 245 290 370 397 484 576
42 129 195 205 216 309 398 411 505 560
10 102 123 213 222 224 226 390 496 555
29 74 85 146 191 361 426 462 525 590
54 66 90 112 161 296 401 522 548 588
20th smallest from a0 = 161
20th smallest from a = 161
這個例子只迭代了 5 列......
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.