简体   繁体   English

线程安全的缓存枚举器 - 带收益的锁

[英]Thread-safe Cached Enumerator - lock with yield

I have a custom "CachedEnumerable" class (inspired by Caching IEnumerable ) that I need to make thread safe for my asp.net core web app.我有一个自定义的“CachedEnumerable”class(受缓存 IEnumerable启发),我需要为我的 asp.net 核心 Z2567A5EC9705EB7AC2C984033E0618 应用程序设置线程安全。

Is the following implementation of the Enumerator thread safe? Enumerator 线程的以下实现是否安全? (All other reads/writes to IList _cache are locked appropriately) (Possibly related to Does the C# Yield free a lock? ) (对 IList _cache 的所有其他读/写操作均已适当锁定)(可能与Does the C# Yield free a lock?相关)

And more specifically, if there are 2 threads accessing the enumerator, how do I protect against one thread incrementing "index" causing a second enumerating thread from getting the wrong element from the _cache (ie. element at index + 1 instead of at index)?更具体地说,如果有 2 个线程访问枚举器,我如何防止一个线程递增“索引”导致第二个枚举线程从 _cache 获取错误的元素(即索引 + 1 处的元素而不是索引处的元素) ? Is this race condition a real concern?这种比赛条件是一个真正的问题吗?

public IEnumerator<T> GetEnumerator()
{
    var index = 0;

    while (true)
    {
        T current;
        lock (_enumeratorLock)
        {
            if (index >= _cache.Count && !MoveNext()) break;
            current = _cache[index];
            index++;
        }
        yield return current;
    }
}

Full code of my version of CachedEnumerable:我的 CachedEnumerable 版本的完整代码:

 public class CachedEnumerable<T> : IDisposable, IEnumerable<T>
    {
        IEnumerator<T> _enumerator;
        private IList<T> _cache = new List<T>();
        public bool CachingComplete { get; private set; } = false;

        public CachedEnumerable(IEnumerable<T> enumerable)
        {
            switch (enumerable)
            {
                case CachedEnumerable<T> cachedEnumerable: //This case is actually dealt with by the extension method.
                    _cache = cachedEnumerable._cache;
                    CachingComplete = cachedEnumerable.CachingComplete;
                    _enumerator = cachedEnumerable.GetEnumerator();

                    break;
                case IList<T> list:
                    //_cache = list; //without clone...
                    //Clone:
                    _cache = new T[list.Count];
                    list.CopyTo((T[]) _cache, 0);
                    CachingComplete = true;
                    break;
                default:
                    _enumerator = enumerable.GetEnumerator();
                    break;
            }
        }

        public CachedEnumerable(IEnumerator<T> enumerator)
        {
            _enumerator = enumerator;
        }

        private int CurCacheCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public IEnumerator<T> GetEnumerator()
        {
            var index = 0;

            while (true)
            {
                T current;
                lock (_enumeratorLock)
                {
                    if (index >= _cache.Count && !MoveNext()) break;
                    current = _cache[index];
                    index++;
                }
                yield return current;
            }
        }

        //private readonly AsyncLock _enumeratorLock = new AsyncLock();
        private readonly object _enumeratorLock = new object();

        private bool MoveNext()
        {
            if (CachingComplete) return false;

            if (_enumerator != null && _enumerator.MoveNext()) //The null check should have been unnecessary b/c of the lock...
            {
                _cache.Add(_enumerator.Current);
                return true;
            }
            else
            {
                CachingComplete = true;
                DisposeWrappedEnumerator(); //Release the enumerator, as it is no longer needed.
            }

            return false;
        }

        public T ElementAt(int index)
        {
            lock (_enumeratorLock)
            {
                if (index < _cache.Count)
                {
                    return _cache[index];
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) throw new ArgumentOutOfRangeException(nameof(index));
                return _cache[index];
            }
        }


        public bool TryGetElementAt(int index, out T value)
        {
            lock (_enumeratorLock)
            {
                value = default;
                if (index < CurCacheCount)
                {
                    value = _cache[index];
                    return true;
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) return false;
                value = _cache[index];
            }

            return true;
        }

        private void EnumerateUntil(int index)
        {
            while (true)
            {
                lock (_enumeratorLock)
                {
                    if (_cache.Count > index || !MoveNext()) break;
                }
            }
        }


        public void Dispose()
        {
            DisposeWrappedEnumerator();
        }

        private void DisposeWrappedEnumerator()
        {
            if (_enumerator != null)
            {
                _enumerator.Dispose();
                _enumerator = null;
                if (_cache is List<T> list)
                {
                    list.Trim();
                }
            }
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }

        public int CachedCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public int Count()
        {
            if (CachingComplete)
            {
                return _cache.Count;
            }

            EnsureCachingComplete();

            return _cache.Count;
        }

        private void EnsureCachingComplete()
        {
            if (CachingComplete)
            {
                return;
            }

            //Enumerate the rest of the collection
            while (!CachingComplete)
            {
                lock (_enumeratorLock)
                {
                    if (!MoveNext()) break;
                }
            }
        }

        public T[] ToArray()
        {
            EnsureCachingComplete();
            //Once Caching is complete, we don't need to lock
            if (!(_cache is T[] array))
            {
                array = _cache.ToArray();
                _cache = array;
            }

            return array;
        }

        public T this[int index] => ElementAt(index);
    }

    public static CachedEnumerable<T> Cached<T>(this IEnumerable<T> source)
    {
        //no gain in caching a cache.
        if (source is CachedEnumerable<T> cached)
        {
            return cached;
        }

        return new CachedEnumerable<T>(source);
    }
}

Basic Usage: (Although not a meaningful use case)基本用法:(虽然不是一个有意义的用例)

var cached = expensiveEnumerable.Cached();
foreach (var element in cached) {
   Console.WriteLine(element);
}

Update更新

I tested the current implementation based on @Theodors answer https://stackoverflow.com/a/58547863/5683904 and confirmed (AFAICT) that it is thread-safe when enumerated with a foreach without creating duplicate values ( Thread-safe Cached Enumerator - lock with yield ):我根据@Theodors 回答https://stackoverflow.com/a/58547863/5683904测试了当前实现,并确认(AFAICT)在使用 foreach 枚举而不创建重复值时它是线程安全的( Thread-safe Cached Enumerator -锁定收益):

class Program
{
    static async Task Main(string[] args)
    {
        var enumerable = Enumerable.Range(0, 1_000_000);
        var cachedEnumerable = new CachedEnumerable<int>(enumerable);
        var c = new ConcurrentDictionary<int, List<int>>();
        var tasks = Enumerable.Range(1, 100).Select(id => Test(id, cachedEnumerable, c));
        Task.WaitAll(tasks.ToArray());
        foreach (var keyValuePair in c)
        {
            var hasDuplicates = keyValuePair.Value.Distinct().Count() != keyValuePair.Value.Count;
            Console.WriteLine($"Task #{keyValuePair.Key} count: {keyValuePair.Value.Count}. Has duplicates? {hasDuplicates}");
        }
    }

    static async Task Test(int id, IEnumerable<int> cache, ConcurrentDictionary<int, List<int>> c)
    {
        foreach (var i in cache)
        {
            //await Task.Delay(10);
            c.AddOrUpdate(id, v => new List<int>() {i}, (k, v) =>
            {
                v.Add(i);
                return v;
            });
        }
    }
}

Your class is not thread safe, because shared state is mutated in unprotected regions inside your class.您的 class 不是线程安全的,因为共享的 state 在 class 内的未受保护区域中发生突变。 The unprotected regions are:未受保护的区域是:

  1. The constructor构造函数
  2. The Dispose method Dispose方法

The shared state is:共享的 state 是:

  1. The _enumerator private field _enumerator私有字段
  2. The _cache private field _cache私有字段
  3. The CachingComplete public property CachingComplete公共属性

Some other issues regarding your class:关于 class 的其他一些问题:

  1. Implementing IDisposable creates the responsibility to the caller to dispose your class.实现IDisposable为调用者创建了处置您的 class 的责任。 There is no need for IEnumerable s to be disposable. IEnumerable不需要是一次性的。 In the contrary IEnumerator s are disposable, but there is language support for their automatic disposal (feature of foreach statement).相反, IEnumerator是一次性的,但它们的自动处理有语言支持( foreach语句的特性)。
  2. Your class offers extended functionality not expected from an IEnumerable ( ElementAt , Count etc).您的 class 提供了IEnumerableElementAtCount等)无法提供的扩展功能。 Maybe you intended to implement a CachedList instead?也许您打算改为实现一个CachedList Without implementing the IList<T> interface, LINQ methods like Count() and ToArray() cannot take advantage of your extended functionality, and will use the slow path like they do with plain vanilla IEnumerable s.如果不实现IList<T>接口,则 LINQ 方法(如Count()ToArray()无法利用您的扩展功能,并且会像使用普通的IEnumerable一样使用慢速路径。

Update: I just noticed another thread-safety issue.更新:我刚刚注意到另一个线程安全问题。 This one is related to the public IEnumerator<T> GetEnumerator() method.这与public IEnumerator<T> GetEnumerator()方法有关。 The enumerator is compiler-generated, since the method is an iterator (utilizes yield return ).枚举器是编译器生成的,因为该方法是一个迭代器(利用yield return )。 Compiler-generated enumerators are not thread safe.编译器生成的枚举器不是线程安全的。 Consider this code for example:例如,考虑以下代码:

var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var enumerator = cachedEnumerable.GetEnumerator();
var tasks = Enumerable.Range(1, 4).Select(id => Task.Run(() =>
{
    int count = 0;
    while (enumerator.MoveNext())
    {
        count++;
    }
    Console.WriteLine($"Task #{id} count: {count}");
})).ToArray();
Task.WaitAll(tasks);

Four threads are using concurrently the same IEnumerator .四个线程同时使用相同的IEnumerator The enumerable has 1,000,000 items.可枚举项有 1,000,000 项。 You may expect that each thread would enumerate ~250,000 items, but that's not what happens.您可能期望每个线程会枚举约 250,000 个项目,但事实并非如此。

Output: Output:

Task #1 count: 0任务 #1 计数:0
Task #4 count: 0任务 #4 计数:0
Task #3 count: 0任务 #3 计数:0
Task #2 count: 1000000任务 #2 计数:1000000

The MoveNext in the line while (enumerator.MoveNext()) is not your safe MoveNext .行中的MoveNext while (enumerator.MoveNext())不是您安全的MoveNext It is the compiler-generated unsafe MoveNext .它是编译器生成的不安全MoveNext Although unsafe, it includes a mechanism intended probably for dealing with exceptions , that marks temporarily the enumerator as finished before calling the externally provided code.虽然不安全,但它包含一种可能用于处理异常的机制,在调用外部提供的代码之前临时将枚举器标记为已完成。 So when multiple threads are calling the MoveNext concurrently, all but the first will get a return value of false , and will terminate instantly the enumeration, having completed zero loops.因此,当多个线程同时调用MoveNext时,除第一个线程外,所有线程都将获得false的返回值,并且将立即终止枚举,完成零循环。 To solve this you must probably code your own IEnumerator class.要解决这个问题,您可能必须编写自己的IEnumerator class。


Update: Actually my last point about thread-safe enumeration is a bit unfair, because enumerating with the IEnumerator interface is an inherently unsafe operation, which is impossible to fix without the cooperation of the calling code.更新:实际上我关于线程安全枚举的最后一点有点不公平,因为使用IEnumerator接口进行枚举本质上是一种不安全的操作,如果没有调用代码的合作,这是无法修复的。 This is because obtaining the next element is not an atomic operation, since it involves two steps (call MoveNext() + read Current ).这是因为获取下一个元素不是原子操作,因为它涉及两个步骤(调用MoveNext() + 读取Current )。 So your thread-safety concerns are limited to the protection of the internal state of your class (fields _enumerator , _cache and CachingComplete ).因此,您的线程安全问题仅限于保护 class 的内部 state(字段_enumerator_cacheCachingComplete )。 These are left unprotected only in the constructor and in the Dispose method, but I suppose that the normal use of your class may not follow code paths that create the race conditions that would result to internal state corruption.这些仅在构造函数和Dispose方法中未受保护,但我认为 class 的正常使用可能不会遵循创建会导致内部 state 损坏的竞争条件的代码路径。

Personally I would prefer to take care of these code paths too, and I wouldn't let it to the whims of chance.就我个人而言,我也更愿意处理这些代码路径,我不会让它随心所欲。


Update: I wrote a cache for IAsyncEnumerable s, to demonstrate an alternative technique.更新:我为IAsyncEnumerable编写了一个缓存,以演示另一种技术。 The enumeration of the source IAsyncEnumerable is not driven by the callers, using locks or semaphores to obtain exclusive access, but by a separate worker-task.IAsyncEnumerable的枚举不是由调用者驱动的,使用锁或信号量来获得独占访问,而是由单独的工作任务驱动。 The first caller starts the worker-task.第一个调用者启动工作任务。 Each caller at first yields all items that are already cached, and then awaits for more items, or for a notification that there are no more items.每个调用者首先产生所有已缓存的项目,然后等待更多项目,或等待没有更多项目的通知。 As notification mechanism I used a TaskCompletionSource<bool> .作为通知机制,我使用了TaskCompletionSource<bool> A lock is still used to ensure that all access to shared resources is synchronized.仍然使用lock来确保对共享资源的所有访问都是同步的。

public class CachedAsyncEnumerable<T> : IAsyncEnumerable<T>
{
    private readonly object _locker = new object();
    private IAsyncEnumerable<T> _source;
    private Task _sourceEnumerationTask;
    private List<T> _buffer;
    private TaskCompletionSource<bool> _moveNextTCS;
    private Exception _sourceEnumerationException;
    private int _sourceEnumerationVersion; // Incremented on exception

    public CachedAsyncEnumerable(IAsyncEnumerable<T> source)
    {
        _source = source ?? throw new ArgumentNullException(nameof(source));
    }

    public async IAsyncEnumerator<T> GetAsyncEnumerator(
        CancellationToken cancellationToken = default)
    {
        lock (_locker)
        {
            if (_sourceEnumerationTask == null)
            {
                _buffer = new List<T>();
                _moveNextTCS = new TaskCompletionSource<bool>();
                _sourceEnumerationTask = Task.Run(
                    () => EnumerateSourceAsync(cancellationToken));
            }
        }
        int index = 0;
        int localVersion = -1;
        while (true)
        {
            T current = default;
            Task<bool> moveNextTask = null;
            lock (_locker)
            {
                if (localVersion == -1)
                {
                    localVersion = _sourceEnumerationVersion;
                }
                else if (_sourceEnumerationVersion != localVersion)
                {
                    ExceptionDispatchInfo
                        .Capture(_sourceEnumerationException).Throw();
                }
                if (index < _buffer.Count)
                {
                    current = _buffer[index];
                    index++;
                }
                else
                {
                    moveNextTask = _moveNextTCS.Task;
                }
            }
            if (moveNextTask == null)
            {
                yield return current;
                continue;
            }
            var moved = await moveNextTask;
            if (!moved) yield break;
            lock (_locker)
            {
                current = _buffer[index];
                index++;
            }
            yield return current;
        }
    }

    private async Task EnumerateSourceAsync(CancellationToken cancellationToken)
    {
        TaskCompletionSource<bool> localMoveNextTCS;
        try
        {
            await foreach (var item in _source.WithCancellation(cancellationToken))
            {
                lock (_locker)
                {
                    _buffer.Add(item);
                    localMoveNextTCS = _moveNextTCS;
                    _moveNextTCS = new TaskCompletionSource<bool>();
                }
                localMoveNextTCS.SetResult(true);
            }
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _buffer.TrimExcess();
                _source = null;
            }
            localMoveNextTCS.SetResult(false);
        }
        catch (Exception ex)
        {
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _sourceEnumerationException = ex;
                _sourceEnumerationVersion++;
                _sourceEnumerationTask = null;
            }
            localMoveNextTCS.SetException(ex);
        }
    }
}

This implementation follows a specific strategy for dealing with exceptions.此实现遵循处理异常的特定策略。 If an exception occurs while enumerating the source IAsyncEnumerable , the exception will be propagated to all current callers, the currently used IAsyncEnumerator will be discarded, and the incomplete cached data will be discarded too.如果在枚举源IAsyncEnumerable时发生异常,该异常将传播到所有当前调用者,当前使用的IAsyncEnumerator将被丢弃,不完整的缓存数据也将被丢弃。 A new worker-task may start again later, when the next enumeration request is received.当接收到下一个枚举请求时,新的工作任务可能会在稍后再次启动。

The access to cache, yes it is thread safe, only one thread per time can read from _cache object.访问缓存,是的,它是线程安全的,每次只能从 _cache object 读取一个线程。

But in that way you can't assure that all threads gets elements in the same order as they access to GetEnumerator.但是这样一来,您就不能保证所有线程都以与访问 GetEnumerator 相同的顺序获取元素。

Check these two exaples, if the behavior is what you expect, you can use lock in that way.检查这两个例子,如果行为是你所期望的,你可以用那种方式使用锁。

Example 1:示例 1:

THREAD1 Calls GetEnumerator THREAD1 调用 GetEnumerator

THREAD1 Initialize T current; THREAD1 初始化 T 电流;

THREAD2 Calls GetEnumerator THREAD2 调用 GetEnumerator

THREAD2 Initialize T current; THREAD2 初始化 T 电流;

THREAD2 LOCK THREAD螺纹 2 锁螺纹

THREAD1 WAIT线程 1 等待

THREAD2 read from cache safely _cache[0] THREAD2 从缓存中安全读取 _cache[0]

THREAD2 index++ THREAD2 索引++

THREAD2 UNLOCK线程2解锁

THREAD1 LOCK螺纹1锁

THREAD1 read from cache safely _cache[1] THREAD1 从缓存中安全读取 _cache[1]

THREAD1 i++线程1 i++

THREAD1 UNLOCK线程 1 解锁

THREAD2 yield return current; THREAD2 产生返回电流;

THREAD1 yield return current; THREAD1 产生返回电流;


Example 2:示例 2:

THREAD2 Initialize T current; THREAD2 初始化 T 电流;

THREAD2 LOCK THREAD螺纹 2 锁螺纹

THREAD2 read from cache safely THREAD2 安全地从缓存中读取

THREAD2 UNLOCK线程2解锁

THREAD1 Initialize T current; THREAD1 初始化 T 电流;

THREAD1 LOCK THREAD螺纹 1 锁螺纹

THREAD1 read from cache safely THREAD1 从缓存中安全读取

THREAD1 UNLOCK线程 1 解锁

THREAD1 yield return current; THREAD1 产生返回电流;

THREAD2 yield return current; THREAD2 产生返回电流;

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM