简体   繁体   中英

How to implement Task.WhenAny() with a predicate

I want to execute several asynchronous tasks concurrently. Each task will run an HTTP request that can either complete successfully or throw an exception. I need to await until the first task completes successfully, or until all the tasks have failed.

How can I implement an overload of the Task.WhenAny method that accepts a predicate, so that I can exclude the non-successfully completed tasks?

Wait for any task and return the task if the condition is met. Otherwise wait again for the other tasks until there is no more task to wait for.

public static async Task<Task> WhenAny( IEnumerable<Task> tasks, Predicate<Task> condition )
{
    var tasklist = tasks.ToList();
    while ( tasklist.Count > 0 )
    {
        var task = await Task.WhenAny( tasklist );
        if ( condition( task ) )
            return task;
        tasklist.Remove( task );
    }
    return null;
}

simple check for that

var tasks = new List<Task> {
    Task.FromException( new Exception() ),
    Task.FromException( new Exception() ),
    Task.FromException( new Exception() ),
    Task.CompletedTask, };

var completedTask = WhenAny( tasks, t => t.Status == TaskStatus.RanToCompletion ).Result;

if ( tasks.IndexOf( completedTask ) != 3 )
    throw new Exception( "not expected" );
public static Task<Task<T>> WhenFirst<T>(IEnumerable<Task<T>> tasks, Func<Task<T>, bool> predicate)
{
    if (tasks == null) throw new ArgumentNullException(nameof(tasks));
    if (predicate == null) throw new ArgumentNullException(nameof(predicate));

    var tasksArray = (tasks as IReadOnlyList<Task<T>>) ?? tasks.ToArray();
    if (tasksArray.Count == 0) throw new ArgumentException("Empty task list", nameof(tasks));
    if (tasksArray.Any(t => t == null)) throw new ArgumentException("Tasks contains a null reference", nameof(tasks));

    var tcs = new TaskCompletionSource<Task<T>>();
    var count = tasksArray.Count;

    Action<Task<T>> continuation = t =>
        {
            if (predicate(t))
            {
                tcs.TrySetResult(t);
            }
            if (Interlocked.Decrement(ref count) == 0)
            {
                tcs.TrySetResult(null);
            }
        };

    foreach (var task in tasksArray)
    {
        task.ContinueWith(continuation);
    }

    return tcs.Task;
}

Sample usage:

var task = await WhenFirst(tasks, t => t.Status == TaskStatus.RanToCompletion);

if (task != null)
    var value = await task;

Note that this doesn't propagate exceptions of failed tasks (just as WhenAny doesn't).

You can also create a version of this for the non-generic Task .

public static Task<T> GetFirstResult<T>(
    ICollection<Func<CancellationToken, Task<T>>> taskFactories, 
    Predicate<T> predicate) where T : class
{
    var tcs = new TaskCompletionSource<T>();
    var cts = new CancellationTokenSource();

    int completedCount = 0;
    // in case you have a lot of tasks you might need to throttle them 
    //(e.g. so you don't try to send 99999999 requests at the same time)
    // see: http://stackoverflow.com/a/25877042/67824
    foreach (var taskFactory in taskFactories)
    {
        taskFactory(cts.Token).ContinueWith(t => 
        {
            if (t.Exception != null)
            {
                Console.WriteLine($"Task completed with exception: {t.Exception}");
            }
            else if (predicate(t.Result))
            {
                cts.Cancel();
                tcs.TrySetResult(t.Result);
            }

            if (Interlocked.Increment(ref completedCount) == taskFactories.Count)
            {
                tcs.SetException(new InvalidOperationException("All tasks failed"));
            }

        }, cts.Token);
    }

    return tcs.Task;
}

Sample usage:

using System.Net.Http;
var client = new HttpClient();
var response = await GetFirstResult(
    new Func<CancellationToken, Task<HttpResponseMessage>>[] 
    {
        ct => client.GetAsync("http://microsoft123456.com", ct),
        ct => client.GetAsync("http://microsoft123456.com", ct),
        ct => client.GetAsync("http://microsoft123456.com", ct),
        ct => client.GetAsync("http://microsoft123456.com", ct),
        ct => client.GetAsync("http://microsoft123456.com", ct),
        ct => client.GetAsync("http://microsoft123456.com", ct),
        ct => client.GetAsync("http://microsoft123456.com", ct),
        ct => client.GetAsync("http://microsoft.com", ct),
        ct => client.GetAsync("http://microsoft123456.com", ct),
        ct => client.GetAsync("http://microsoft123456.com", ct),
    }, 
    rm => rm.IsSuccessStatusCode);
Console.WriteLine($"Successful response: {response}");

Here is an attempted improvement of the excellent Eli Arbel's answer . These are the improved points:

  1. An exception in the predicate is propagated as a fault of the returned task.
  2. The predicate is not called after a task has been accepted as the result.
  3. The predicate is executed in the original SynchronizationContext . This makes it possible to access UI elements (if the WhenFirst method is called from a UI thread)
  4. The source IEnumerable<Task<T>> is enumerated directly, without being converted to an array first.
public static Task<Task<T>> WhenFirst<T>(IEnumerable<Task<T>> tasks,
    Func<Task<T>, bool> predicate)
{
    if (tasks == null) throw new ArgumentNullException(nameof(tasks));
    if (predicate == null) throw new ArgumentNullException(nameof(predicate));

    var tcs = new TaskCompletionSource<Task<T>>(
        TaskCreationOptions.RunContinuationsAsynchronously);
    var pendingCount = 1; // The initial 1 represents the enumeration itself
    foreach (var task in tasks)
    {
        if (task == null) throw new ArgumentException($"The {nameof(tasks)}" +
            " argument included a null value.", nameof(tasks));
        Interlocked.Increment(ref pendingCount);
        HandleTaskCompletion(task);
    }
    if (Interlocked.Decrement(ref pendingCount) == 0) tcs.TrySetResult(null);
    return tcs.Task;

    async void HandleTaskCompletion(Task<T> task)
    {
        try
        {
            await task; // Continue on the captured context
        }
        catch { } // Ignore exception

        if (tcs.Task.IsCompleted) return;

        try
        {
            if (predicate(task))
                tcs.TrySetResult(task);
            else
                if (Interlocked.Decrement(ref pendingCount) == 0)
                    tcs.TrySetResult(null);
        }
        catch (Exception ex)
        {
            tcs.TrySetException(ex);
        }
    }
}

Another way of doing this, very similar to Sir Rufo's answer , but using AsyncEnumerable and Ix.NET

Implement a little helper method to stream any task as soon as it's completed:

static IAsyncEnumerable<Task<T>> WhenCompleted<T>(IEnumerable<Task<T>> source) =>
    AsyncEnumerable.Create(_ =>
    {
        var tasks = source.ToList();
        Task<T> current = null;
        return AsyncEnumerator.Create(
            async () => tasks.Any() && tasks.Remove(current = await Task.WhenAny(tasks)), 
            () => current,
            async () => { });
    });
}

One can then process the tasks in completion order, eg returning the first matching one as requested:

await WhenCompleted(tasks).FirstOrDefault(t => t.Status == TaskStatus.RanToCompletion)

Just wanted to add on some of the answers @Peebo and @SirRufo that are using List.Remove (because I can't comment yet)

I would consider using:

var tasks = source.ToHashSet();

instead of:

var tasks = source.ToList();

so removing would be more efficient

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM