繁体   English   中英

如何实现一个高效的 WhenEach 流式传输任务结果的 IAsyncEnumerable?

[英]How to implement an efficient WhenEach that streams an IAsyncEnumerable of task results?

我正在尝试使用C# 8提供的新工具更新我的工具集,其中一个似乎特别有用的方法是返回IAsyncEnumerable Task.WhenAll 这种方法应该 stream 任务一旦可用就会产生结果,因此将其命名为WhenAll没有多大意义。 WhenEach听起来更合适。 该方法的签名是:

public static IAsyncEnumerable<TResult> WhenEach<TResult>(Task<TResult>[] tasks);

这种方法可以像这样使用:

var tasks = new Task<int>[]
{
    ProcessAsync(1, 300),
    ProcessAsync(2, 500),
    ProcessAsync(3, 400),
    ProcessAsync(4, 200),
    ProcessAsync(5, 100),
};

await foreach (int result in WhenEach(tasks))
{
    Console.WriteLine($"Processed: {result}");
}

static async Task<int> ProcessAsync(int result, int delay)
{
    await Task.Delay(delay);
    return result;
}

预期 output:

已处理:5
已处理:4
已处理:1
已处理:3
已处理:2

我设法在循环中使用Task.WhenAny方法编写了一个基本实现,但是这种方法存在问题:

public static async IAsyncEnumerable<TResult> WhenEach<TResult>(
    Task<TResult>[] tasks)
{
    var hashSet = new HashSet<Task<TResult>>(tasks);
    while (hashSet.Count > 0)
    {
        var task = await Task.WhenAny(hashSet).ConfigureAwait(false);
        yield return await task.ConfigureAwait(false);
        hashSet.Remove(task);
    }
}

问题是性能。 Task.WhenAny实现创建了所提供任务列表的防御性副本,因此在循环中重复调用它会导致 O(n²) 计算复杂度。 我幼稚的实现很难处理 10,000 个任务。 我的机器上的开销将近 10 秒。 我希望该方法几乎与内置Task.WhenAll ,可以轻松处理数十万个任务。 如何改进WhenEach方法以使其表现得体?

通过使用本文中的代码,您可以实现以下功能:

public static Task<Task<T>>[] Interleaved<T>(IEnumerable<Task<T>> tasks)
{
   var inputTasks = tasks.ToList();

   var buckets = new TaskCompletionSource<Task<T>>[inputTasks.Count];
   var results = new Task<Task<T>>[buckets.Length];
   for (int i = 0; i < buckets.Length; i++)
   {
       buckets[i] = new TaskCompletionSource<Task<T>>();
       results[i] = buckets[i].Task;
   }

   int nextTaskIndex = -1;
   Action<Task<T>> continuation = completed =>
   {
       var bucket = buckets[Interlocked.Increment(ref nextTaskIndex)];
       bucket.TrySetResult(completed);
   };

   foreach (var inputTask in inputTasks)
       inputTask.ContinueWith(continuation, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);

   return results;
}

然后更改您的WhenEach以调用Interleaved代码

public static async IAsyncEnumerable<TResult> WhenEach<TResult>(Task<TResult>[] tasks)
{
    foreach (var bucket in Interleaved(tasks))
    {
        var t = await bucket;
        yield return await t;
    }
}

然后你可以像往常一样打电话给你的WhenEach

await foreach (int result in WhenEach(tasks))
{
    Console.WriteLine($"Processed: {result}");
}

我对 10k 个任务进行了一些基本的基准测试,并在速度方面提高了 5 倍。

您可以将 Channel 用作异步队列。 每个任务完成后都可以写入通道。 通道中的项目将通过ChannelReader.ReadAllAsync作为 IAsyncEnumerable 返回。

IAsyncEnumerable<T> ToAsyncEnumerable<T>(IEnumerable<Task<T>> inputTasks)
{
    var channel=Channel.CreateUnbounded<T>();
    var writer=channel.Writer;
    var continuations=inputTasks.Select(t=>t.ContinueWith(x=>
                                           writer.TryWrite(x.Result)));
    _ = Task.WhenAll(continuations)
            .ContinueWith(t=>writer.Complete(t.Exception));

    return channel.Reader.ReadAllAsync();
}

当所有任务完成时,调用writer.Complete()以关闭通道。

为了测试这一点,此代码生成具有递减延迟的任务。 这应该以相反的顺序返回索引:

var tasks=Enumerable.Range(1,4)
                    .Select(async i=>
                    { 
                      await Task.Delay(300*(5-i));
                      return i;
                    });

await foreach(var i in Interleave(tasks))
{
     Console.WriteLine(i);

}

产生:

4
3
2
1

只是为了好玩,使用System.ReactiveSystem.Interactive.Async

public static async IAsyncEnumerable<TResult> WhenEach<TResult>(
    Task<TResult>[] tasks)
    => Observable.Merge(tasks.Select(t => t.ToObservable())).ToAsyncEnumerable()

我真的很喜欢Panagiotis 提供的解决方案,但仍然希望引发异常,就像在 JohanP 的解决方案中一样。

为了实现这一点,我们可以稍微修改一下,在任务失败时尝试关闭通道:

public IAsyncEnumerable<T> ToAsyncEnumerable<T>(IEnumerable<Task<T>> inputTasks)
{
    if (inputTasks == null)
    {
        throw new ArgumentNullException(nameof(inputTasks), "Task list must not be null.");
    }

    var channel = Channel.CreateUnbounded<T>();
    var channelWriter = channel.Writer;
    var inputTaskContinuations = inputTasks.Select(inputTask => inputTask.ContinueWith(completedInputTask =>
    {
        // Check whether the task succeeded or not
        if (completedInputTask.Status == TaskStatus.RanToCompletion)
        {
            // Write the result to the channel on successful completion
            channelWriter.TryWrite(completedInputTask.Result);
        }
        else
        {
            // Complete the channel on failure to immediately communicate the failure to the caller and prevent additional results from being returned
            var taskException = completedInputTask.Exception?.InnerException ?? completedInputTask?.Exception;
            channelWriter.TryComplete(taskException);
        }
    }));

    // Ensure the writer is closed after the tasks are all complete, and propagate any exceptions from the continuations
    _ = Task.WhenAll(inputTaskContinuations).ContinueWith(completedInputTaskContinuationsTask => channelWriter.TryComplete(completedInputTaskContinuationsTask.Exception));

    // Return the async enumerator of the channel so results are yielded to the caller as they're available
    return channel.Reader.ReadAllAsync();
}

这样做的明显缺点是遇到的第一个错误将结束枚举并阻止返回任何其他可能成功的结果。 这是我的用例可以接受的权衡,但可能不适用于其他用例。

我要为这个问题再添加一个答案,因为有几个问题需要解决。

  1. 建议创建异步可枚举序列的方法应具有CancellationToken参数。 这会在await foreach循环中启用WithCancellation配置。
  2. 建议当异步操作将延续附加到任务时,应在操作完成时清理这些延续。 因此,例如,如果WhenEach方法的调用者决定提前退出await foreach循环(使用breakreturn等),或者如果循环由于异常而提前终止,我们不想让一堆死延续挂起左右,执着于任务。 如果在循环中重复调用WhenEach (例如,作为Retry功能的一部分),这一点尤其重要。

下面的实现解决了这两个问题。 它基于Channel<Task<TResult>> 现在通道已成为 .NET 平台不可或缺的一部分,因此没有理由避免使用基于更复杂TaskCompletionSource的解决方案。

public async static IAsyncEnumerable<TResult> WhenEach<TResult>(
    Task<TResult>[] tasks,
    [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
    var channel = Channel.CreateUnbounded<Task<TResult>>();
    using var linkedCts = CancellationTokenSource
        .CreateLinkedTokenSource(cancellationToken);
    var continuations = new List<Task>(tasks.Length);

    try
    {
        int pendingCount = tasks.Length;
        foreach (var task in tasks)
        {
            if (task == null) throw new ArgumentException(
                $"The tasks argument included a null value.", nameof(tasks));
            continuations.Add(task.ContinueWith(t =>
            {
                var accepted = channel.Writer.TryWrite(t);
                Debug.Assert(accepted);
                if (Interlocked.Decrement(ref pendingCount) == 0)
                    channel.Writer.Complete();
            }, linkedCts.Token, TaskContinuationOptions.ExecuteSynchronously |
                TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default));
        }

        await foreach (var task in channel.Reader.ReadAllAsync(cancellationToken)
            .ConfigureAwait(false))
        {
            yield return await task.ConfigureAwait(false);
            cancellationToken.ThrowIfCancellationRequested();
        }
    }
    finally
    {
        linkedCts.Cancel();
        try { await Task.WhenAll(continuations).ConfigureAwait(false); }
        catch (OperationCanceledException) { } // Ignore
    }
}

finally块负责取消附加的延续,并在退出之前等待它们完成。

await foreach循环中的ThrowIfCancellationRequested可能看起来多余,但实际上是必需的,因为ReadAllAsync方法的设计行为,这在此处进行了解释。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM