简体   繁体   English

Java:循环执行直到ThreadPoolExecutor的任务完成,然后再继续

[英]Java: Wait in a loop until tasks of ThreadPoolExecutor are done before continuing

I'm working on making the Dijkstra algorithm parallel. 我正在努力使Dijkstra算法并行化。 Per node threads are made to look at all the edges of the current node. 使每个节点线程查看当前节点的所有边缘。 This was made parallel with threads but there is too much overhead. 这是与线程并行执行的,但是开销太大。 This resulted in a longer time than the sequential version of the algorithm. 这导致比算法的顺序版本更长的时间。

ThreadPool was added to solve this problem but i'm having trouble with waiting until the tasks are done before I can move on to the next iteration. 添加了ThreadPool来解决此问题,但是我无法等待任务完成才可以进行下一个迭代。 Only after all tasks for one node is done we should move on. 只有完成一个节点的所有任务后,我们才应该继续。 We need the results of all tasks before I can search for the next closest by node. 我们需要所有任务的结果,然后才能按节点搜索下一个最接近的节点。

I tried doing executor.shutdown() but with this aproach it won't accept new tasks. 我尝试做executor.shutdown(),但是有了这个方法,它不会接受新任务。 How can we wait in the loop until every task is finished without having to declare the ThreadPoolExecutor every time. 我们如何在循环中等待直到每个任务完成,而不必每次都声明ThreadPoolExecutor。 Doing this will defeat the purpose of the less overhead by using this instead of regular threads. 这样做将通过使用此线程而不是常规线程来达到减少开销的目的。

One thing I thought about was an BlockingQueue that add the tasks(edges). 我想到的一件事是添加任务(边)的BlockingQueue。 But also for this solution i'm stuck on waiting for tasks to finish without shudown(). 但对于这种解决方案,我也坚持等待任务完成而无需shudown()。

public void apply(int numberOfThreads) {
        ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numberOfThreads);

        class DijkstraTask implements Runnable {

            private String name;

            public DijkstraTask(String name) {
                this.name = name;
            }

            public String getName() {
                return name;
            }

            @Override
            public void run() {
                calculateShortestDistances(numberOfThreads);
            }
        }

        // Visit every node, in order of stored distance
        for (int i = 0; i < this.nodes.length; i++) {

            //Add task for each node
            for (int t = 0; t < numberOfThreads; t++) {
                executor.execute(new DijkstraTask("Task " + t));
            }

            //Wait until finished?
            while (executor.getActiveCount() > 0) {
                System.out.println("Active count: " + executor.getActiveCount());
            }

            //Look through the results of the tasks and get the next node that is closest by
            currentNode = getNodeShortestDistanced();

            //Reset the threadCounter for next iteration
            this.setCount(0);
        }
    }

The amount of edges is divided by the number of threads. 边的数量除以线程数。 So 8 edges and 2 threads means each thread will deal with 4 edges in parallel. 所以8条边和2条线程意味着每个线程将并行处理4条边。

public void calculateShortestDistances(int numberOfThreads) {

        int threadCounter = this.getCount();
        this.setCount(count + 1);

        // Loop round the edges that are joined to the current node
        currentNodeEdges = this.nodes[currentNode].getEdges();

        int edgesPerThread = currentNodeEdges.size() / numberOfThreads;
        int modulo = currentNodeEdges.size() % numberOfThreads;
        this.nodes[0].setDistanceFromSource(0);
        //Process the edges per thread
        for (int joinedEdge = (edgesPerThread * threadCounter); joinedEdge < (edgesPerThread * (threadCounter + 1)); joinedEdge++) {

            System.out.println("Start: " + (edgesPerThread * threadCounter) + ". End: " + (edgesPerThread * (threadCounter + 1) + ".JoinedEdge: " + joinedEdge) + ". Total: " + currentNodeEdges.size());
            // Determine the joined edge neighbour of the current node
            int neighbourIndex = currentNodeEdges.get(joinedEdge).getNeighbourIndex(currentNode);

            // Only interested in an unvisited neighbour
            if (!this.nodes[neighbourIndex].isVisited()) {
                // Calculate the tentative distance for the neighbour
                int tentative = this.nodes[currentNode].getDistanceFromSource() + currentNodeEdges.get(joinedEdge).getLength();
                // Overwrite if the tentative distance is less than what's currently stored
                if (tentative < nodes[neighbourIndex].getDistanceFromSource()) {
                    nodes[neighbourIndex].setDistanceFromSource(tentative);
                }
            }
        }

        //if we have a modulo above 0, the last thread will process the remaining edges
        if (modulo > 0 && numberOfThreads == (threadCounter + 1)) {
            for (int joinedEdge = (edgesPerThread * threadCounter); joinedEdge < (edgesPerThread * (threadCounter) + modulo); joinedEdge++) {
                // Determine the joined edge neighbour of the current node
                int neighbourIndex = currentNodeEdges.get(joinedEdge).getNeighbourIndex(currentNode);

                // Only interested in an unvisited neighbour
                if (!this.nodes[neighbourIndex].isVisited()) {
                    // Calculate the tentative distance for the neighbour
                    int tentative = this.nodes[currentNode].getDistanceFromSource() + currentNodeEdges.get(joinedEdge).getLength();
                    // Overwrite if the tentative distance is less than what's currently stored
                    if (tentative < nodes[neighbourIndex].getDistanceFromSource()) {
                        nodes[neighbourIndex].setDistanceFromSource(tentative);
                    }
                }
            }
        }
        // All neighbours are checked so this node is now visited
        nodes[currentNode].setVisited(true);
    }

Thanks for helping me! 感谢您的帮助!

You should look into CyclicBarrier or a CountDownLatch . 您应该查看CyclicBarrierCountDownLatch Both of these allow you to prevent threads starting unless other threads have signaled that they're done. 这两种方法都可以防止线程启动,除非其他线程已发出信号已完成。 The difference between them is that CyclicBarrier is reusable, ie can be used multiple times, while CountDownLatch is one-shot, you cannot reset the count. 它们之间的区别是CyclicBarrier是可重用的,即可以多次使用,而CountDownLatch是一次性的,您无法重置计数。

Paraphrasing from the Javadocs: 从Javadocs释义:

A CountDownLatch is a synchronization aid that allows one or more threads to wait until a set of operations being performed in other threads completes. CountDownLatch是一种同步辅助工具,它允许一个或多个线程等待,直到在其他线程中执行的一组操作完成为止。

A CyclicBarrier is a synchronization aid that allows a set of threads to all wait for each other to reach a common barrier point. CyclicBarrier是一种同步辅助工具,它允许一组线程互相等待以到达一个公共的屏障点。 CyclicBarriers are useful in programs involving a fixed sized party of threads that must occasionally wait for each other. CyclicBarriers在涉及固定大小的线程方的程序中很有用,该线程方有时必须互相等待。 The barrier is called cyclic because it can be re-used after the waiting threads are released. 该屏障称为循环屏障,因为它可以在释放等待线程之后重新使用。

https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/util/concurrent/CyclicBarrier.html https://docs.oracle.com/cn/java/javase/11/docs/api/java.base/java/util/concurrent/CyclicBarrier.html

https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/util/concurrent/CountDownLatch.html https://docs.oracle.com/cn/java/javase/11/docs/api/java.base/java/util/concurrent/CountDownLatch.html

Here is a simple demo of using CountDownLatch to wait for all threads in the pool: 这是使用CountDownLatch等待池中所有线程的简单演示:

import java.io.IOException;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

public class WaitForAllThreadsInPool {

    private static int MAX_CYCLES = 10;

    public static void main(String args[]) throws InterruptedException, IOException {
        new WaitForAllThreadsInPool().apply(4);
    }

    public void apply(int numberOfThreads) {

        ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
        CountDownLatch cdl = new CountDownLatch(numberOfThreads);

        class DijkstraTask implements Runnable {

            private final String name;
            private final CountDownLatch cdl;
            private final Random rnd = new Random();

            public DijkstraTask(String name, CountDownLatch cdl) {
                this.name = name;
                this.cdl = cdl;
            }

            @Override
            public void run() {
                calculateShortestDistances(1+ rnd.nextInt(MAX_CYCLES), cdl, name);
            }
        }

        for (int t = 0; t < numberOfThreads; t++) {
            executor.execute(new DijkstraTask("Task " + t, cdl));
        }

        //wait for all threads to finish
        try {
            cdl.await();
            System.out.println("-all done-");
        } catch (InterruptedException ex) {
            ex.printStackTrace();
        }
    }

    public void calculateShortestDistances(int numberOfWorkCycles, CountDownLatch cdl, String name) {

        //simulate long process
        for(int cycle = 1 ; cycle <= numberOfWorkCycles; cycle++){
            System.out.println(name + " cycle  "+ cycle + "/"+ numberOfWorkCycles );
            try {
                TimeUnit.MILLISECONDS.sleep(1000);
            } catch (InterruptedException ex) {
                ex.printStackTrace();
            }
        }

        cdl.countDown(); //thread finished
    }
}

Output sample: 输出样本:

Task 0 cycle 1/3 任务0周期1/3
Task 1 cycle 1/2 任务1周期1/2
Task 3 cycle 1/9 任务3周期1/9
Task 2 cycle 1/3 任务2周期1/3
Task 0 cycle 2/3 任务0周期2/3
Task 1 cycle 2/2 任务1周期2/2
Task 2 cycle 2/3 任务2周期2/3
Task 3 cycle 2/9 任务3周期2/9
Task 0 cycle 3/3 任务0周期3/3
Task 2 cycle 3/3 任务2周期3/3
Task 3 cycle 3/9 任务3周期3/9
Task 3 cycle 4/9 任务3周期4/9
Task 3 cycle 5/9 任务3周期5/9
Task 3 cycle 6/9 任务3周期6/9
Task 3 cycle 7/9 任务3周期7/9
Task 3 cycle 8/9 任务3周期8/9
Task 3 cycle 9/9 任务3周期9/9
-all done- -全部做完-

You can use invokeAll : 您可以使用invokeAll

//Add task for each node
Collection<Callable<Object>> tasks = new ArrayList<>(numberOfThreads);
for (int t = 0; t < numberOfThreads; t++) {
    tasks.add(Executors.callable(new DijkstraTask("Task " + t)));
}

//Wait until finished
executor.invokeAll(tasks);

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM