简体   繁体   English

在Java中并行处理数组

[英]Parallely processing an array in java

I am trying to apply get faster output through threads. 我试图应用通过线程获得更快的输出。 Just doing a small POC sort. 只是做一个小的POC排序。
Suppose I have a problem statement to find all the the numbers in an array who have odd occurrence. 假设我有一个问题陈述来查找数组中所有出现奇数的数字。 Following is my attempt for both sequentially and parallel. 以下是我对顺序和并行的尝试。

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

public class Test1 {

    final static Map<Integer, Integer> mymap  = new HashMap<>();

    static Map<Integer, AtomicInteger> mymap1 = new ConcurrentHashMap<>();

    public static void generateData(final int[] arr) {
        final Random aRandom = new Random();
        for (int i = 0; i < arr.length; i++) {
            arr[i] = aRandom.nextInt(10);
        }
    }

    public static void calculateAllOddOccurrence(final int[] arr) {

        for (int i = 0; i < arr.length; i++) {
            if (mymap.containsKey(arr[i])) {
                mymap.put(arr[i], mymap.get(arr[i]) + 1);
            } else {
                mymap.put(arr[i], 1);
            }
        }

        for (final Map.Entry<Integer, Integer> entry : mymap.entrySet()) {
            if (entry.getValue() % 2 != 0) {
                System.out.println(entry.getKey() + "=" + entry.getValue());
            }

        }

    }

    public static void calculateAllOddOccurrenceThread(final int[] arr) {

        final ExecutorService executor = Executors.newFixedThreadPool(10);
        final List<Future<?>> results = new ArrayList<>();
        ;
        final int range = arr.length / 10;
        for (int count = 0; count < 10; ++count) {
            final int startAt = count * range;
            final int endAt = startAt + range;
            executor.submit(() -> {
                for (int i = startAt; i < endAt; i++) {
                    if (mymap1.containsKey(arr[i])) {
                        final AtomicInteger accumulator = mymap1.get(arr[i]);
                        accumulator.incrementAndGet();
                        mymap1.put(arr[i], accumulator);
                    } else {
                        mymap1.put(arr[i], new AtomicInteger(1));
                    }
                }
            });
        }

        awaitTerminationAfterShutdown(executor);

        for (final Entry<Integer, AtomicInteger> entry : mymap1.entrySet()) {
            if (entry.getValue().get() % 2 != 0) {
                System.out.println(entry.getKey() + "=" + entry.getValue());
            }

        }

    }

    public static void calculateAllOddOccurrenceStream(final int[] arr) {

        final ConcurrentMap<Integer, List<Integer>> map2 = Arrays.stream(arr).parallel().boxed().collect(Collectors.groupingByConcurrent(i -> i));
        map2.entrySet().stream().parallel().filter(e -> e.getValue().size() % 2 != 0).forEach(entry -> System.out.println(entry.getKey() + "=" + entry.getValue().size()));

    }

    public static void awaitTerminationAfterShutdown(final ExecutorService threadPool) {
        threadPool.shutdown();
        try {
            if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) {
                threadPool.shutdownNow();
            }
        } catch (final InterruptedException ex) {
            threadPool.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }

    public static void main(final String... doYourBest) {

        final int[] arr = new int[200000000];

        generateData(arr);
        long starttime = System.currentTimeMillis();
        calculateAllOddOccurrence(arr);

        System.out.println("Total time=" + (System.currentTimeMillis() - starttime));

        starttime = System.currentTimeMillis();
        calculateAllOddOccurrenceStream(arr);

        System.out.println("Total time Thread=" + (System.currentTimeMillis() - starttime));

    }

}

Output: 输出:

1=20003685
2=20000961
3=19991311
5=20006433
7=19995737
8=19999463
Total time=3418
5=20006433
7=19995737
1=20003685
8=19999463
2=20000961
3=19991311
Total time Thread=19640

Parallel execution (calculateAllOddOccurrenceStream ) is taking more time. 并行执行(calculateAllOddOccurrenceStream)需要更多时间。 What is the best way to process an array in parallel and then merge the result? 并行处理数组然后合并结果的最佳方法是什么?

My goal is not to find the fastest algorithm, but to use any algorithm and try to run on in different threads such that they are processing different part of array simultaneously. 我的目标不是找到最快的算法,而是使用任何算法并尝试在不同的线程上运行,以便它们同时处理数组的不同部分。

It seems that those threads are working on same parts of the array simultaneously hence the answer is not coming correctly. 这些线程似乎同时在数组的相同部分上工作,因此答案无法正确给出。

Rather divide the array in parts with proper start and end indexes. 而是将数组分为适当的开始索引和结束索引。 Allocate separate threads to process these parts and count the occurences of each number in each of those parts. 分配单独的线程来处理这些部分,并计算每个部分中每个数字的出现次数。

At the end, you would have multiple maps having counts calculated from those separate parts. 最后,您将拥有多个地图,这些地图的计数是根据这些单独的部分计算得出的。 Merge those maps to get the final answer. 合并这些地图以获得最终答案。

OR you could have a single concurrentHashMap for storing the counts coming from all those threads, but a bug could creep in there I guess as there would still be concurrent write conflicts. 或者,您可以有一个并发HashMap来存储来自所有这些线程的计数,但是我猜这里可能会出现一个错误,因为仍然存在并发写入冲突。 In a highly multi-threaded environment, writes on a cocnurrentHashMap might not be 100% safe. 在高度多线程的环境中,在cocnurrentHashMap上进行写入可能不是100%安全的。 For a guaranteed write behaviour, the correct way is to use the the atomicity of ConcurrentHashMap.putIfAbsent(K key, V value) method and pay attention to the return value, which tells if the put operation was successful or not. 为了保证写入行为,正确的方法是使用ConcurrentHashMap.putIfAbsent(K key,V value)方法的原子性,并注意返回值,该值指示放置操作是否成功。 Simple put might not be correct. 简单放置可能不正确。 See https://stackoverflow.com/a/14947844/945214 参见https://stackoverflow.com/a/14947844/945214

You could use java 8 streams API ( https://www.journaldev.com/2774/java-8-stream ) to write the code OR simple threading code using Java 5 constructs would also do. 您可以使用Java 8流API( https://www.journaldev.com/2774/java-8-stream )编写代码,或者使用Java 5构造的简单线程代码也可以。

Added Java8 stream code, Notice the timing differences. 添加了Java8流代码,请注意时序差异。 ArrayList (instead) of an array makes a difference: 数组的ArrayList(而不是)有所不同:

package com.test;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;

public class Test {

    public static void generateData(final int[] arr) {
        final Random aRandom = new Random();
        for (int i = 0; i < arr.length; i++) {
            arr[i] = aRandom.nextInt(10);
        }
    }

    public static void calculateAllOddOccurrence(final int[] arr) {
        final Map<Integer, Integer> mymap  = new HashMap<>();
        for (int i = 0; i < arr.length; i++) {
            if (mymap.containsKey(arr[i])) {
                mymap.put(arr[i], mymap.get(arr[i]) + 1);
            } else {
                mymap.put(arr[i], 1);
            }
        }
        for (final Map.Entry<Integer, Integer> entry : mymap.entrySet()) {
            if (entry.getValue() % 2 != 0) {
                System.out.println(entry.getKey() + "=" + entry.getValue());
            }

        }
    }

    public static void calculateAllOddOccurrenceStream( int[] arr) {
        Arrays.stream(arr).boxed().collect(Collectors.groupingBy(Function.identity(), Collectors.counting())).entrySet().parallelStream().filter(e -> e.getValue() % 2 != 0).forEach(entry -> System.out.println(entry.getKey()+"="+ entry.getValue()));
    }

    public static void calculateAllOddOccurrenceStream(List<Integer> list) {
        list.parallelStream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting())).entrySet().parallelStream().filter(e -> e.getValue() % 2 != 0).forEach(entry -> System.out.println(entry.getKey()+"="+ entry.getValue()));
    }

    public static void main(final String... doYourBest) {

        final int[] arr = new int[200000000];

        generateData(arr);
        long starttime = System.currentTimeMillis();
        calculateAllOddOccurrence(arr);
        System.out.println("Total time with simple map=" + (System.currentTimeMillis() - starttime));

        List<Integer> list = Arrays.stream(arr).boxed().collect(Collectors.toList());
        starttime = System.currentTimeMillis();
        calculateAllOddOccurrenceStream(list);
        System.out.println("Total time stream - with a readymade list, which might be the case for most apps as arraylist is more easier to work with =" + (System.currentTimeMillis() - starttime));

        starttime = System.currentTimeMillis();
        calculateAllOddOccurrenceStream(arr);
        System.out.println("Total time Stream with array=" + (System.currentTimeMillis() - starttime));

    }}

OUTPUT 输出值


0=19999427
2=20001707
4=20002331
5=20001585
7=20001859
8=19993989
Total time with simple map=2813
4=20002331
0=19999427
2=20001707
7=20001859
8=19993989
5=20001585
Total time stream - with a readymade list, which might be the case for most apps as arraylist is more easier to work with = 3328
8=19993989
7=20001859
0=19999427
4=20002331
2=20001707
5=20001585
Total time Stream with array=6115

You are looking at the STREAMS API introduced in Java 8: http://www.baeldung.com/java-8-streams 您正在查看Java 8中引入的STREAMS API: http : //www.baeldung.com/java-8-streams

Example: 例:

// sequential processes
myArray.stream().filter( ... ).map( ... ).collect(Collectors.toList()):

// parallel processes
myArray.parallelStream().filter( ... ).map( ... ).collect(Collectors.toList());

Looking at your code, you're going wrong with this line: 查看您的代码,您在此行上出错:

mymap1.put(arr[i], mymap1.get(arr[i]) + 1);

You are overwriting the values in parallel, for example: 您正在并行覆盖这些值,例如:

Thread 1 'get' = 0
Thread 2 'get' = 0
Thread 1 'put 1' 
Thread 2 'put 1'

Change your map to: 将地图更改为:

static Map<Integer, AtomicInteger>       mymap1 = new ConcurrentHashMap<>();
static {
    //initialize to avoid null values and non-synchronized puts from different Threads
    for(int i=0;i<10;i++) {
        mymap1.put(i, new AtomicInteger());
    }
}
....
    //in your loop
    for (int i = 0; i < arr.length; i++) {
        AtomicInteger accumulator = mymap1.get(arr[i]);
        accumulator.incrementAndGet();
    }

Edit: The problem with the above approach is of course the initialization of mymap1. 编辑:以上方法的问题当然是mymap1的初始化。 To avoid falling into the same trap (creating AtomicInteger within the loop and overwriting each other yet again), it needs to be prefilled with values. 为了避免陷入同一陷阱(在循环内创建AtomicInteger并再次相互覆盖),需要使用值对其进行预填充。

Since I'm feeling generous, here's what might work with the Streams API: 由于我很慷慨,因此以下是适用于Streams API的方法:

int totalEvenCount = Arrays.stream(arr).parallel().filter(i->i%2==0).reduce(0, Integer::sum);
int totalOddCount = Arrays.stream(arr).parallel().filter(i->i%2!=0).reduce(0, Integer::sum);

//or this to count by individual numbers:
ConcurrentMap<Integer,List<Integer>> map1 = Arrays.stream(arr).parallel().boxed().collect(Collectors.groupingByConcurrent(i->i));
map1.entrySet().stream().filter(e -> e.getKey()%2!=0).forEach(entry -> System.out.println(entry.getKey() + "=" + entry.getValue().size()));

As an exercise to the reader, perhaps you can look into how the various Collector s work, in order to write your own countingBy(i->i%2!=0) to output a map only containing the counts instead of a list of values. 作为读者的练习,也许您可​​以研究各种Collector的工作方式,以便编写自己的countingBy(i->i%2!=0)以输出仅包含计数而不是列表的地图价值观。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM