简体   繁体   中英

Java fork-join Performance

I have sample implementations for Merge-Sort, one using Fork-Join and other is straight recursive function.

It looks like fork-join is slower than straight recursive, why?

import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

class DivideTask extends RecursiveTask<int[]> {
    private static final long serialVersionUID = -7017440434091885703L;
    int[] arrayToDivide;

    public DivideTask(int[] arrayToDivide) {
        this.arrayToDivide = arrayToDivide;
    }

    @Override
    protected int[] compute() {
        //List<RecursiveTask> forkedTasks = new ArrayList<>();

        /*
         * We divide the array till it has only 1 element. 
         * We can also custom define this value to say some 
         * 5 elements. In which case the return would be
         * Arrays.sort(arrayToDivide) instead.
         */
        if (arrayToDivide.length > 1) {

            List<int[]> partitionedArray = partitionArray();

            DivideTask task1 = new DivideTask(partitionedArray.get(0));
            DivideTask task2 = new DivideTask(partitionedArray.get(1));
            invokeAll(task1, task2);

            //Wait for results from both the tasks
            int[] array1 = task1.join();
            int[] array2 = task2.join();

            //Initialize a merged array
            int[] mergedArray = new int[array1.length + array2.length];

            mergeArrays(task1.join(), task2.join(), mergedArray);

            return mergedArray;
        }
        return arrayToDivide;
    }

    private void mergeArrays(int[] array1, int[] array2, int[] mergedArray) {

        int i = 0, j = 0, k = 0;

        while ((i < array1.length) && (j < array2.length)) {

            if (array1[i] < array2[j]) {
                mergedArray[k] = array1[i++];
            } else {
                mergedArray[k] = array2[j++];
            }

            k++;
        }

        if (i == array1.length) {
            for (int a = j; a < array2.length; a++) {
                mergedArray[k++] = array2[a];
            }
        } else {
            for (int a = i; a < array1.length; a++) {
                mergedArray[k++] = array1[a];
            }
        }
    }

    private List<int[]> partitionArray() {
        int[] partition1 = Arrays.copyOfRange(arrayToDivide, 0, arrayToDivide.length / 2);

        int[] partition2 = Arrays.copyOfRange(arrayToDivide, arrayToDivide.length / 2, arrayToDivide.length);
        return Arrays.asList(partition1, partition2);
    }
}

public class ForkJoinTest {
    static int[] numbers;
    static final int SIZE = 1_000_000;
    static final int MAX = 20;

    public static void main(String[] args) {
        setUp();

        testMergeSortByFJ();
        testMergeSort();
    }

    static void setUp() {
        numbers = new int[SIZE];
        Random generator = new Random();
        for (int i = 0; i < numbers.length; i++) {
            numbers[i] = generator.nextInt(MAX);
        }
    }

    static void testMergeSort() {
        long startTime = System.currentTimeMillis();

        Mergesort sorter = new Mergesort();
        sorter.sort(numbers);

        long stopTime = System.currentTimeMillis();
        long elapsedTime = stopTime - startTime;
        System.out.println("Mergesort Time:" + elapsedTime + " msec");
    }

    static void testMergeSortByFJ() {
        //System.out.println("Unsorted array: " + Arrays.toString(numbers));
        long t1 = System.currentTimeMillis();
        DivideTask task = new DivideTask(numbers);
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        forkJoinPool.invoke(task);
        //System.out.println("Sorted array: " + Arrays.toString(task.join()));
        System.out.println("Fork-Join Time:" + (System.currentTimeMillis() - t1) + " msec");
    }
 }

class Mergesort {
    private int[] msNumbers;
    private int[] helper;

    private int number;

    private void merge(int low, int middle, int high) {

        // Copy both parts into the helper array
        for (int i = low; i <= high; i++) {
            helper[i] = msNumbers[i];
        }

        int i = low;
        int j = middle + 1;
        int k = low;
        // Copy the smallest values from either the left or the right side back
        // to the original array
        while (i <= middle && j <= high) {
            if (helper[i] <= helper[j]) {
                msNumbers[k] = helper[i];
                i++;
            } else {
                msNumbers[k] = helper[j];
                j++;
            }
            k++;
        }
        // Copy the rest of the left side of the array into the target array
        while (i <= middle) {
            msNumbers[k] = helper[i];
            k++;
            i++;
        }

    }

    private void mergesort(int low, int high) {
        // Check if low is smaller then high, if not then the array is sorted
        if (low < high) {
            // Get the index of the element which is in the middle
            int middle = low + (high - low) / 2;
            // Sort the left side of the array
            mergesort(low, middle);
            // Sort the right side of the array
            mergesort(middle + 1, high);
            // Combine them both
            merge(low, middle, high);
        }
    }

    public void sort(int[] values) {
        this.msNumbers = values;
        number = values.length;
        this.helper = new int[number];
        mergesort(0, number - 1);
    }
}

IMHO the main reason is not the overhead due to thread spawning and pooling.

I think the multi-threaded version runs slow mainly because you are continuously creating new arrays , all the times, millions of times. Eventually, you create 1 million of arrays with a single element, a headache for the garbage collector.

All your DivideTask s can just operate on different portions of the array (the two halves), so just send them a range and make them operate on that range.

Furthermore, your parallelization strategy makes it impossible to use the clever "helper array" optimization (notice the helper array in the sequential version). This optimization swaps the "input" array with a "helper" array on which merges are made, so that a new array shouldn't be created for every merge operation: a memory-saving technique that you can't do if you don't parallelize by level of the recursion tree .

For a classwork, I had to parallelize MergeSort and I managed to get a nice speedup by parallelizing by level of the recursion tree. Unfortunately the code is in C and uses OpenMP. If you want I can provide it.

As gd1 points out, you're doing a lot of array allocation and copying; this is going to cost you. You should instead work on different sections of the same one array, taking care that no subtask works on a section that another subtask is working on.

But beyond that, there's a certain amount of overhead that comes with the fork/join approach (as with any concurrency). In fact, if you look at the javadocs for RecursiveTask , they even point out that their simple example will perform slowly because the forking is too granular.

Long story short, you should have fewer subdivisions, each of which does more. More generally, any time you have more non-blocked threads than cores, throughput is not going to improve, and in fact overhead will start to chip away at it.

Without looking too deeply at your code, spawning a new thread is costly. If you haven't got much work to do, then it's not worth it for performance reasons alone. Very generally talking here, but a single thread could loop thousands of times before a new thread is spawned and starts running (especially on Windows).

Please refer to Doug Lea's paper (under 2. DESIGN) where he states:

"However, the java.lang.Thread class (as well as POSIX pthreads, upon which Java threads are often based) are suboptimal vehicles for supporting fork/join programs"

还发现了有关使用Fork / Join Dan Grossman的Fork / Join简介的以下信息

I encountered the same issue. In my implementation of merge sort I only copy the right part which might be shorter than the left part. Also I skip possible max elements in the right part while copying and merging. Even with that optimization the parallel implementation is still slower than an iterative implementation. According to Leetcode, my iterative approach is faster than 91.96%, my parallel implementation is faster than 56.59%.

  1. https://leetcode.com/problems/sort-an-array/submissions/878182467/ . https://leetcode.com/problems/sort-an-array/submissions/877801578/ .
import java.util.concurrent.RecursiveAction;

class Solution {

    public static class Sort extends RecursiveAction {
        private int[] a;
        private int left;
        private int right;

        public Sort(int[] a, int left, int right) {
            this.a = a;
            this.left = left;
            this.right = right;
        }

        @Override
        protected void compute() {
            int m = (left + right) / 2;
            if (m >= left + 1) {
                Sort leftHalf = new Sort(a, left, m);
                leftHalf.fork();
                Sort rightHalf = new Sort(a, m+1, right);
                rightHalf.compute();
                leftHalf.join();
            }
            merge(a, left, right, m);
        }

        private void merge(int[] a, int left, int right, int mid) {
            if (left == right || left + 1 == right && a[left] <= a[right])
                return;
            // let l point to last element of left half, r point to last element of right half
            int l = mid, r = right;
            // skip possible max elements
            while (l < r && a[l] <= a[r])
                r -= 1;
            // size of remaining right half
            int size = r-l;
            int[] buf = new int[size];
            for (int i = 0; i < size; i++){
                buf[i] = a[mid + 1 + i];
            }
            int i = size-1;
            while (i >= 0) {
                if (l >= left && a[ l] >buf[i]) {
                    a[r] = a[l];
                    l -= 1;
                } else {
                    a[r] = buf[i];
                    i -= 1;
                }
                r -= 1;
            }
        }
    }

    public int[] sortArray(int[] a) {
        ForkJoinPool threadPool = ForkJoinPool.commonPool();
        threadPool.invoke(new Sort(a, 0, a.length-1));
        return a;
    }
}

Iterative implementation:
class Solution {
    public int[] sortArray(int[] a) {
        int[] buf = new int[a.length];
        int size = 1;
        while (size < a.length) {
            int left = 0 + size - 1;
            int right = Math.min(left + size, a.length-1);
            while (left < a.length) {
                merge(a, size, left, right, buf);
                left += 2 * size;
                right = Math.min(left + size, a.length-1);
            }
            size *= 2;
        }
        return a;
    }

    private void merge(int[] a, int size, int l, int r, int[] buf) {
        int terminal1 = l - size;
        int right = r;
        while (l < right && a[l] <= a[right])
            right--;
        r = right;
        int rsize = right - l;
        for (int i = rsize-1; i >= 0; i--) {
            buf[i] = a[right--];
        }
        int i = rsize-1;
        while (i >= 0) {
            if (l > terminal1 && a[l] > buf[i]) {
                a[r--] = a[l--];
            } else {
                a[r--] = buf[i--];
            }
        }
    }
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM