简体   繁体   中英

How to implement an efficient WhenEach that streams an IAsyncEnumerable of task results?

I am trying to update my toolset with the new tools offered by C# 8 , and one method that seems particularly useful is a version of Task.WhenAll that returns an IAsyncEnumerable . This method should stream the task results as soon as they become available, so naming it WhenAll doesn't make much sense. WhenEach sounds more appropriate. The signature of the method is:

public static IAsyncEnumerable<TResult> WhenEach<TResult>(Task<TResult>[] tasks);

This method could be used like this:

var tasks = new Task<int>[]
{
    ProcessAsync(1, 300),
    ProcessAsync(2, 500),
    ProcessAsync(3, 400),
    ProcessAsync(4, 200),
    ProcessAsync(5, 100),
};

await foreach (int result in WhenEach(tasks))
{
    Console.WriteLine($"Processed: {result}");
}

static async Task<int> ProcessAsync(int result, int delay)
{
    await Task.Delay(delay);
    return result;
}

Expected output:

Processed: 5
Processed: 4
Processed: 1
Processed: 3
Processed: 2

I managed to write a basic implementation using the method Task.WhenAny in a loop, but there is a problem with this approach:

public static async IAsyncEnumerable<TResult> WhenEach<TResult>(
    Task<TResult>[] tasks)
{
    var hashSet = new HashSet<Task<TResult>>(tasks);
    while (hashSet.Count > 0)
    {
        var task = await Task.WhenAny(hashSet).ConfigureAwait(false);
        yield return await task.ConfigureAwait(false);
        hashSet.Remove(task);
    }
}

The problem is the performance. The implementation of the Task.WhenAny creates a defensive copy of the supplied list of tasks, so calling it repeatedly in a loop results in O(n²) computational complexity. My naive implementation struggles to process 10,000 tasks. The overhead is nearly 10 sec in my machine. I would like the method to be nearly as performant as the build-in Task.WhenAll , that can handle hundreds of thousands of tasks with ease. How could I improve the WhenEach method to make it performs decently?

By using code from this article, you can implement the following:

public static Task<Task<T>>[] Interleaved<T>(IEnumerable<Task<T>> tasks)
{
   var inputTasks = tasks.ToList();

   var buckets = new TaskCompletionSource<Task<T>>[inputTasks.Count];
   var results = new Task<Task<T>>[buckets.Length];
   for (int i = 0; i < buckets.Length; i++)
   {
       buckets[i] = new TaskCompletionSource<Task<T>>();
       results[i] = buckets[i].Task;
   }

   int nextTaskIndex = -1;
   Action<Task<T>> continuation = completed =>
   {
       var bucket = buckets[Interlocked.Increment(ref nextTaskIndex)];
       bucket.TrySetResult(completed);
   };

   foreach (var inputTask in inputTasks)
       inputTask.ContinueWith(continuation, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);

   return results;
}

Then change your WhenEach to call the Interleaved code

public static async IAsyncEnumerable<TResult> WhenEach<TResult>(Task<TResult>[] tasks)
{
    foreach (var bucket in Interleaved(tasks))
    {
        var t = await bucket;
        yield return await t;
    }
}

Then you can call your WhenEach as per usual

await foreach (int result in WhenEach(tasks))
{
    Console.WriteLine($"Processed: {result}");
}

I did some rudimentary benchmarking with 10k tasks and performed 5 times better in terms of speed.

You can use a Channel as an async queue. Each task can write to the channel when it completes. Items in the channel will be returned as an IAsyncEnumerable through ChannelReader.ReadAllAsync .

IAsyncEnumerable<T> ToAsyncEnumerable<T>(IEnumerable<Task<T>> inputTasks)
{
    var channel=Channel.CreateUnbounded<T>();
    var writer=channel.Writer;
    var continuations=inputTasks.Select(t=>t.ContinueWith(x=>
                                           writer.TryWrite(x.Result)));
    _ = Task.WhenAll(continuations)
            .ContinueWith(t=>writer.Complete(t.Exception));

    return channel.Reader.ReadAllAsync();
}

When all tasks complete writer.Complete() is called to close the channel.

To test this, this code produces tasks with decreasing delays. This should return the indexes in reverse order:

var tasks=Enumerable.Range(1,4)
                    .Select(async i=>
                    { 
                      await Task.Delay(300*(5-i));
                      return i;
                    });

await foreach(var i in Interleave(tasks))
{
     Console.WriteLine(i);

}

Produces:

4
3
2
1

Just for the fun of it, using System.Reactive and System.Interactive.Async :

public static async IAsyncEnumerable<TResult> WhenEach<TResult>(
    Task<TResult>[] tasks)
    => Observable.Merge(tasks.Select(t => t.ToObservable())).ToAsyncEnumerable()

I really liked the solution provided by Panagiotis , but still wanted to get exceptions raised as they happen like in JohanP's solution.

To achieve that we can slightly modify that to try closing the channel in the continuations when a task fails:

public IAsyncEnumerable<T> ToAsyncEnumerable<T>(IEnumerable<Task<T>> inputTasks)
{
    if (inputTasks == null)
    {
        throw new ArgumentNullException(nameof(inputTasks), "Task list must not be null.");
    }

    var channel = Channel.CreateUnbounded<T>();
    var channelWriter = channel.Writer;
    var inputTaskContinuations = inputTasks.Select(inputTask => inputTask.ContinueWith(completedInputTask =>
    {
        // Check whether the task succeeded or not
        if (completedInputTask.Status == TaskStatus.RanToCompletion)
        {
            // Write the result to the channel on successful completion
            channelWriter.TryWrite(completedInputTask.Result);
        }
        else
        {
            // Complete the channel on failure to immediately communicate the failure to the caller and prevent additional results from being returned
            var taskException = completedInputTask.Exception?.InnerException ?? completedInputTask?.Exception;
            channelWriter.TryComplete(taskException);
        }
    }));

    // Ensure the writer is closed after the tasks are all complete, and propagate any exceptions from the continuations
    _ = Task.WhenAll(inputTaskContinuations).ContinueWith(completedInputTaskContinuationsTask => channelWriter.TryComplete(completedInputTaskContinuationsTask.Exception));

    // Return the async enumerator of the channel so results are yielded to the caller as they're available
    return channel.Reader.ReadAllAsync();
}

The obvious downside to this is that the first error encountered will end enumeration and prevent any other, possibly successful, results from being returned. This is a tradeoff that's acceptable for my use case, but may not be for others.

I am adding one more answer to this question, because there are a couple of issues that need to be addressed.

  1. It is recommended that methods creating async-enumerable sequences should have a CancellationToken parameter. This enables the WithCancellation configuration in await foreach loops.
  2. It is recommended that when an asynchronous operation attaches continuations to tasks, these continuations should be cleaned up when the operation completes. So if for example the caller of the WhenEach method decide to exit prematurely the await foreach loop (using break , return etc), or if the loop terminates prematurely because of an exception, we don't want to leave a bunch of dead continuations hanging around, attached to the tasks. This can be particularly important if the WhenEach is called repeatedly in a loop (as part of a Retry functionality for example).

The implementation below addresses these two issues. It is based on a Channel<Task<TResult>> . Now the channels have become an integral part of the .NET platform, so there is no reason to avoid them in favor of more complex TaskCompletionSource -based solutions.

public async static IAsyncEnumerable<TResult> WhenEach<TResult>(
    Task<TResult>[] tasks,
    [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
    var channel = Channel.CreateUnbounded<Task<TResult>>();
    using var linkedCts = CancellationTokenSource
        .CreateLinkedTokenSource(cancellationToken);
    var continuations = new List<Task>(tasks.Length);

    try
    {
        int pendingCount = tasks.Length;
        foreach (var task in tasks)
        {
            if (task == null) throw new ArgumentException(
                $"The tasks argument included a null value.", nameof(tasks));
            continuations.Add(task.ContinueWith(t =>
            {
                var accepted = channel.Writer.TryWrite(t);
                Debug.Assert(accepted);
                if (Interlocked.Decrement(ref pendingCount) == 0)
                    channel.Writer.Complete();
            }, linkedCts.Token, TaskContinuationOptions.ExecuteSynchronously |
                TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default));
        }

        await foreach (var task in channel.Reader.ReadAllAsync(cancellationToken)
            .ConfigureAwait(false))
        {
            yield return await task.ConfigureAwait(false);
            cancellationToken.ThrowIfCancellationRequested();
        }
    }
    finally
    {
        linkedCts.Cancel();
        try { await Task.WhenAll(continuations).ConfigureAwait(false); }
        catch (OperationCanceledException) { } // Ignore
    }
}

The finally block takes care of cancelling the attached continuations, and awaiting them to complete before exiting.

The ThrowIfCancellationRequested inside the await foreach loop might seem redundant, but it is actually required because of a by-design behavior of the ReadAllAsync method, that is explained here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM