繁体   English   中英

线程安全的缓存枚举器 - 带收益的锁

[英]Thread-safe Cached Enumerator - lock with yield

我有一个自定义的“CachedEnumerable”class(受缓存 IEnumerable启发),我需要为我的 asp.net 核心 Z2567A5EC9705EB7AC2C984033E0618 应用程序设置线程安全。

Enumerator 线程的以下实现是否安全? (对 IList _cache 的所有其他读/写操作均已适当锁定)(可能与Does the C# Yield free a lock?相关)

更具体地说,如果有 2 个线程访问枚举器,我如何防止一个线程递增“索引”导致第二个枚举线程从 _cache 获取错误的元素(即索引 + 1 处的元素而不是索引处的元素) ? 这种比赛条件是一个真正的问题吗?

public IEnumerator<T> GetEnumerator()
{
    var index = 0;

    while (true)
    {
        T current;
        lock (_enumeratorLock)
        {
            if (index >= _cache.Count && !MoveNext()) break;
            current = _cache[index];
            index++;
        }
        yield return current;
    }
}

我的 CachedEnumerable 版本的完整代码:

 public class CachedEnumerable<T> : IDisposable, IEnumerable<T>
    {
        IEnumerator<T> _enumerator;
        private IList<T> _cache = new List<T>();
        public bool CachingComplete { get; private set; } = false;

        public CachedEnumerable(IEnumerable<T> enumerable)
        {
            switch (enumerable)
            {
                case CachedEnumerable<T> cachedEnumerable: //This case is actually dealt with by the extension method.
                    _cache = cachedEnumerable._cache;
                    CachingComplete = cachedEnumerable.CachingComplete;
                    _enumerator = cachedEnumerable.GetEnumerator();

                    break;
                case IList<T> list:
                    //_cache = list; //without clone...
                    //Clone:
                    _cache = new T[list.Count];
                    list.CopyTo((T[]) _cache, 0);
                    CachingComplete = true;
                    break;
                default:
                    _enumerator = enumerable.GetEnumerator();
                    break;
            }
        }

        public CachedEnumerable(IEnumerator<T> enumerator)
        {
            _enumerator = enumerator;
        }

        private int CurCacheCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public IEnumerator<T> GetEnumerator()
        {
            var index = 0;

            while (true)
            {
                T current;
                lock (_enumeratorLock)
                {
                    if (index >= _cache.Count && !MoveNext()) break;
                    current = _cache[index];
                    index++;
                }
                yield return current;
            }
        }

        //private readonly AsyncLock _enumeratorLock = new AsyncLock();
        private readonly object _enumeratorLock = new object();

        private bool MoveNext()
        {
            if (CachingComplete) return false;

            if (_enumerator != null && _enumerator.MoveNext()) //The null check should have been unnecessary b/c of the lock...
            {
                _cache.Add(_enumerator.Current);
                return true;
            }
            else
            {
                CachingComplete = true;
                DisposeWrappedEnumerator(); //Release the enumerator, as it is no longer needed.
            }

            return false;
        }

        public T ElementAt(int index)
        {
            lock (_enumeratorLock)
            {
                if (index < _cache.Count)
                {
                    return _cache[index];
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) throw new ArgumentOutOfRangeException(nameof(index));
                return _cache[index];
            }
        }


        public bool TryGetElementAt(int index, out T value)
        {
            lock (_enumeratorLock)
            {
                value = default;
                if (index < CurCacheCount)
                {
                    value = _cache[index];
                    return true;
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) return false;
                value = _cache[index];
            }

            return true;
        }

        private void EnumerateUntil(int index)
        {
            while (true)
            {
                lock (_enumeratorLock)
                {
                    if (_cache.Count > index || !MoveNext()) break;
                }
            }
        }


        public void Dispose()
        {
            DisposeWrappedEnumerator();
        }

        private void DisposeWrappedEnumerator()
        {
            if (_enumerator != null)
            {
                _enumerator.Dispose();
                _enumerator = null;
                if (_cache is List<T> list)
                {
                    list.Trim();
                }
            }
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }

        public int CachedCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public int Count()
        {
            if (CachingComplete)
            {
                return _cache.Count;
            }

            EnsureCachingComplete();

            return _cache.Count;
        }

        private void EnsureCachingComplete()
        {
            if (CachingComplete)
            {
                return;
            }

            //Enumerate the rest of the collection
            while (!CachingComplete)
            {
                lock (_enumeratorLock)
                {
                    if (!MoveNext()) break;
                }
            }
        }

        public T[] ToArray()
        {
            EnsureCachingComplete();
            //Once Caching is complete, we don't need to lock
            if (!(_cache is T[] array))
            {
                array = _cache.ToArray();
                _cache = array;
            }

            return array;
        }

        public T this[int index] => ElementAt(index);
    }

    public static CachedEnumerable<T> Cached<T>(this IEnumerable<T> source)
    {
        //no gain in caching a cache.
        if (source is CachedEnumerable<T> cached)
        {
            return cached;
        }

        return new CachedEnumerable<T>(source);
    }
}

基本用法:(虽然不是一个有意义的用例)

var cached = expensiveEnumerable.Cached();
foreach (var element in cached) {
   Console.WriteLine(element);
}

更新

我根据@Theodors 回答https://stackoverflow.com/a/58547863/5683904测试了当前实现,并确认(AFAICT)在使用 foreach 枚举而不创建重复值时它是线程安全的( Thread-safe Cached Enumerator -锁定收益):

class Program
{
    static async Task Main(string[] args)
    {
        var enumerable = Enumerable.Range(0, 1_000_000);
        var cachedEnumerable = new CachedEnumerable<int>(enumerable);
        var c = new ConcurrentDictionary<int, List<int>>();
        var tasks = Enumerable.Range(1, 100).Select(id => Test(id, cachedEnumerable, c));
        Task.WaitAll(tasks.ToArray());
        foreach (var keyValuePair in c)
        {
            var hasDuplicates = keyValuePair.Value.Distinct().Count() != keyValuePair.Value.Count;
            Console.WriteLine($"Task #{keyValuePair.Key} count: {keyValuePair.Value.Count}. Has duplicates? {hasDuplicates}");
        }
    }

    static async Task Test(int id, IEnumerable<int> cache, ConcurrentDictionary<int, List<int>> c)
    {
        foreach (var i in cache)
        {
            //await Task.Delay(10);
            c.AddOrUpdate(id, v => new List<int>() {i}, (k, v) =>
            {
                v.Add(i);
                return v;
            });
        }
    }
}

您的 class 不是线程安全的,因为共享的 state 在 class 内的未受保护区域中发生突变。 未受保护的区域是:

  1. 构造函数
  2. Dispose方法

共享的 state 是:

  1. _enumerator私有字段
  2. _cache私有字段
  3. CachingComplete公共属性

关于 class 的其他一些问题:

  1. 实现IDisposable为调用者创建了处置您的 class 的责任。 IEnumerable不需要是一次性的。 相反, IEnumerator是一次性的,但它们的自动处理有语言支持( foreach语句的特性)。
  2. 您的 class 提供了IEnumerableElementAtCount等)无法提供的扩展功能。 也许您打算改为实现一个CachedList 如果不实现IList<T>接口,则 LINQ 方法(如Count()ToArray()无法利用您的扩展功能,并且会像使用普通的IEnumerable一样使用慢速路径。

更新:我刚刚注意到另一个线程安全问题。 这与public IEnumerator<T> GetEnumerator()方法有关。 枚举器是编译器生成的,因为该方法是一个迭代器(利用yield return )。 编译器生成的枚举器不是线程安全的。 例如,考虑以下代码:

var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var enumerator = cachedEnumerable.GetEnumerator();
var tasks = Enumerable.Range(1, 4).Select(id => Task.Run(() =>
{
    int count = 0;
    while (enumerator.MoveNext())
    {
        count++;
    }
    Console.WriteLine($"Task #{id} count: {count}");
})).ToArray();
Task.WaitAll(tasks);

四个线程同时使用相同的IEnumerator 可枚举项有 1,000,000 项。 您可能期望每个线程会枚举约 250,000 个项目,但事实并非如此。

Output:

任务 #1 计数:0
任务 #4 计数:0
任务 #3 计数:0
任务 #2 计数:1000000

行中的MoveNext while (enumerator.MoveNext())不是您安全的MoveNext 它是编译器生成的不安全MoveNext 虽然不安全,但它包含一种可能用于处理异常的机制,在调用外部提供的代码之前临时将枚举器标记为已完成。 因此,当多个线程同时调用MoveNext时,除第一个线程外,所有线程都将获得false的返回值,并且将立即终止枚举,完成零循环。 要解决这个问题,您可能必须编写自己的IEnumerator class。


更新:实际上我关于线程安全枚举的最后一点有点不公平,因为使用IEnumerator接口进行枚举本质上是一种不安全的操作,如果没有调用代码的合作,这是无法修复的。 这是因为获取下一个元素不是原子操作,因为它涉及两个步骤(调用MoveNext() + 读取Current )。 因此,您的线程安全问题仅限于保护 class 的内部 state(字段_enumerator_cacheCachingComplete )。 这些仅在构造函数和Dispose方法中未受保护,但我认为 class 的正常使用可能不会遵循创建会导致内部 state 损坏的竞争条件的代码路径。

就我个人而言,我也更愿意处理这些代码路径,我不会让它随心所欲。


更新:我为IAsyncEnumerable编写了一个缓存,以演示另一种技术。 IAsyncEnumerable的枚举不是由调用者驱动的,使用锁或信号量来获得独占访问,而是由单独的工作任务驱动。 第一个调用者启动工作任务。 每个调用者首先产生所有已缓存的项目,然后等待更多项目,或等待没有更多项目的通知。 作为通知机制,我使用了TaskCompletionSource<bool> 仍然使用lock来确保对共享资源的所有访问都是同步的。

public class CachedAsyncEnumerable<T> : IAsyncEnumerable<T>
{
    private readonly object _locker = new object();
    private IAsyncEnumerable<T> _source;
    private Task _sourceEnumerationTask;
    private List<T> _buffer;
    private TaskCompletionSource<bool> _moveNextTCS;
    private Exception _sourceEnumerationException;
    private int _sourceEnumerationVersion; // Incremented on exception

    public CachedAsyncEnumerable(IAsyncEnumerable<T> source)
    {
        _source = source ?? throw new ArgumentNullException(nameof(source));
    }

    public async IAsyncEnumerator<T> GetAsyncEnumerator(
        CancellationToken cancellationToken = default)
    {
        lock (_locker)
        {
            if (_sourceEnumerationTask == null)
            {
                _buffer = new List<T>();
                _moveNextTCS = new TaskCompletionSource<bool>();
                _sourceEnumerationTask = Task.Run(
                    () => EnumerateSourceAsync(cancellationToken));
            }
        }
        int index = 0;
        int localVersion = -1;
        while (true)
        {
            T current = default;
            Task<bool> moveNextTask = null;
            lock (_locker)
            {
                if (localVersion == -1)
                {
                    localVersion = _sourceEnumerationVersion;
                }
                else if (_sourceEnumerationVersion != localVersion)
                {
                    ExceptionDispatchInfo
                        .Capture(_sourceEnumerationException).Throw();
                }
                if (index < _buffer.Count)
                {
                    current = _buffer[index];
                    index++;
                }
                else
                {
                    moveNextTask = _moveNextTCS.Task;
                }
            }
            if (moveNextTask == null)
            {
                yield return current;
                continue;
            }
            var moved = await moveNextTask;
            if (!moved) yield break;
            lock (_locker)
            {
                current = _buffer[index];
                index++;
            }
            yield return current;
        }
    }

    private async Task EnumerateSourceAsync(CancellationToken cancellationToken)
    {
        TaskCompletionSource<bool> localMoveNextTCS;
        try
        {
            await foreach (var item in _source.WithCancellation(cancellationToken))
            {
                lock (_locker)
                {
                    _buffer.Add(item);
                    localMoveNextTCS = _moveNextTCS;
                    _moveNextTCS = new TaskCompletionSource<bool>();
                }
                localMoveNextTCS.SetResult(true);
            }
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _buffer.TrimExcess();
                _source = null;
            }
            localMoveNextTCS.SetResult(false);
        }
        catch (Exception ex)
        {
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _sourceEnumerationException = ex;
                _sourceEnumerationVersion++;
                _sourceEnumerationTask = null;
            }
            localMoveNextTCS.SetException(ex);
        }
    }
}

此实现遵循处理异常的特定策略。 如果在枚举源IAsyncEnumerable时发生异常,该异常将传播到所有当前调用者,当前使用的IAsyncEnumerator将被丢弃,不完整的缓存数据也将被丢弃。 当接收到下一个枚举请求时,新的工作任务可能会在稍后再次启动。

访问缓存,是的,它是线程安全的,每次只能从 _cache object 读取一个线程。

但是这样一来,您就不能保证所有线程都以与访问 GetEnumerator 相同的顺序获取元素。

检查这两个例子,如果行为是你所期望的,你可以用那种方式使用锁。

示例 1:

THREAD1 调用 GetEnumerator

THREAD1 初始化 T 电流;

THREAD2 调用 GetEnumerator

THREAD2 初始化 T 电流;

螺纹 2 锁螺纹

线程 1 等待

THREAD2 从缓存中安全读取 _cache[0]

THREAD2 索引++

线程2解锁

螺纹1锁

THREAD1 从缓存中安全读取 _cache[1]

线程1 i++

线程 1 解锁

THREAD2 产生返回电流;

THREAD1 产生返回电流;


示例 2:

THREAD2 初始化 T 电流;

螺纹 2 锁螺纹

THREAD2 安全地从缓存中读取

线程2解锁

THREAD1 初始化 T 电流;

螺纹 1 锁螺纹

THREAD1 从缓存中安全读取

线程 1 解锁

THREAD1 产生返回电流;

THREAD2 产生返回电流;

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM