[英]Java: Wait in a loop until tasks of ThreadPoolExecutor are done before continuing
我正在努力使Dijkstra算法並行化。 使每個節點線程查看當前節點的所有邊緣。 這是與線程並行執行的,但是開銷太大。 這導致比算法的順序版本更長的時間。
添加了ThreadPool來解決此問題,但是我無法等待任務完成才可以進行下一個迭代。 只有完成一個節點的所有任務后,我們才應該繼續。 我們需要所有任務的結果,然后才能按節點搜索下一個最接近的節點。
我嘗試做executor.shutdown(),但是有了這個方法,它不會接受新任務。 我們如何在循環中等待直到每個任務完成,而不必每次都聲明ThreadPoolExecutor。 這樣做將通過使用此線程而不是常規線程來達到減少開銷的目的。
我想到的一件事是添加任務(邊)的BlockingQueue。 但對於這種解決方案,我也堅持等待任務完成而無需shudown()。
public void apply(int numberOfThreads) {
ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numberOfThreads);
class DijkstraTask implements Runnable {
private String name;
public DijkstraTask(String name) {
this.name = name;
}
public String getName() {
return name;
}
@Override
public void run() {
calculateShortestDistances(numberOfThreads);
}
}
// Visit every node, in order of stored distance
for (int i = 0; i < this.nodes.length; i++) {
//Add task for each node
for (int t = 0; t < numberOfThreads; t++) {
executor.execute(new DijkstraTask("Task " + t));
}
//Wait until finished?
while (executor.getActiveCount() > 0) {
System.out.println("Active count: " + executor.getActiveCount());
}
//Look through the results of the tasks and get the next node that is closest by
currentNode = getNodeShortestDistanced();
//Reset the threadCounter for next iteration
this.setCount(0);
}
}
邊的數量除以線程數。 所以8條邊和2條線程意味着每個線程將並行處理4條邊。
public void calculateShortestDistances(int numberOfThreads) {
int threadCounter = this.getCount();
this.setCount(count + 1);
// Loop round the edges that are joined to the current node
currentNodeEdges = this.nodes[currentNode].getEdges();
int edgesPerThread = currentNodeEdges.size() / numberOfThreads;
int modulo = currentNodeEdges.size() % numberOfThreads;
this.nodes[0].setDistanceFromSource(0);
//Process the edges per thread
for (int joinedEdge = (edgesPerThread * threadCounter); joinedEdge < (edgesPerThread * (threadCounter + 1)); joinedEdge++) {
System.out.println("Start: " + (edgesPerThread * threadCounter) + ". End: " + (edgesPerThread * (threadCounter + 1) + ".JoinedEdge: " + joinedEdge) + ". Total: " + currentNodeEdges.size());
// Determine the joined edge neighbour of the current node
int neighbourIndex = currentNodeEdges.get(joinedEdge).getNeighbourIndex(currentNode);
// Only interested in an unvisited neighbour
if (!this.nodes[neighbourIndex].isVisited()) {
// Calculate the tentative distance for the neighbour
int tentative = this.nodes[currentNode].getDistanceFromSource() + currentNodeEdges.get(joinedEdge).getLength();
// Overwrite if the tentative distance is less than what's currently stored
if (tentative < nodes[neighbourIndex].getDistanceFromSource()) {
nodes[neighbourIndex].setDistanceFromSource(tentative);
}
}
}
//if we have a modulo above 0, the last thread will process the remaining edges
if (modulo > 0 && numberOfThreads == (threadCounter + 1)) {
for (int joinedEdge = (edgesPerThread * threadCounter); joinedEdge < (edgesPerThread * (threadCounter) + modulo); joinedEdge++) {
// Determine the joined edge neighbour of the current node
int neighbourIndex = currentNodeEdges.get(joinedEdge).getNeighbourIndex(currentNode);
// Only interested in an unvisited neighbour
if (!this.nodes[neighbourIndex].isVisited()) {
// Calculate the tentative distance for the neighbour
int tentative = this.nodes[currentNode].getDistanceFromSource() + currentNodeEdges.get(joinedEdge).getLength();
// Overwrite if the tentative distance is less than what's currently stored
if (tentative < nodes[neighbourIndex].getDistanceFromSource()) {
nodes[neighbourIndex].setDistanceFromSource(tentative);
}
}
}
}
// All neighbours are checked so this node is now visited
nodes[currentNode].setVisited(true);
}
感謝您的幫助!
您應該查看CyclicBarrier
或CountDownLatch
。 這兩種方法都可以防止線程啟動,除非其他線程已發出信號已完成。 它們之間的區別是CyclicBarrier
是可重用的,即可以多次使用,而CountDownLatch
是一次性的,您無法重置計數。
從Javadocs釋義:
CountDownLatch是一種同步輔助工具,它允許一個或多個線程等待,直到在其他線程中執行的一組操作完成為止。
CyclicBarrier是一種同步輔助工具,它允許一組線程互相等待以到達一個公共的屏障點。 CyclicBarriers在涉及固定大小的線程方的程序中很有用,該線程方有時必須互相等待。 該屏障稱為循環屏障,因為它可以在釋放等待線程之后重新使用。
https://docs.oracle.com/cn/java/javase/11/docs/api/java.base/java/util/concurrent/CyclicBarrier.html
這是使用CountDownLatch
等待池中所有線程的簡單演示:
import java.io.IOException;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class WaitForAllThreadsInPool {
private static int MAX_CYCLES = 10;
public static void main(String args[]) throws InterruptedException, IOException {
new WaitForAllThreadsInPool().apply(4);
}
public void apply(int numberOfThreads) {
ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
CountDownLatch cdl = new CountDownLatch(numberOfThreads);
class DijkstraTask implements Runnable {
private final String name;
private final CountDownLatch cdl;
private final Random rnd = new Random();
public DijkstraTask(String name, CountDownLatch cdl) {
this.name = name;
this.cdl = cdl;
}
@Override
public void run() {
calculateShortestDistances(1+ rnd.nextInt(MAX_CYCLES), cdl, name);
}
}
for (int t = 0; t < numberOfThreads; t++) {
executor.execute(new DijkstraTask("Task " + t, cdl));
}
//wait for all threads to finish
try {
cdl.await();
System.out.println("-all done-");
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
public void calculateShortestDistances(int numberOfWorkCycles, CountDownLatch cdl, String name) {
//simulate long process
for(int cycle = 1 ; cycle <= numberOfWorkCycles; cycle++){
System.out.println(name + " cycle "+ cycle + "/"+ numberOfWorkCycles );
try {
TimeUnit.MILLISECONDS.sleep(1000);
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
cdl.countDown(); //thread finished
}
}
輸出樣本:
任務0周期1/3
任務1周期1/2
任務3周期1/9
任務2周期1/3
任務0周期2/3
任務1周期2/2
任務2周期2/3
任務3周期2/9
任務0周期3/3
任務2周期3/3
任務3周期3/9
任務3周期4/9
任務3周期5/9
任務3周期6/9
任務3周期7/9
任務3周期8/9
任務3周期9/9
-全部做完-
您可以使用invokeAll :
//Add task for each node
Collection<Callable<Object>> tasks = new ArrayList<>(numberOfThreads);
for (int t = 0; t < numberOfThreads; t++) {
tasks.add(Executors.callable(new DijkstraTask("Task " + t)));
}
//Wait until finished
executor.invokeAll(tasks);
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.