簡體   English   中英

線程安全的緩存枚舉器 - 帶收益的鎖

[英]Thread-safe Cached Enumerator - lock with yield

我有一個自定義的“CachedEnumerable”class(受緩存 IEnumerable啟發),我需要為我的 asp.net 核心 Z2567A5EC9705EB7AC2C984033E0618 應用程序設置線程安全。

Enumerator 線程的以下實現是否安全? (對 IList _cache 的所有其他讀/寫操作均已適當鎖定)(可能與Does the C# Yield free a lock?相關)

更具體地說,如果有 2 個線程訪問枚舉器,我如何防止一個線程遞增“索引”導致第二個枚舉線程從 _cache 獲取錯誤的元素(即索引 + 1 處的元素而不是索引處的元素) ? 這種比賽條件是一個真正的問題嗎?

public IEnumerator<T> GetEnumerator()
{
    var index = 0;

    while (true)
    {
        T current;
        lock (_enumeratorLock)
        {
            if (index >= _cache.Count && !MoveNext()) break;
            current = _cache[index];
            index++;
        }
        yield return current;
    }
}

我的 CachedEnumerable 版本的完整代碼:

 public class CachedEnumerable<T> : IDisposable, IEnumerable<T>
    {
        IEnumerator<T> _enumerator;
        private IList<T> _cache = new List<T>();
        public bool CachingComplete { get; private set; } = false;

        public CachedEnumerable(IEnumerable<T> enumerable)
        {
            switch (enumerable)
            {
                case CachedEnumerable<T> cachedEnumerable: //This case is actually dealt with by the extension method.
                    _cache = cachedEnumerable._cache;
                    CachingComplete = cachedEnumerable.CachingComplete;
                    _enumerator = cachedEnumerable.GetEnumerator();

                    break;
                case IList<T> list:
                    //_cache = list; //without clone...
                    //Clone:
                    _cache = new T[list.Count];
                    list.CopyTo((T[]) _cache, 0);
                    CachingComplete = true;
                    break;
                default:
                    _enumerator = enumerable.GetEnumerator();
                    break;
            }
        }

        public CachedEnumerable(IEnumerator<T> enumerator)
        {
            _enumerator = enumerator;
        }

        private int CurCacheCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public IEnumerator<T> GetEnumerator()
        {
            var index = 0;

            while (true)
            {
                T current;
                lock (_enumeratorLock)
                {
                    if (index >= _cache.Count && !MoveNext()) break;
                    current = _cache[index];
                    index++;
                }
                yield return current;
            }
        }

        //private readonly AsyncLock _enumeratorLock = new AsyncLock();
        private readonly object _enumeratorLock = new object();

        private bool MoveNext()
        {
            if (CachingComplete) return false;

            if (_enumerator != null && _enumerator.MoveNext()) //The null check should have been unnecessary b/c of the lock...
            {
                _cache.Add(_enumerator.Current);
                return true;
            }
            else
            {
                CachingComplete = true;
                DisposeWrappedEnumerator(); //Release the enumerator, as it is no longer needed.
            }

            return false;
        }

        public T ElementAt(int index)
        {
            lock (_enumeratorLock)
            {
                if (index < _cache.Count)
                {
                    return _cache[index];
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) throw new ArgumentOutOfRangeException(nameof(index));
                return _cache[index];
            }
        }


        public bool TryGetElementAt(int index, out T value)
        {
            lock (_enumeratorLock)
            {
                value = default;
                if (index < CurCacheCount)
                {
                    value = _cache[index];
                    return true;
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) return false;
                value = _cache[index];
            }

            return true;
        }

        private void EnumerateUntil(int index)
        {
            while (true)
            {
                lock (_enumeratorLock)
                {
                    if (_cache.Count > index || !MoveNext()) break;
                }
            }
        }


        public void Dispose()
        {
            DisposeWrappedEnumerator();
        }

        private void DisposeWrappedEnumerator()
        {
            if (_enumerator != null)
            {
                _enumerator.Dispose();
                _enumerator = null;
                if (_cache is List<T> list)
                {
                    list.Trim();
                }
            }
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }

        public int CachedCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public int Count()
        {
            if (CachingComplete)
            {
                return _cache.Count;
            }

            EnsureCachingComplete();

            return _cache.Count;
        }

        private void EnsureCachingComplete()
        {
            if (CachingComplete)
            {
                return;
            }

            //Enumerate the rest of the collection
            while (!CachingComplete)
            {
                lock (_enumeratorLock)
                {
                    if (!MoveNext()) break;
                }
            }
        }

        public T[] ToArray()
        {
            EnsureCachingComplete();
            //Once Caching is complete, we don't need to lock
            if (!(_cache is T[] array))
            {
                array = _cache.ToArray();
                _cache = array;
            }

            return array;
        }

        public T this[int index] => ElementAt(index);
    }

    public static CachedEnumerable<T> Cached<T>(this IEnumerable<T> source)
    {
        //no gain in caching a cache.
        if (source is CachedEnumerable<T> cached)
        {
            return cached;
        }

        return new CachedEnumerable<T>(source);
    }
}

基本用法:(雖然不是一個有意義的用例)

var cached = expensiveEnumerable.Cached();
foreach (var element in cached) {
   Console.WriteLine(element);
}

更新

我根據@Theodors 回答https://stackoverflow.com/a/58547863/5683904測試了當前實現,並確認(AFAICT)在使用 foreach 枚舉而不創建重復值時它是線程安全的( Thread-safe Cached Enumerator -鎖定收益):

class Program
{
    static async Task Main(string[] args)
    {
        var enumerable = Enumerable.Range(0, 1_000_000);
        var cachedEnumerable = new CachedEnumerable<int>(enumerable);
        var c = new ConcurrentDictionary<int, List<int>>();
        var tasks = Enumerable.Range(1, 100).Select(id => Test(id, cachedEnumerable, c));
        Task.WaitAll(tasks.ToArray());
        foreach (var keyValuePair in c)
        {
            var hasDuplicates = keyValuePair.Value.Distinct().Count() != keyValuePair.Value.Count;
            Console.WriteLine($"Task #{keyValuePair.Key} count: {keyValuePair.Value.Count}. Has duplicates? {hasDuplicates}");
        }
    }

    static async Task Test(int id, IEnumerable<int> cache, ConcurrentDictionary<int, List<int>> c)
    {
        foreach (var i in cache)
        {
            //await Task.Delay(10);
            c.AddOrUpdate(id, v => new List<int>() {i}, (k, v) =>
            {
                v.Add(i);
                return v;
            });
        }
    }
}

您的 class 不是線程安全的,因為共享的 state 在 class 內的未受保護區域中發生突變。 未受保護的區域是:

  1. 構造函數
  2. Dispose方法

共享的 state 是:

  1. _enumerator私有字段
  2. _cache私有字段
  3. CachingComplete公共屬性

關於 class 的其他一些問題:

  1. 實現IDisposable為調用者創建了處置您的 class 的責任。 IEnumerable不需要是一次性的。 相反, IEnumerator是一次性的,但它們的自動處理有語言支持( foreach語句的特性)。
  2. 您的 class 提供了IEnumerableElementAtCount等)無法提供的擴展功能。 也許您打算改為實現一個CachedList 如果不實現IList<T>接口,則 LINQ 方法(如Count()ToArray()無法利用您的擴展功能,並且會像使用普通的IEnumerable一樣使用慢速路徑。

更新:我剛剛注意到另一個線程安全問題。 這與public IEnumerator<T> GetEnumerator()方法有關。 枚舉器是編譯器生成的,因為該方法是一個迭代器(利用yield return )。 編譯器生成的枚舉器不是線程安全的。 例如,考慮以下代碼:

var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var enumerator = cachedEnumerable.GetEnumerator();
var tasks = Enumerable.Range(1, 4).Select(id => Task.Run(() =>
{
    int count = 0;
    while (enumerator.MoveNext())
    {
        count++;
    }
    Console.WriteLine($"Task #{id} count: {count}");
})).ToArray();
Task.WaitAll(tasks);

四個線程同時使用相同的IEnumerator 可枚舉項有 1,000,000 項。 您可能期望每個線程會枚舉約 250,000 個項目,但事實並非如此。

Output:

任務 #1 計數:0
任務 #4 計數:0
任務 #3 計數:0
任務 #2 計數:1000000

行中的MoveNext while (enumerator.MoveNext())不是您安全的MoveNext 它是編譯器生成的不安全MoveNext 雖然不安全,但它包含一種可能用於處理異常的機制,在調用外部提供的代碼之前臨時將枚舉器標記為已完成。 因此,當多個線程同時調用MoveNext時,除第一個線程外,所有線程都將獲得false的返回值,並且將立即終止枚舉,完成零循環。 要解決這個問題,您可能必須編寫自己的IEnumerator class。


更新:實際上我關於線程安全枚舉的最后一點有點不公平,因為使用IEnumerator接口進行枚舉本質上是一種不安全的操作,如果沒有調用代碼的合作,這是無法修復的。 這是因為獲取下一個元素不是原子操作,因為它涉及兩個步驟(調用MoveNext() + 讀取Current )。 因此,您的線程安全問題僅限於保護 class 的內部 state(字段_enumerator_cacheCachingComplete )。 這些僅在構造函數和Dispose方法中未受保護,但我認為 class 的正常使用可能不會遵循創建會導致內部 state 損壞的競爭條件的代碼路徑。

就我個人而言,我也更願意處理這些代碼路徑,我不會讓它隨心所欲。


更新:我為IAsyncEnumerable編寫了一個緩存,以演示另一種技術。 IAsyncEnumerable的枚舉不是由調用者驅動的,使用鎖或信號量來獲得獨占訪問,而是由單獨的工作任務驅動。 第一個調用者啟動工作任務。 每個調用者首先產生所有已緩存的項目,然后等待更多項目,或等待沒有更多項目的通知。 作為通知機制,我使用了TaskCompletionSource<bool> 仍然使用lock來確保對共享資源的所有訪問都是同步的。

public class CachedAsyncEnumerable<T> : IAsyncEnumerable<T>
{
    private readonly object _locker = new object();
    private IAsyncEnumerable<T> _source;
    private Task _sourceEnumerationTask;
    private List<T> _buffer;
    private TaskCompletionSource<bool> _moveNextTCS;
    private Exception _sourceEnumerationException;
    private int _sourceEnumerationVersion; // Incremented on exception

    public CachedAsyncEnumerable(IAsyncEnumerable<T> source)
    {
        _source = source ?? throw new ArgumentNullException(nameof(source));
    }

    public async IAsyncEnumerator<T> GetAsyncEnumerator(
        CancellationToken cancellationToken = default)
    {
        lock (_locker)
        {
            if (_sourceEnumerationTask == null)
            {
                _buffer = new List<T>();
                _moveNextTCS = new TaskCompletionSource<bool>();
                _sourceEnumerationTask = Task.Run(
                    () => EnumerateSourceAsync(cancellationToken));
            }
        }
        int index = 0;
        int localVersion = -1;
        while (true)
        {
            T current = default;
            Task<bool> moveNextTask = null;
            lock (_locker)
            {
                if (localVersion == -1)
                {
                    localVersion = _sourceEnumerationVersion;
                }
                else if (_sourceEnumerationVersion != localVersion)
                {
                    ExceptionDispatchInfo
                        .Capture(_sourceEnumerationException).Throw();
                }
                if (index < _buffer.Count)
                {
                    current = _buffer[index];
                    index++;
                }
                else
                {
                    moveNextTask = _moveNextTCS.Task;
                }
            }
            if (moveNextTask == null)
            {
                yield return current;
                continue;
            }
            var moved = await moveNextTask;
            if (!moved) yield break;
            lock (_locker)
            {
                current = _buffer[index];
                index++;
            }
            yield return current;
        }
    }

    private async Task EnumerateSourceAsync(CancellationToken cancellationToken)
    {
        TaskCompletionSource<bool> localMoveNextTCS;
        try
        {
            await foreach (var item in _source.WithCancellation(cancellationToken))
            {
                lock (_locker)
                {
                    _buffer.Add(item);
                    localMoveNextTCS = _moveNextTCS;
                    _moveNextTCS = new TaskCompletionSource<bool>();
                }
                localMoveNextTCS.SetResult(true);
            }
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _buffer.TrimExcess();
                _source = null;
            }
            localMoveNextTCS.SetResult(false);
        }
        catch (Exception ex)
        {
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _sourceEnumerationException = ex;
                _sourceEnumerationVersion++;
                _sourceEnumerationTask = null;
            }
            localMoveNextTCS.SetException(ex);
        }
    }
}

此實現遵循處理異常的特定策略。 如果在枚舉源IAsyncEnumerable時發生異常,該異常將傳播到所有當前調用者,當前使用的IAsyncEnumerator將被丟棄,不完整的緩存數據也將被丟棄。 當接收到下一個枚舉請求時,新的工作任務可能會在稍后再次啟動。

訪問緩存,是的,它是線程安全的,每次只能從 _cache object 讀取一個線程。

但是這樣一來,您就不能保證所有線程都以與訪問 GetEnumerator 相同的順序獲取元素。

檢查這兩個例子,如果行為是你所期望的,你可以用那種方式使用鎖。

示例 1:

THREAD1 調用 GetEnumerator

THREAD1 初始化 T 電流;

THREAD2 調用 GetEnumerator

THREAD2 初始化 T 電流;

螺紋 2 鎖螺紋

線程 1 等待

THREAD2 從緩存中安全讀取 _cache[0]

THREAD2 索引++

線程2解鎖

螺紋1鎖

THREAD1 從緩存中安全讀取 _cache[1]

線程1 i++

線程 1 解鎖

THREAD2 產生返回電流;

THREAD1 產生返回電流;


示例 2:

THREAD2 初始化 T 電流;

螺紋 2 鎖螺紋

THREAD2 安全地從緩存中讀取

線程2解鎖

THREAD1 初始化 T 電流;

螺紋 1 鎖螺紋

THREAD1 從緩存中安全讀取

線程 1 解鎖

THREAD1 產生返回電流;

THREAD2 產生返回電流;

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM