简体   繁体   English

Java fork-join 性能

[英]Java fork-join Performance

I have sample implementations for Merge-Sort, one using Fork-Join and other is straight recursive function.我有 Merge-Sort 的示例实现,一个使用 Fork-Join,另一个是直接递归 function。

It looks like fork-join is slower than straight recursive, why?看起来 fork-join 比直接递归慢,为什么?

import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

class DivideTask extends RecursiveTask<int[]> {
    private static final long serialVersionUID = -7017440434091885703L;
    int[] arrayToDivide;

    public DivideTask(int[] arrayToDivide) {
        this.arrayToDivide = arrayToDivide;
    }

    @Override
    protected int[] compute() {
        //List<RecursiveTask> forkedTasks = new ArrayList<>();

        /*
         * We divide the array till it has only 1 element. 
         * We can also custom define this value to say some 
         * 5 elements. In which case the return would be
         * Arrays.sort(arrayToDivide) instead.
         */
        if (arrayToDivide.length > 1) {

            List<int[]> partitionedArray = partitionArray();

            DivideTask task1 = new DivideTask(partitionedArray.get(0));
            DivideTask task2 = new DivideTask(partitionedArray.get(1));
            invokeAll(task1, task2);

            //Wait for results from both the tasks
            int[] array1 = task1.join();
            int[] array2 = task2.join();

            //Initialize a merged array
            int[] mergedArray = new int[array1.length + array2.length];

            mergeArrays(task1.join(), task2.join(), mergedArray);

            return mergedArray;
        }
        return arrayToDivide;
    }

    private void mergeArrays(int[] array1, int[] array2, int[] mergedArray) {

        int i = 0, j = 0, k = 0;

        while ((i < array1.length) && (j < array2.length)) {

            if (array1[i] < array2[j]) {
                mergedArray[k] = array1[i++];
            } else {
                mergedArray[k] = array2[j++];
            }

            k++;
        }

        if (i == array1.length) {
            for (int a = j; a < array2.length; a++) {
                mergedArray[k++] = array2[a];
            }
        } else {
            for (int a = i; a < array1.length; a++) {
                mergedArray[k++] = array1[a];
            }
        }
    }

    private List<int[]> partitionArray() {
        int[] partition1 = Arrays.copyOfRange(arrayToDivide, 0, arrayToDivide.length / 2);

        int[] partition2 = Arrays.copyOfRange(arrayToDivide, arrayToDivide.length / 2, arrayToDivide.length);
        return Arrays.asList(partition1, partition2);
    }
}

public class ForkJoinTest {
    static int[] numbers;
    static final int SIZE = 1_000_000;
    static final int MAX = 20;

    public static void main(String[] args) {
        setUp();

        testMergeSortByFJ();
        testMergeSort();
    }

    static void setUp() {
        numbers = new int[SIZE];
        Random generator = new Random();
        for (int i = 0; i < numbers.length; i++) {
            numbers[i] = generator.nextInt(MAX);
        }
    }

    static void testMergeSort() {
        long startTime = System.currentTimeMillis();

        Mergesort sorter = new Mergesort();
        sorter.sort(numbers);

        long stopTime = System.currentTimeMillis();
        long elapsedTime = stopTime - startTime;
        System.out.println("Mergesort Time:" + elapsedTime + " msec");
    }

    static void testMergeSortByFJ() {
        //System.out.println("Unsorted array: " + Arrays.toString(numbers));
        long t1 = System.currentTimeMillis();
        DivideTask task = new DivideTask(numbers);
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        forkJoinPool.invoke(task);
        //System.out.println("Sorted array: " + Arrays.toString(task.join()));
        System.out.println("Fork-Join Time:" + (System.currentTimeMillis() - t1) + " msec");
    }
 }

class Mergesort {
    private int[] msNumbers;
    private int[] helper;

    private int number;

    private void merge(int low, int middle, int high) {

        // Copy both parts into the helper array
        for (int i = low; i <= high; i++) {
            helper[i] = msNumbers[i];
        }

        int i = low;
        int j = middle + 1;
        int k = low;
        // Copy the smallest values from either the left or the right side back
        // to the original array
        while (i <= middle && j <= high) {
            if (helper[i] <= helper[j]) {
                msNumbers[k] = helper[i];
                i++;
            } else {
                msNumbers[k] = helper[j];
                j++;
            }
            k++;
        }
        // Copy the rest of the left side of the array into the target array
        while (i <= middle) {
            msNumbers[k] = helper[i];
            k++;
            i++;
        }

    }

    private void mergesort(int low, int high) {
        // Check if low is smaller then high, if not then the array is sorted
        if (low < high) {
            // Get the index of the element which is in the middle
            int middle = low + (high - low) / 2;
            // Sort the left side of the array
            mergesort(low, middle);
            // Sort the right side of the array
            mergesort(middle + 1, high);
            // Combine them both
            merge(low, middle, high);
        }
    }

    public void sort(int[] values) {
        this.msNumbers = values;
        number = values.length;
        this.helper = new int[number];
        mergesort(0, number - 1);
    }
}

IMHO the main reason is not the overhead due to thread spawning and pooling. 恕我直言,主要原因不是由于线程产生和汇集而产生的开销。

I think the multi-threaded version runs slow mainly because you are continuously creating new arrays , all the times, millions of times. 我认为多线程版本运行缓慢主要是因为您不断创建新阵列数百万次。 Eventually, you create 1 million of arrays with a single element, a headache for the garbage collector. 最终,您使用单个元素创建了100万个数组,这对垃圾收集器来说是一个令人头痛的问题。

All your DivideTask s can just operate on different portions of the array (the two halves), so just send them a range and make them operate on that range. 所有DivideTask都可以在阵列的不同部分(两半)上运行,因此只需向它们发送一个范围并让它们在该范围内运行。

Furthermore, your parallelization strategy makes it impossible to use the clever "helper array" optimization (notice the helper array in the sequential version). 此外,您的并行化策略使得无法使用聪明的“辅助数组”优化(注意顺序版本中的helper数组)。 This optimization swaps the "input" array with a "helper" array on which merges are made, so that a new array shouldn't be created for every merge operation: a memory-saving technique that you can't do if you don't parallelize by level of the recursion tree . 这种优化将“输入”数组与一个“辅助”数组进行交换,在该数组上进行合并,因此不应为每个合并操作创建一个新数组:一种节省内存的技术,如果你不这样做就不能做t 按递归树的级别并行化。

For a classwork, I had to parallelize MergeSort and I managed to get a nice speedup by parallelizing by level of the recursion tree. 对于一个类作品,我不得不并行化MergeSort,并通过递归树的级别并行化来设法获得一个很好的加速。 Unfortunately the code is in C and uses OpenMP. 不幸的是,代码在C中并使用OpenMP。 If you want I can provide it. 如果你想我可以提供它。

As gd1 points out, you're doing a lot of array allocation and copying; 正如gd1指出的那样,你正在进行大量的数组分配和复制; this is going to cost you. 这将花费你。 You should instead work on different sections of the same one array, taking care that no subtask works on a section that another subtask is working on. 您应该在同一个数组的不同部分上工作,注意没有子任务在另一个子任务正在处理的部分上工作。

But beyond that, there's a certain amount of overhead that comes with the fork/join approach (as with any concurrency). 但除此之外,fork / join方法带来了一定的开销(与任何并发一样)。 In fact, if you look at the javadocs for RecursiveTask , they even point out that their simple example will perform slowly because the forking is too granular. 事实上,如果你看一下RecursiveTask的javadocs,他们甚至会指出他们的简单例子会表现得很慢,因为分叉过于细化。

Long story short, you should have fewer subdivisions, each of which does more. 长话短说,你应该有更少的细分,每个细分做得更多。 More generally, any time you have more non-blocked threads than cores, throughput is not going to improve, and in fact overhead will start to chip away at it. 更一般地说,只要你拥有比核心更多的非阻塞线程,吞吐量就不会有所改善,实际上开销会开始削减它。

Without looking too deeply at your code, spawning a new thread is costly. 在不深入研究代码的情况下,生成新线程代价高昂。 If you haven't got much work to do, then it's not worth it for performance reasons alone. 如果你没有太多的工作要做,那么仅仅出于性能原因它是不值得的。 Very generally talking here, but a single thread could loop thousands of times before a new thread is spawned and starts running (especially on Windows). 非常普遍地在这里谈论,但是在生成新线程并开始运行之前,单个线程可以循环数千次(特别是在Windows上)。

Please refer to Doug Lea's paper (under 2. DESIGN) where he states: 请参阅Doug Lea的论文 (在2. DESIGN下),他说:

"However, the java.lang.Thread class (as well as POSIX pthreads, upon which Java threads are often based) are suboptimal vehicles for supporting fork/join programs" “但是,java.lang.Thread类(以及Java线程通常基于POSIX pthreads)是支持fork / join程序的次优工具”

还发现了有关使用Fork / Join Dan Grossman的Fork / Join简介的以下信息

I encountered the same issue.我遇到了同样的问题。 In my implementation of merge sort I only copy the right part which might be shorter than the left part.在我的合并排序实现中,我只复制可能比左边部分短的右边部分。 Also I skip possible max elements in the right part while copying and merging.此外,我在复制和合并时跳过了右侧部分可能的最大元素。 Even with that optimization the parallel implementation is still slower than an iterative implementation.即使进行了优化,并行实现仍然比迭代实现慢。 According to Leetcode, my iterative approach is faster than 91.96%, my parallel implementation is faster than 56.59%.根据 Leetcode,我的迭代方法比 91.96% 快,我的并行实现比 56.59% 快。

  1. https://leetcode.com/problems/sort-an-array/submissions/878182467/ . https://leetcode.com/problems/sort-an-array/submissions/878182467/ https://leetcode.com/problems/sort-an-array/submissions/877801578/ . https://leetcode.com/problems/sort-an-array/submissions/877801578/
import java.util.concurrent.RecursiveAction;

class Solution {

    public static class Sort extends RecursiveAction {
        private int[] a;
        private int left;
        private int right;

        public Sort(int[] a, int left, int right) {
            this.a = a;
            this.left = left;
            this.right = right;
        }

        @Override
        protected void compute() {
            int m = (left + right) / 2;
            if (m >= left + 1) {
                Sort leftHalf = new Sort(a, left, m);
                leftHalf.fork();
                Sort rightHalf = new Sort(a, m+1, right);
                rightHalf.compute();
                leftHalf.join();
            }
            merge(a, left, right, m);
        }

        private void merge(int[] a, int left, int right, int mid) {
            if (left == right || left + 1 == right && a[left] <= a[right])
                return;
            // let l point to last element of left half, r point to last element of right half
            int l = mid, r = right;
            // skip possible max elements
            while (l < r && a[l] <= a[r])
                r -= 1;
            // size of remaining right half
            int size = r-l;
            int[] buf = new int[size];
            for (int i = 0; i < size; i++){
                buf[i] = a[mid + 1 + i];
            }
            int i = size-1;
            while (i >= 0) {
                if (l >= left && a[ l] >buf[i]) {
                    a[r] = a[l];
                    l -= 1;
                } else {
                    a[r] = buf[i];
                    i -= 1;
                }
                r -= 1;
            }
        }
    }

    public int[] sortArray(int[] a) {
        ForkJoinPool threadPool = ForkJoinPool.commonPool();
        threadPool.invoke(new Sort(a, 0, a.length-1));
        return a;
    }
}

Iterative implementation:
class Solution {
    public int[] sortArray(int[] a) {
        int[] buf = new int[a.length];
        int size = 1;
        while (size < a.length) {
            int left = 0 + size - 1;
            int right = Math.min(left + size, a.length-1);
            while (left < a.length) {
                merge(a, size, left, right, buf);
                left += 2 * size;
                right = Math.min(left + size, a.length-1);
            }
            size *= 2;
        }
        return a;
    }

    private void merge(int[] a, int size, int l, int r, int[] buf) {
        int terminal1 = l - size;
        int right = r;
        while (l < right && a[l] <= a[right])
            right--;
        r = right;
        int rsize = right - l;
        for (int i = rsize-1; i >= 0; i--) {
            buf[i] = a[right--];
        }
        int i = rsize-1;
        while (i >= 0) {
            if (l > terminal1 && a[l] > buf[i]) {
                a[r--] = a[l--];
            } else {
                a[r--] = buf[i--];
            }
        }
    }
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM