简体   繁体   English

O(n)最坏情况时间的2D峰值发现算法?

[英]2D peak finding algorithm in O(n) worst case time?

I was doing this course on algorithms from MIT. 我在做从MIT算法课程。 In the very first lecture the professor presents the following problem:- 在第一堂课中,教授提出了以下问题: -

A peak in a 2D array is a value such that all it's 4 neighbours are less than or equal to it, ie. 2D阵列中的峰值是使得它的所有4个邻居都小于或等于它的值,即。 for 对于

a[i][j] to be a local maximum, a[i][j]是局部最大值,

a[i+1][j] <= a[i][j] 
&& a[i-1][j] <= a[i][j]
&& a[i][j+1] <= a[i][j]
&& a[i+1][j-1] <= a[i][j]

Now given an NxN 2D array, find a peak in the array. 现在给定一个NxN 2D阵列,在阵列中找到一个峰值

This question can be easily solved in O(N^2) time by iterating over all the elements and returning a peak. 通过迭代所有元素并返回峰值,可以在O(N^2)时间内轻松解决该问题。

However it can be optimized to be solved in O(NlogN) time by using a divide and conquer solution as explained here . 然而,它可以被优化所要解决O(NlogN)通过使用分而治之溶液作为解释时间这里

But they have said that there exists an O(N) time algorithm that solves this problem. 但是他们说有一种O(N)时间算法可以解决这个问题。 Please suggest how can we solve this problem in O(N) time. 请建议我们如何在O(N)时间内解决这个问题。

PS(For those who know python) The course staff has explained an approach here (Problem 1-5. Peak-Finding Proof) and also provided some python code in their problem sets. PS(对于那些了解python的人)课程工作人员在这里解释了一种方法(问题1-5。峰值证明),并在他们的问题集中提供了一些python代码。 But the approach explained is totally non-obvious and very hard to decipher. 但解释的方法完全不明显,很难破译。 The python code is equally confusing. python代码同样令人困惑。 So I have copied the main part of the code below for those who know python and can tell what algorithm is being used from the code. 所以我已经为那些了解python的人复制了下面代码的主要部分,并且可以从代码中告诉我们使用了什么算法。

def algorithm4(problem, bestSeen = None, rowSplit = True, trace = None):
    # if it's empty, we're done 
    if problem.numRow <= 0 or problem.numCol <= 0:
        return None

    subproblems = []
    divider = []

    if rowSplit:
        # the recursive subproblem will involve half the number of rows
        mid = problem.numRow // 2

        # information about the two subproblems
        (subStartR1, subNumR1) = (0, mid)
        (subStartR2, subNumR2) = (mid + 1, problem.numRow - (mid + 1))
        (subStartC, subNumC) = (0, problem.numCol)

        subproblems.append((subStartR1, subStartC, subNumR1, subNumC))
        subproblems.append((subStartR2, subStartC, subNumR2, subNumC))

        # get a list of all locations in the dividing column
        divider = crossProduct([mid], range(problem.numCol))
    else:
        # the recursive subproblem will involve half the number of columns
        mid = problem.numCol // 2

        # information about the two subproblems
        (subStartR, subNumR) = (0, problem.numRow)
        (subStartC1, subNumC1) = (0, mid)
        (subStartC2, subNumC2) = (mid + 1, problem.numCol - (mid + 1))

        subproblems.append((subStartR, subStartC1, subNumR, subNumC1))
        subproblems.append((subStartR, subStartC2, subNumR, subNumC2))

        # get a list of all locations in the dividing column
        divider = crossProduct(range(problem.numRow), [mid])

    # find the maximum in the dividing row or column
    bestLoc = problem.getMaximum(divider, trace)
    neighbor = problem.getBetterNeighbor(bestLoc, trace)

    # update the best we've seen so far based on this new maximum
    if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
        bestSeen = neighbor
        if not trace is None: trace.setBestSeen(bestSeen)

    # return when we know we've found a peak
    if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
        if not trace is None: trace.foundPeak(bestLoc)
        return bestLoc

    # figure out which subproblem contains the largest number we've seen so
    # far, and recurse, alternating between splitting on rows and splitting
    # on columns
    sub = problem.getSubproblemContaining(subproblems, bestSeen)
    newBest = sub.getLocationInSelf(problem, bestSeen)
    if not trace is None: trace.setProblemDimensions(sub)
    result = algorithm4(sub, newBest, not rowSplit, trace)
    return problem.getLocationInSelf(sub, result)

#Helper Method
def crossProduct(list1, list2):
    """
    Returns all pairs with one item from the first list and one item from 
    the second list.  (Cartesian product of the two lists.)

    The code is equivalent to the following list comprehension:
        return [(a, b) for a in list1 for b in list2]
    but for easier reading and analysis, we have included more explicit code.
    """

    answer = []
    for a in list1:
        for b in list2:
            answer.append ((a, b))
    return answer
  1. Let's assume that width of the array is bigger than height, otherwise we will split in another direction. 让我们假设数组的宽度大于高度,否则我们将分裂到另一个方向。
  2. Split the array into three parts: central column, left side and right side. 将阵列分成三部分:中央列,左侧和右侧。
  3. Go through the central column and two neighbour columns and look for maximum. 浏览中心列和两个相邻列并查找最大值。
    • If it's in the central column - this is our peak 如果它在中央列 - 这是我们的高峰期
    • If it's in the left side, run this algorithm on subarray left_side + central_column 如果它在左侧,则在子阵列left_side + central_column上运行此算法
    • If it's in the right side, run this algorithm on subarray right_side + central_column 如果它在右侧,则在子阵列right_side + central_column上运行此算法

Why this works: 为什么这样有效:

For cases where the maximum element is in the central column - obvious. 对于最大元素位于中心列的情况 - 显而易见。 If it's not, we can step from that maximum to increasing elements and will definitely not cross the central row, so a peak will definitely exist in the corresponding half. 如果不是,我们可以从最大值逐步增加到增加的元素,并且绝对不会越过中心行,因此相应的一半肯定存在峰值。

Why this is O(n): 为什么这是O(n):

step #3 takes less than or equal to max_dimension iterations and max_dimension at least halves on every two algorithm steps. 步骤#3在每两个算法步骤中采用小于或等于max_dimension迭代并且max_dimension至少减半。 This gives n+n/2+n/4+... which is O(n) . 这给出n+n/2+n/4+... ,即O(n) Important detail: we split by the maximum direction. 重要细节:我们按最大方向划分。 For square arrays this means that split directions will be alternating. 对于方形阵列,这意味着分割方向将是交替的。 This is a difference from the last attempt in the PDF you linked to. 这与您链接到的PDF中的最后一次尝试有所不同。

A note: I'm not sure if it exactly matches the algorithm in the code you gave, it may or may not be a different approach. 注意:我不确定它是否与您给出的代码中的算法完全匹配,它可能是也可能不是一种不同的方法。

Here is the working Java code that implements @maxim1000 's algorithm. 这是实现@ maxim1000算法的工作Java代码 The following code finds a peak in the 2D array in linear time. 以下代码在线性时间内在2D数组中找到峰值。

import java.util.*;

class Ideone{
    public static void main (String[] args) throws java.lang.Exception{
        new Ideone().run();
    }
    int N , M ;

    void run(){
        N = 1000;
        M = 100;

        // arr is a random NxM array
        int[][] arr = randomArray();
        long start = System.currentTimeMillis();
//      for(int i=0; i<N; i++){   // TO print the array. 
//          System. out.println(Arrays.toString(arr[i]));
//      }
        System.out.println(findPeakLinearTime(arr));
        long end = System.currentTimeMillis();
        System.out.println("time taken : " + (end-start));
    }

    int findPeakLinearTime(int[][] arr){
        int rows = arr.length;
        int cols = arr[0].length;
        return kthLinearColumn(arr, 0, cols-1, 0, rows-1);
    }

    // helper function that splits on the middle Column
    int kthLinearColumn(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loCol==hiCol){
            int max = arr[loRow][loCol];
            int foundRow = loRow;
            for(int row = loRow; row<=hiRow; row++){
                if(max < arr[row][loCol]){
                    max = arr[row][loCol];
                    foundRow = row;
                }
            }
            if(!correctPeak(arr, foundRow, loCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return max;
        }
        int midCol = (loCol+hiCol)/2;
        int max = arr[loRow][loCol];
        for(int row=loRow; row<=hiRow; row++){
            max = Math.max(max, arr[row][midCol]);
        }
        boolean centralMax = true;
        boolean rightMax = false;
        boolean leftMax  = false;

        if(midCol-1 >= 0){
            for(int row = loRow; row<=hiRow; row++){
                if(arr[row][midCol-1] > max){
                    max = arr[row][midCol-1];
                    centralMax = false;
                    leftMax = true;
                }
            }
        }

        if(midCol+1 < M){
            for(int row=loRow; row<=hiRow; row++){
                if(arr[row][midCol+1] > max){
                    max = arr[row][midCol+1];
                    centralMax = false;
                    leftMax = false;
                    rightMax = true;
                }
            }
        }

        if(centralMax) return max;
        if(rightMax)  return kthLinearRow(arr, midCol+1, hiCol, loRow, hiRow);
        if(leftMax)   return kthLinearRow(arr, loCol, midCol-1, loRow, hiRow);
        throw new RuntimeException("INCORRECT CODE");
    }

    // helper function that splits on the middle 
    int kthLinearRow(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loRow==hiRow){
            int ans = arr[loCol][loRow];
            int foundCol = loCol;
            for(int col=loCol; col<=hiCol; col++){
                if(arr[loRow][col] > ans){
                    ans = arr[loRow][col];
                    foundCol = col;
                }
            }
            if(!correctPeak(arr, loRow, foundCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return ans;
        }
        boolean centralMax = true;
        boolean upperMax = false;
        boolean lowerMax = false;

        int midRow = (loRow+hiRow)/2;
        int max = arr[midRow][loCol];

        for(int col=loCol; col<=hiCol; col++){
            max = Math.max(max, arr[midRow][col]);
        }

        if(midRow-1>=0){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow-1][col] > max){
                    max = arr[midRow-1][col];
                    upperMax = true;
                    centralMax = false;
                }
            }
        }

        if(midRow+1<N){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow+1][col] > max){
                    max = arr[midRow+1][col];
                    lowerMax = true;
                    centralMax = false;
                    upperMax   = false;
                }
            }
        }

        if(centralMax) return max;
        if(lowerMax)   return kthLinearColumn(arr, loCol, hiCol, midRow+1, hiRow);
        if(upperMax)   return kthLinearColumn(arr, loCol, hiCol, loRow, midRow-1);
        throw new RuntimeException("Incorrect code");
    }

    int[][] randomArray(){
        int[][] arr = new int[N][M];
        for(int i=0; i<N; i++)
            for(int j=0; j<M; j++)
                arr[i][j] = (int)(Math.random()*1000000000);
        return arr;
    }

    boolean correctPeak(int[][] arr, int row, int col){//Function that checks if arr[row][col] is a peak or not
        if(row-1>=0 && arr[row-1][col]>arr[row][col])  return false;
        if(row+1<N && arr[row+1][col]>arr[row][col])   return false;
        if(col-1>=0 && arr[row][col-1]>arr[row][col])  return false;
        if(col+1<M && arr[row][col+1]>arr[row][col])   return false;
        return true;
    }
}

To see thata(n): 要看那个(n):

Calculation step is in the picture 计算步骤如图所示

To see algorithm implementation: 要查看算法实现:

1) start with either 1a) or 1b) 1)从1a)或1b)开始

1a) set left half, divider, right half. 1a)设置左半部分,分隔线,右半部分。

1b) set top half, divider, bottom half. 1b)设置上半部分,分隔线,下半部分。

2) Find global maximum on the divider. 2)在分频器上找到全局最大值。 [theta n] [theta n]

3) Find the values of its neighbour. 3)找到其邻居的值。 And record the largest node ever visited as the bestSeen node. 并记录有史以来最大的节点作为bestSeen节点。 [theta 1] [1]

# update the best we've seen so far based on this new maximum
if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
    bestSeen = neighbor
    if not trace is None: trace.setBestSeen(bestSeen)

4) check if the global maximum is larger than the bestSeen and its neighbour. 4)检查全局最大值是否大于bestSeen及其邻居。 [theta 1] [1]

//Step 4 is the main key of why this algorithm works //第4步是此算法工作原理的主要关键

# return when we know we've found a peak
if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
    if not trace is None: trace.foundPeak(bestLoc)
    return bestLoc

5) If 4) is True, return the global maximum as 2-D peak. 5)如果4)为True,则将全局最大值返回为2-D峰值。

Else if this time did 1a), choose the half of BestSeen, go back to step 1b) 如果这次是1a),选择BestSeen的一半,返回步骤1b)

Else, choose the half of BestSeen, go back to step 1a) 否则,选择BestSeen的一半,回到步骤1a)


To see visually why this algorithm works, it is like grabbing the greatest value side, keep reducing the boundaries and eventually get the BestSeen value. 为了直观地看到这个算法的工作原理,就像抓住最大的价值面,不断缩小边界,最终得到BestSeen值。

# Visualised simulation #可视化模拟

round1 第1轮

round2 round2

round3 round3

round4 round4

round5 round5

round6 round6

finally 最后

For this 10*10 matrix, we used only 6 steps to search for the 2-D peak, its quite convincing that it is indeed theta n 对于这个10 * 10矩阵,我们只使用了6个步骤来搜索2-D峰值,它非常有说服力,它确实是theta n


By Falcon 通过猎鹰

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM