简体   繁体   中英

How do I "cancel" a CountDownLatch?

I have multiple consumer threads waiting on a CountDownLatch of size 1 using await() . I have a single producer thread that calls countDown() when it successfully finishes.

This works great when there are no errors.

However, if the producer detects an error, I would like for it to be able to signal the error to the consumer threads. Ideally I could have the producer call something like abortCountDown() and have all of the consumers receive an InterruptedException or some other exception. I don't want to call countDown() , because this requires all of my consumer threads to then do an additional manual check for success after their call to await() . I'd rather they just receive an exception, which they already know how to handle.

I know that an abort facility is not available in CountDownLatch . Is there another synchronization primitive that I can easily adapt to effectively create a CountDownLatch that supports aborting the countdown?

JB Nizet had a great answer. I took his and polished it a little bit. The result is a subclass of CountDownLatch called AbortableCountDownLatch, which adds an "abort()" method to the class that will cause all threads waiting on the latch to receive an AbortException (a subclass of InterruptedException).

Also, unlike JB's class, the AbortableCountDownLatch will abort all blocking threads immediately on an abort, rather than waiting for the countdown to reach zero (for situations where you use a count>1).

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

public class AbortableCountDownLatch extends CountDownLatch {
    protected boolean aborted = false;

    public AbortableCountDownLatch(int count) {
        super(count);
    }


   /**
     * Unblocks all threads waiting on this latch and cause them to receive an
     * AbortedException.  If the latch has already counted all the way down,
     * this method does nothing.
     */
    public void abort() {
        if( getCount()==0 )
            return;

        this.aborted = true;
        while(getCount()>0)
            countDown();
    }


    @Override
    public boolean await(long timeout, TimeUnit unit) throws InterruptedException {
        final boolean rtrn = super.await(timeout,unit);
        if (aborted)
            throw new AbortedException();
        return rtrn;
    }

    @Override
    public void await() throws InterruptedException {
        super.await();
        if (aborted)
            throw new AbortedException();
    }


    public static class AbortedException extends InterruptedException {
        public AbortedException() {
        }

        public AbortedException(String detailMessage) {
            super(detailMessage);
        }
    }
}

Encapsulate this behavior inside a specific, higher-level class, using the CountDownLatch internally:

public class MyLatch {
    private CountDownLatch latch;
    private boolean aborted;
    ...

    // called by consumers
    public void await() throws AbortedException {
        latch.await();
        if (aborted) {
            throw new AbortedException();
        }
    }

    // called by producer
    public void abort() {
        this.aborted = true;
        latch.countDown();
    }

    // called by producer
    public void succeed() {
        latch.countDown();
    }
}

You can create a wrapper around CountDownLatch that provides the ability to cancel the waiters. It will need to track the waiting threads and release them when they timeout as well as remember that the latch was cancelled so future calls to await will interrupt immediately.

public class CancellableCountDownLatch
{
    final CountDownLatch latch;
    final List<Thread> waiters;
    boolean cancelled = false;

    public CancellableCountDownLatch(int count) {
        latch = new CountDownLatch(count);
        waiters = new ArrayList<Thread>();
    }

    public void await() throws InterruptedException {
        try {
            addWaiter();
            latch.await();
        }
        finally {
            removeWaiter();
        }
    }

    public boolean await(long timeout, TimeUnit unit) throws InterruptedException {
        try {
            addWaiter();
            return latch.await(timeout, unit);
        }
        finally {
            removeWaiter();
        }
    }

    private synchronized void addWaiter() throws InterruptedException {
        if (cancelled) {
            Thread.currentThread().interrupt();
            throw new InterruptedException("Latch has already been cancelled");
        }
        waiters.add(Thread.currentThread());
    }

    private synchronized void removeWaiter() {
        waiters.remove(Thread.currentThread());
    }

    public void countDown() {
        latch.countDown();
    }

    public synchronized void cancel() {
        if (!cancelled) {
            cancelled = true;
            for (Thread waiter : waiters) {
                waiter.interrupt();
            }
            waiters.clear();
        }
    }

    public long getCount() {
        return latch.getCount();
    }

    @Override
    public String toString() {
        return latch.toString();
    }
}

You could roll your own CountDownLatch out using a ReentrantLock that allows access to its protected getWaitingThreads method.

Example:

public class FailableCountDownLatch {
    private static class ConditionReentrantLock extends ReentrantLock {
        private static final long serialVersionUID = 2974195457854549498L;

        @Override
        public Collection<Thread> getWaitingThreads(Condition c) {
            return super.getWaitingThreads(c);
        }
    }

    private final ConditionReentrantLock lock = new ConditionReentrantLock();
    private final Condition countIsZero = lock.newCondition();
    private long count;

    public FailableCountDownLatch(long count) {
        this.count = count;
    }

    public void await() throws InterruptedException {
        lock.lock();
        try {
            if (getCount() > 0) {
                countIsZero.await();
            }
        } finally {
            lock.unlock();
        }
    }

    public boolean await(long time, TimeUnit unit) throws InterruptedException {
        lock.lock();
        try {
            if (getCount() > 0) {
                return countIsZero.await(time, unit);
            }
        } finally {
            lock.unlock();
        }
        return true;
    }

    public long getCount() {
        lock.lock();
        try {
            return count;
        } finally {
            lock.unlock();
        }
    }

    public void countDown() {
        lock.lock();
        try {
            if (count > 0) {
                count--;

                if (count == 0) {
                    countIsZero.signalAll();
                }
            }
        } finally {
            lock.unlock();
        }
    }

    public void abortCountDown() {
        lock.lock();
        try {
            for (Thread t : lock.getWaitingThreads(countIsZero)) {
                t.interrupt();
            }
        } finally {
            lock.unlock();
        }
    }
}

You may want to change this class to throw an InterruptedException on new calls to await after it has been cancelled. You could even have this class extend CountDownLatch if you needed that functionality.

Since Java 8 you can use CompletableFuture for this. One or more threads can call the blocking get() method:

CompletableFuture<Void> cf = new CompletableFuture<>();
try {
  cf.get();
} catch (ExecutionException e) {
  //act on error
}

another thread can either complete it successfully with cf.complete(null) or exceptionally with cf.completeExceptionally(new MyException())

There is a simple option here that wraps the CountDownLatch. It's similar to the second answer but does not have to call countdown repeatedly, which could be very expensive if the latch is for a large number. It uses an AtomicInteger for the real count, with a CountDownLatch of 1.

https://github.com/scottf/CancellableCountDownLatch/blob/main/CancellableCountDownLatch.java

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

public class CancellableCountDownLatch {
    private final AtomicInteger count;
    private final CountDownLatch cdl;

    public CancellableCountDownLatch(int count) {
        this.count = new AtomicInteger(count);
        cdl = new CountDownLatch(1);
    }

    public void cancel() {
        count.set(0);
        cdl.countDown();
    }

    public void await() throws InterruptedException {
        cdl.await();
    }

    public boolean await(long timeout, TimeUnit unit) throws InterruptedException {
        return cdl.await(timeout, unit);
    }

    public void countDown() {
        if (count.decrementAndGet() <= 0) {
            cdl.countDown();
        }
    }

    public long getCount() {
        return Math.max(count.get(), 0);
    }

    @Override
    public String toString() {
        return super.toString() + "[Count = " + getCount() + "]";
    }
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM