簡體   English   中英

O(n)最壞情況時間的2D峰值發現算法?

[英]2D peak finding algorithm in O(n) worst case time?

我在做從MIT算法課程。 在第一堂課中,教授提出了以下問題: -

2D陣列中的峰值是使得它的所有4個鄰居都小於或等於它的值,即。 對於

a[i][j]是局部最大值,

a[i+1][j] <= a[i][j] 
&& a[i-1][j] <= a[i][j]
&& a[i][j+1] <= a[i][j]
&& a[i+1][j-1] <= a[i][j]

現在給定一個NxN 2D陣列,在陣列中找到一個峰值

通過迭代所有元素並返回峰值,可以在O(N^2)時間內輕松解決該問題。

然而,它可以被優化所要解決O(NlogN)通過使用分而治之溶液作為解釋時間這里

但是他們說有一種O(N)時間算法可以解決這個問題。 請建議我們如何在O(N)時間內解決這個問題。

PS(對於那些了解python的人)課程工作人員在這里解釋了一種方法(問題1-5。峰值證明),並在他們的問題集中提供了一些python代碼。 但解釋的方法完全不明顯,很難破譯。 python代碼同樣令人困惑。 所以我已經為那些了解python的人復制了下面代碼的主要部分,並且可以從代碼中告訴我們使用了什么算法。

def algorithm4(problem, bestSeen = None, rowSplit = True, trace = None):
    # if it's empty, we're done 
    if problem.numRow <= 0 or problem.numCol <= 0:
        return None

    subproblems = []
    divider = []

    if rowSplit:
        # the recursive subproblem will involve half the number of rows
        mid = problem.numRow // 2

        # information about the two subproblems
        (subStartR1, subNumR1) = (0, mid)
        (subStartR2, subNumR2) = (mid + 1, problem.numRow - (mid + 1))
        (subStartC, subNumC) = (0, problem.numCol)

        subproblems.append((subStartR1, subStartC, subNumR1, subNumC))
        subproblems.append((subStartR2, subStartC, subNumR2, subNumC))

        # get a list of all locations in the dividing column
        divider = crossProduct([mid], range(problem.numCol))
    else:
        # the recursive subproblem will involve half the number of columns
        mid = problem.numCol // 2

        # information about the two subproblems
        (subStartR, subNumR) = (0, problem.numRow)
        (subStartC1, subNumC1) = (0, mid)
        (subStartC2, subNumC2) = (mid + 1, problem.numCol - (mid + 1))

        subproblems.append((subStartR, subStartC1, subNumR, subNumC1))
        subproblems.append((subStartR, subStartC2, subNumR, subNumC2))

        # get a list of all locations in the dividing column
        divider = crossProduct(range(problem.numRow), [mid])

    # find the maximum in the dividing row or column
    bestLoc = problem.getMaximum(divider, trace)
    neighbor = problem.getBetterNeighbor(bestLoc, trace)

    # update the best we've seen so far based on this new maximum
    if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
        bestSeen = neighbor
        if not trace is None: trace.setBestSeen(bestSeen)

    # return when we know we've found a peak
    if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
        if not trace is None: trace.foundPeak(bestLoc)
        return bestLoc

    # figure out which subproblem contains the largest number we've seen so
    # far, and recurse, alternating between splitting on rows and splitting
    # on columns
    sub = problem.getSubproblemContaining(subproblems, bestSeen)
    newBest = sub.getLocationInSelf(problem, bestSeen)
    if not trace is None: trace.setProblemDimensions(sub)
    result = algorithm4(sub, newBest, not rowSplit, trace)
    return problem.getLocationInSelf(sub, result)

#Helper Method
def crossProduct(list1, list2):
    """
    Returns all pairs with one item from the first list and one item from 
    the second list.  (Cartesian product of the two lists.)

    The code is equivalent to the following list comprehension:
        return [(a, b) for a in list1 for b in list2]
    but for easier reading and analysis, we have included more explicit code.
    """

    answer = []
    for a in list1:
        for b in list2:
            answer.append ((a, b))
    return answer
  1. 讓我們假設數組的寬度大於高度,否則我們將分裂到另一個方向。
  2. 將陣列分成三部分:中央列,左側和右側。
  3. 瀏覽中心列和兩個相鄰列並查找最大值。
    • 如果它在中央列 - 這是我們的高峰期
    • 如果它在左側,則在子陣列left_side + central_column上運行此算法
    • 如果它在右側,則在子陣列right_side + central_column上運行此算法

為什么這樣有效:

對於最大元素位於中心列的情況 - 顯而易見。 如果不是,我們可以從最大值逐步增加到增加的元素,並且絕對不會越過中心行,因此相應的一半肯定存在峰值。

為什么這是O(n):

步驟#3在每兩個算法步驟中采用小於或等於max_dimension迭代並且max_dimension至少減半。 這給出n+n/2+n/4+... ,即O(n) 重要細節:我們按最大方向划分。 對於方形陣列,這意味着分割方向將是交替的。 這與您鏈接到的PDF中的最后一次嘗試有所不同。

注意:我不確定它是否與您給出的代碼中的算法完全匹配,它可能是也可能不是一種不同的方法。

這是實現@ maxim1000算法的工作Java代碼 以下代碼在線性時間內在2D數組中找到峰值。

import java.util.*;

class Ideone{
    public static void main (String[] args) throws java.lang.Exception{
        new Ideone().run();
    }
    int N , M ;

    void run(){
        N = 1000;
        M = 100;

        // arr is a random NxM array
        int[][] arr = randomArray();
        long start = System.currentTimeMillis();
//      for(int i=0; i<N; i++){   // TO print the array. 
//          System. out.println(Arrays.toString(arr[i]));
//      }
        System.out.println(findPeakLinearTime(arr));
        long end = System.currentTimeMillis();
        System.out.println("time taken : " + (end-start));
    }

    int findPeakLinearTime(int[][] arr){
        int rows = arr.length;
        int cols = arr[0].length;
        return kthLinearColumn(arr, 0, cols-1, 0, rows-1);
    }

    // helper function that splits on the middle Column
    int kthLinearColumn(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loCol==hiCol){
            int max = arr[loRow][loCol];
            int foundRow = loRow;
            for(int row = loRow; row<=hiRow; row++){
                if(max < arr[row][loCol]){
                    max = arr[row][loCol];
                    foundRow = row;
                }
            }
            if(!correctPeak(arr, foundRow, loCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return max;
        }
        int midCol = (loCol+hiCol)/2;
        int max = arr[loRow][loCol];
        for(int row=loRow; row<=hiRow; row++){
            max = Math.max(max, arr[row][midCol]);
        }
        boolean centralMax = true;
        boolean rightMax = false;
        boolean leftMax  = false;

        if(midCol-1 >= 0){
            for(int row = loRow; row<=hiRow; row++){
                if(arr[row][midCol-1] > max){
                    max = arr[row][midCol-1];
                    centralMax = false;
                    leftMax = true;
                }
            }
        }

        if(midCol+1 < M){
            for(int row=loRow; row<=hiRow; row++){
                if(arr[row][midCol+1] > max){
                    max = arr[row][midCol+1];
                    centralMax = false;
                    leftMax = false;
                    rightMax = true;
                }
            }
        }

        if(centralMax) return max;
        if(rightMax)  return kthLinearRow(arr, midCol+1, hiCol, loRow, hiRow);
        if(leftMax)   return kthLinearRow(arr, loCol, midCol-1, loRow, hiRow);
        throw new RuntimeException("INCORRECT CODE");
    }

    // helper function that splits on the middle 
    int kthLinearRow(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loRow==hiRow){
            int ans = arr[loCol][loRow];
            int foundCol = loCol;
            for(int col=loCol; col<=hiCol; col++){
                if(arr[loRow][col] > ans){
                    ans = arr[loRow][col];
                    foundCol = col;
                }
            }
            if(!correctPeak(arr, loRow, foundCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return ans;
        }
        boolean centralMax = true;
        boolean upperMax = false;
        boolean lowerMax = false;

        int midRow = (loRow+hiRow)/2;
        int max = arr[midRow][loCol];

        for(int col=loCol; col<=hiCol; col++){
            max = Math.max(max, arr[midRow][col]);
        }

        if(midRow-1>=0){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow-1][col] > max){
                    max = arr[midRow-1][col];
                    upperMax = true;
                    centralMax = false;
                }
            }
        }

        if(midRow+1<N){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow+1][col] > max){
                    max = arr[midRow+1][col];
                    lowerMax = true;
                    centralMax = false;
                    upperMax   = false;
                }
            }
        }

        if(centralMax) return max;
        if(lowerMax)   return kthLinearColumn(arr, loCol, hiCol, midRow+1, hiRow);
        if(upperMax)   return kthLinearColumn(arr, loCol, hiCol, loRow, midRow-1);
        throw new RuntimeException("Incorrect code");
    }

    int[][] randomArray(){
        int[][] arr = new int[N][M];
        for(int i=0; i<N; i++)
            for(int j=0; j<M; j++)
                arr[i][j] = (int)(Math.random()*1000000000);
        return arr;
    }

    boolean correctPeak(int[][] arr, int row, int col){//Function that checks if arr[row][col] is a peak or not
        if(row-1>=0 && arr[row-1][col]>arr[row][col])  return false;
        if(row+1<N && arr[row+1][col]>arr[row][col])   return false;
        if(col-1>=0 && arr[row][col-1]>arr[row][col])  return false;
        if(col+1<M && arr[row][col+1]>arr[row][col])   return false;
        return true;
    }
}

要看那個(n):

計算步驟如圖所示

要查看算法實現:

1)從1a)或1b)開始

1a)設置左半部分,分隔線,右半部分。

1b)設置上半部分,分隔線,下半部分。

2)在分頻器上找到全局最大值。 [theta n]

3)找到其鄰居的值。 並記錄有史以來最大的節點作為bestSeen節點。 [1]

# update the best we've seen so far based on this new maximum
if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
    bestSeen = neighbor
    if not trace is None: trace.setBestSeen(bestSeen)

4)檢查全局最大值是否大於bestSeen及其鄰居。 [1]

//第4步是此算法工作原理的主要關鍵

# return when we know we've found a peak
if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
    if not trace is None: trace.foundPeak(bestLoc)
    return bestLoc

5)如果4)為True,則將全局最大值返回為2-D峰值。

如果這次是1a),選擇BestSeen的一半,返回步驟1b)

否則,選擇BestSeen的一半,回到步驟1a)


為了直觀地看到這個算法的工作原理,就像抓住最大的價值面,不斷縮小邊界,最終得到BestSeen值。

#可視化模擬

第1輪

round2

round3

round4

round5

round6

最后

對於這個10 * 10矩陣,我們只使用了6個步驟來搜索2-D峰值,它非常有說服力,它確實是theta n


通過獵鷹

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM