简体   繁体   中英

What is the fastest way to get k smallest (or largest) elements of array in Java?

I have an array of elements (in the example, these are simply integers), which are compared using some custom comparator. In this example, I simulate this comparator by defining i SMALLER j if and only if scores[i] <= scores[j] .

I have two approaches:

  • using heap of the current k candidates
  • using array of the current k candidates

I update the upper two structures in the following way:

  • heap: methods PriorityQueue.poll and PriorityQueue.offer ,
  • array: index top of the worst among top k candidates in the array of candidates is stored. If a newly seen example is better than the element at the index top , the latter is replaced by the former and top is updated by iterating through all k elements of the array.

However, when I have tested, which of the approaches is faster, I found out that this is the second. The questions are:

  • Is my use of PriorityQueue suboptimal?
  • What is the fastest way to compute k smallest elements?

I am interested in the case, when the number of examples can be large, but the number of neighbours is relatively small (between 10 and 20).

Here is the code:

public static void main(String[] args) {
    long kopica, navadno, sortiranje;

    int numTries = 10000;
    int numExamples = 1000;
    int numNeighbours = 10;

    navadno = testSimple(numExamples, numNeighbours, numTries);
    kopica = testHeap(numExamples, numNeighbours, numTries);

    sortiranje = testSort(numExamples, numNeighbours, numTries, false);
    System.out.println(String.format("tries: %d examples: %d neighbours: %d\n time heap[ms]: %d\n time simple[ms]: %d", numTries, numExamples, numNeighbours, kopica, navadno));
}

public static long testHeap(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        PriorityQueue<Integer> myHeap = new PriorityQueue(numberNeighbours, new Comparator<Integer>(){
            @Override
            public int compare(Integer o1, Integer o2) {
                return -Double.compare(scores[o1], scores[o2]);
            }
        });

        int top;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                myHeap.offer(i);
            } else{
                top = myHeap.peek();
                if(scores[top] > scores[i]){
                    myHeap.poll();
                    myHeap.offer(i);
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

public static long testSimple(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        int[] candidates = new int[numberNeighbours];
        int top = 0;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                candidates[i] = i;
                if(scores[candidates[top]] < scores[candidates[i]]) top = i;
            } else{
                if(scores[candidates[top]] > scores[i]){
                    candidates[top] = i;
                    top = 0;
                    for(int j = 1; j < numberNeighbours; j++){
                        if(scores[candidates[top]] < scores[candidates[j]]) top = j;                            
                    }
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

This produces the following result:

tries: 10000 examples: 1000 neighbours: 10
   time heap[ms]: 393
   time simple[ms]: 388

Creating the fastest algorithm is never simple, you need to consider many things. For example the k elements need to be returned sorted or not, your research needs to be stable (if two elements are equals you need to extract before the first or is not necessary) or not?

In this contest theoretically the best solution is to save the k smallest element in an ordered data structure. Because insertion can happens often in the middle of this data structure a balanced sorted Tree seems to be an optimal solution.

But reality is very different from that.

Probably a mix between different data structure depending on the size of the original array and the value of k is the best solution:

  • If k is little use an array to save the k smallest values
  • If k is big use a balanced tree
  • If k is very big and close to the dimension of the array, simply sort the array (and if you can't create a new sorted copy of it), then extract the first k elements.

This kind of algorithm is named hibryd algorithm . A famous hybrid algorithm is Tim Sort that is used in the java classes to sort collections.

Note: If you can use the power of multithreading different algorithms and so different data structures can be used.


Additional note on micro benchmark . Your performance measure can be strongly influenced by external factors not related to the efficiency of your algorithm. Creating objects, as you do in both functions, can need memory that is not available asking for an extra work done by the GC. This kind of factors influence very much you results. At least try to minimize code that is not strongly related to portion of code to be investigated. Repeat tests in different orders, Wait before invoking tests to be sure that no GC is in action.

First solution has time complexity O(numberExamples * log numberNeighbours) , while second is O(numberExamples * numberNeighbours) , so it has to be slower for large enough input. The second solution is faster because you test for small numberNeighbours , and PriorityQueue has bigger overhead that simple array. You use PriorityQueue optimal.

Faster, but not optimal, would be just to sort an array, and then smallest elements is at k place.

Anyway you may want to implement QuickSelect algorithm, if you will choose pivot element smartly you should have better performance. You may want to see this https://discuss.leetcode.com/topic/55501/2ms-java-quick-select-only-2-points-to-mention

First of all, your benchmarking method is incorrect. You are measuring input data creation along with an algorithm performance, and you aren't warming up the JVM before measuring. Results for your code, when tested through the JMH :

Benchmark                     Mode  Cnt      Score   Error  Units
CounterBenchmark.testHeap    thrpt    2  18103,296          ops/s
CounterBenchmark.testSimple  thrpt    2  59490,384          ops/s

Modified benchmark pastebin .

Regarding 3x times difference between two provided solutions. In the terms of big-O notation your first algorithm may seem better, but in fact big-O notation only tells your how good the algorithm is in the terms of scaling, it never tells you how fast it performs (see this question also). And in your case scaling is not the issue, as your numNeighbours is limited to 20. In other words big-O notation describes how many ticks of algorithm is necessary for it to complete, but it doesn't limit the duration of a tick, it just says that the tick duration doesn't change when inputs change. And in terms of tick complexity your second algorithm surely wins.

What is the fastest way to compute k smallest elements?

I've came up with the next solution which I do believe allows branch prediction to do its job:

@Benchmark
public void testModified(Blackhole bh) {
    final double[] scores = sampleData;
    int[] candidates = new int[numberNeighbours];
    for (int i = 0; i < numberNeighbours; i++) {
        candidates[i] = i;
    }
    // sorting candidates so scores[candidates[0]] is the largest
    for (int i = 0; i < numberNeighbours; i++) {
        for (int j = i+1; j < numberNeighbours; j++) {
            if (scores[candidates[i]] < scores[candidates[j]]) {
                int temp = candidates[i];
                candidates[i] = candidates[j];
                candidates[j] = temp;
            }
        }
    }
    // processing other scores, while keeping candidates array sorted in the descending order
    for (int i = numberNeighbours; i < numberExamples; i++) {
        if (scores[i] > scores[candidates[0]]) {
            continue;
        }
        // moving all larger candidates to the left, to keep the array sorted
        int j; // here the branch prediction should kick-in
        for (j = 1; j < numberNeighbours && scores[i] < scores[candidates[j]]; j++) {
            candidates[j - 1] = candidates[j];
        }
        // inserting the new item
        candidates[j - 1] = i;
    }
    bh.consume(candidates);
}

Benchmark results (2x times faster than your current solution):

(10 neighbours) CounterBenchmark.testModified    thrpt    2  136492,151          ops/s
(20 neighbours) CounterBenchmark.testModified    thrpt    2  118395,598          ops/s

Others mentioned quickselect , but as one may expect, the complexity of that algorithm neglects its strong sides in your case:

@Benchmark
public void testQuickSelect(Blackhole bh) {
    final int[] candidates = new int[sampleData.length];
    for (int i = 0; i < candidates.length; i++) {
        candidates[i] = i;
    }
    final int[] resultIndices = new int[numberNeighbours];
    int neighboursToAdd = numberNeighbours;

    int left = 0;
    int right = candidates.length - 1;
    while (neighboursToAdd > 0) {
        int partitionIndex = partition(candidates, left, right);
        int smallerItemsPartitioned = partitionIndex - left;
        if (smallerItemsPartitioned <= neighboursToAdd) {
            while (left < partitionIndex) {
                resultIndices[numberNeighbours - neighboursToAdd--] = candidates[left++];
            }
        } else {
            right = partitionIndex - 1;
        }
    }
    bh.consume(resultIndices);
}

private int partition(int[] locations, int left, int right) {
    final int pivotIndex = ThreadLocalRandom.current().nextInt(left, right + 1);
    final double pivotValue = sampleData[locations[pivotIndex]];
    int storeIndex = left;
    for (int i = left; i <= right; i++) {
        if (sampleData[locations[i]] <= pivotValue) {
            final int temp = locations[storeIndex];
            locations[storeIndex] = locations[i];
            locations[i] = temp;

            storeIndex++;
        }
    }
    return storeIndex;
}

Benchmark results are pretty upsetting in this case:

CounterBenchmark.testQuickSelect  thrpt    2   11586,761          ops/s

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM