[英]Thread-safe Cached Enumerator - lock with yield
我有一個自定義的“CachedEnumerable”class(受緩存 IEnumerable啟發),我需要為我的 asp.net 核心 Z2567A5EC9705EB7AC2C984033E0618 應用程序設置線程安全。
Enumerator 線程的以下實現是否安全? (對 IList _cache 的所有其他讀/寫操作均已適當鎖定)(可能與Does the C# Yield free a lock?相關)
更具體地說,如果有 2 個線程訪問枚舉器,我如何防止一個線程遞增“索引”導致第二個枚舉線程從 _cache 獲取錯誤的元素(即索引 + 1 處的元素而不是索引處的元素) ? 這種比賽條件是一個真正的問題嗎?
public IEnumerator<T> GetEnumerator()
{
var index = 0;
while (true)
{
T current;
lock (_enumeratorLock)
{
if (index >= _cache.Count && !MoveNext()) break;
current = _cache[index];
index++;
}
yield return current;
}
}
我的 CachedEnumerable 版本的完整代碼:
public class CachedEnumerable<T> : IDisposable, IEnumerable<T>
{
IEnumerator<T> _enumerator;
private IList<T> _cache = new List<T>();
public bool CachingComplete { get; private set; } = false;
public CachedEnumerable(IEnumerable<T> enumerable)
{
switch (enumerable)
{
case CachedEnumerable<T> cachedEnumerable: //This case is actually dealt with by the extension method.
_cache = cachedEnumerable._cache;
CachingComplete = cachedEnumerable.CachingComplete;
_enumerator = cachedEnumerable.GetEnumerator();
break;
case IList<T> list:
//_cache = list; //without clone...
//Clone:
_cache = new T[list.Count];
list.CopyTo((T[]) _cache, 0);
CachingComplete = true;
break;
default:
_enumerator = enumerable.GetEnumerator();
break;
}
}
public CachedEnumerable(IEnumerator<T> enumerator)
{
_enumerator = enumerator;
}
private int CurCacheCount
{
get
{
lock (_enumeratorLock)
{
return _cache.Count;
}
}
}
public IEnumerator<T> GetEnumerator()
{
var index = 0;
while (true)
{
T current;
lock (_enumeratorLock)
{
if (index >= _cache.Count && !MoveNext()) break;
current = _cache[index];
index++;
}
yield return current;
}
}
//private readonly AsyncLock _enumeratorLock = new AsyncLock();
private readonly object _enumeratorLock = new object();
private bool MoveNext()
{
if (CachingComplete) return false;
if (_enumerator != null && _enumerator.MoveNext()) //The null check should have been unnecessary b/c of the lock...
{
_cache.Add(_enumerator.Current);
return true;
}
else
{
CachingComplete = true;
DisposeWrappedEnumerator(); //Release the enumerator, as it is no longer needed.
}
return false;
}
public T ElementAt(int index)
{
lock (_enumeratorLock)
{
if (index < _cache.Count)
{
return _cache[index];
}
}
EnumerateUntil(index);
lock (_enumeratorLock)
{
if (_cache.Count <= index) throw new ArgumentOutOfRangeException(nameof(index));
return _cache[index];
}
}
public bool TryGetElementAt(int index, out T value)
{
lock (_enumeratorLock)
{
value = default;
if (index < CurCacheCount)
{
value = _cache[index];
return true;
}
}
EnumerateUntil(index);
lock (_enumeratorLock)
{
if (_cache.Count <= index) return false;
value = _cache[index];
}
return true;
}
private void EnumerateUntil(int index)
{
while (true)
{
lock (_enumeratorLock)
{
if (_cache.Count > index || !MoveNext()) break;
}
}
}
public void Dispose()
{
DisposeWrappedEnumerator();
}
private void DisposeWrappedEnumerator()
{
if (_enumerator != null)
{
_enumerator.Dispose();
_enumerator = null;
if (_cache is List<T> list)
{
list.Trim();
}
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public int CachedCount
{
get
{
lock (_enumeratorLock)
{
return _cache.Count;
}
}
}
public int Count()
{
if (CachingComplete)
{
return _cache.Count;
}
EnsureCachingComplete();
return _cache.Count;
}
private void EnsureCachingComplete()
{
if (CachingComplete)
{
return;
}
//Enumerate the rest of the collection
while (!CachingComplete)
{
lock (_enumeratorLock)
{
if (!MoveNext()) break;
}
}
}
public T[] ToArray()
{
EnsureCachingComplete();
//Once Caching is complete, we don't need to lock
if (!(_cache is T[] array))
{
array = _cache.ToArray();
_cache = array;
}
return array;
}
public T this[int index] => ElementAt(index);
}
public static CachedEnumerable<T> Cached<T>(this IEnumerable<T> source)
{
//no gain in caching a cache.
if (source is CachedEnumerable<T> cached)
{
return cached;
}
return new CachedEnumerable<T>(source);
}
}
基本用法:(雖然不是一個有意義的用例)
var cached = expensiveEnumerable.Cached();
foreach (var element in cached) {
Console.WriteLine(element);
}
更新
我根據@Theodors 回答https://stackoverflow.com/a/58547863/5683904測試了當前實現,並確認(AFAICT)在使用 foreach 枚舉而不創建重復值時它是線程安全的( Thread-safe Cached Enumerator -鎖定收益):
class Program
{
static async Task Main(string[] args)
{
var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var c = new ConcurrentDictionary<int, List<int>>();
var tasks = Enumerable.Range(1, 100).Select(id => Test(id, cachedEnumerable, c));
Task.WaitAll(tasks.ToArray());
foreach (var keyValuePair in c)
{
var hasDuplicates = keyValuePair.Value.Distinct().Count() != keyValuePair.Value.Count;
Console.WriteLine($"Task #{keyValuePair.Key} count: {keyValuePair.Value.Count}. Has duplicates? {hasDuplicates}");
}
}
static async Task Test(int id, IEnumerable<int> cache, ConcurrentDictionary<int, List<int>> c)
{
foreach (var i in cache)
{
//await Task.Delay(10);
c.AddOrUpdate(id, v => new List<int>() {i}, (k, v) =>
{
v.Add(i);
return v;
});
}
}
}
您的 class 不是線程安全的,因為共享的 state 在 class 內的未受保護區域中發生突變。 未受保護的區域是:
Dispose
方法共享的 state 是:
_enumerator
私有字段_cache
私有字段CachingComplete
公共屬性關於 class 的其他一些問題:
IDisposable
為調用者創建了處置您的 class 的責任。 IEnumerable
不需要是一次性的。 相反, IEnumerator
是一次性的,但它們的自動處理有語言支持( foreach
語句的特性)。IEnumerable
( ElementAt
、 Count
等)無法提供的擴展功能。 也許您打算改為實現一個CachedList
? 如果不實現IList<T>
接口,則 LINQ 方法(如Count()
和ToArray()
無法利用您的擴展功能,並且會像使用普通的IEnumerable
一樣使用慢速路徑。 更新:我剛剛注意到另一個線程安全問題。 這與public IEnumerator<T> GetEnumerator()
方法有關。 枚舉器是編譯器生成的,因為該方法是一個迭代器(利用yield return
)。 編譯器生成的枚舉器不是線程安全的。 例如,考慮以下代碼:
var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var enumerator = cachedEnumerable.GetEnumerator();
var tasks = Enumerable.Range(1, 4).Select(id => Task.Run(() =>
{
int count = 0;
while (enumerator.MoveNext())
{
count++;
}
Console.WriteLine($"Task #{id} count: {count}");
})).ToArray();
Task.WaitAll(tasks);
四個線程同時使用相同的IEnumerator
。 可枚舉項有 1,000,000 項。 您可能期望每個線程會枚舉約 250,000 個項目,但事實並非如此。
Output:
任務 #1 計數:0
任務 #4 計數:0
任務 #3 計數:0
任務 #2 計數:1000000
行中的MoveNext
while (enumerator.MoveNext())
不是您安全的MoveNext
。 它是編譯器生成的不安全MoveNext
。 雖然不安全,但它包含一種可能用於處理異常的機制,在調用外部提供的代碼之前臨時將枚舉器標記為已完成。 因此,當多個線程同時調用MoveNext
時,除第一個線程外,所有線程都將獲得false
的返回值,並且將立即終止枚舉,完成零循環。 要解決這個問題,您可能必須編寫自己的IEnumerator
class。
更新:實際上我關於線程安全枚舉的最后一點有點不公平,因為使用IEnumerator
接口進行枚舉本質上是一種不安全的操作,如果沒有調用代碼的合作,這是無法修復的。 這是因為獲取下一個元素不是原子操作,因為它涉及兩個步驟(調用MoveNext()
+ 讀取Current
)。 因此,您的線程安全問題僅限於保護 class 的內部 state(字段_enumerator
, _cache
和CachingComplete
)。 這些僅在構造函數和Dispose
方法中未受保護,但我認為 class 的正常使用可能不會遵循創建會導致內部 state 損壞的競爭條件的代碼路徑。
就我個人而言,我也更願意處理這些代碼路徑,我不會讓它隨心所欲。
更新:我為IAsyncEnumerable
編寫了一個緩存,以演示另一種技術。 源IAsyncEnumerable
的枚舉不是由調用者驅動的,使用鎖或信號量來獲得獨占訪問,而是由單獨的工作任務驅動。 第一個調用者啟動工作任務。 每個調用者首先產生所有已緩存的項目,然后等待更多項目,或等待沒有更多項目的通知。 作為通知機制,我使用了TaskCompletionSource<bool>
。 仍然使用lock
來確保對共享資源的所有訪問都是同步的。
public class CachedAsyncEnumerable<T> : IAsyncEnumerable<T>
{
private readonly object _locker = new object();
private IAsyncEnumerable<T> _source;
private Task _sourceEnumerationTask;
private List<T> _buffer;
private TaskCompletionSource<bool> _moveNextTCS;
private Exception _sourceEnumerationException;
private int _sourceEnumerationVersion; // Incremented on exception
public CachedAsyncEnumerable(IAsyncEnumerable<T> source)
{
_source = source ?? throw new ArgumentNullException(nameof(source));
}
public async IAsyncEnumerator<T> GetAsyncEnumerator(
CancellationToken cancellationToken = default)
{
lock (_locker)
{
if (_sourceEnumerationTask == null)
{
_buffer = new List<T>();
_moveNextTCS = new TaskCompletionSource<bool>();
_sourceEnumerationTask = Task.Run(
() => EnumerateSourceAsync(cancellationToken));
}
}
int index = 0;
int localVersion = -1;
while (true)
{
T current = default;
Task<bool> moveNextTask = null;
lock (_locker)
{
if (localVersion == -1)
{
localVersion = _sourceEnumerationVersion;
}
else if (_sourceEnumerationVersion != localVersion)
{
ExceptionDispatchInfo
.Capture(_sourceEnumerationException).Throw();
}
if (index < _buffer.Count)
{
current = _buffer[index];
index++;
}
else
{
moveNextTask = _moveNextTCS.Task;
}
}
if (moveNextTask == null)
{
yield return current;
continue;
}
var moved = await moveNextTask;
if (!moved) yield break;
lock (_locker)
{
current = _buffer[index];
index++;
}
yield return current;
}
}
private async Task EnumerateSourceAsync(CancellationToken cancellationToken)
{
TaskCompletionSource<bool> localMoveNextTCS;
try
{
await foreach (var item in _source.WithCancellation(cancellationToken))
{
lock (_locker)
{
_buffer.Add(item);
localMoveNextTCS = _moveNextTCS;
_moveNextTCS = new TaskCompletionSource<bool>();
}
localMoveNextTCS.SetResult(true);
}
lock (_locker)
{
localMoveNextTCS = _moveNextTCS;
_buffer.TrimExcess();
_source = null;
}
localMoveNextTCS.SetResult(false);
}
catch (Exception ex)
{
lock (_locker)
{
localMoveNextTCS = _moveNextTCS;
_sourceEnumerationException = ex;
_sourceEnumerationVersion++;
_sourceEnumerationTask = null;
}
localMoveNextTCS.SetException(ex);
}
}
}
此實現遵循處理異常的特定策略。 如果在枚舉源IAsyncEnumerable
時發生異常,該異常將傳播到所有當前調用者,當前使用的IAsyncEnumerator
將被丟棄,不完整的緩存數據也將被丟棄。 當接收到下一個枚舉請求時,新的工作任務可能會在稍后再次啟動。
訪問緩存,是的,它是線程安全的,每次只能從 _cache object 讀取一個線程。
但是這樣一來,您就不能保證所有線程都以與訪問 GetEnumerator 相同的順序獲取元素。
檢查這兩個例子,如果行為是你所期望的,你可以用那種方式使用鎖。
示例 1:
THREAD1 調用 GetEnumerator
THREAD1 初始化 T 電流;
THREAD2 調用 GetEnumerator
THREAD2 初始化 T 電流;
螺紋 2 鎖螺紋
線程 1 等待
THREAD2 從緩存中安全讀取 _cache[0]
THREAD2 索引++
線程2解鎖
螺紋1鎖
THREAD1 從緩存中安全讀取 _cache[1]
線程1 i++
線程 1 解鎖
THREAD2 產生返回電流;
THREAD1 產生返回電流;
示例 2:
THREAD2 初始化 T 電流;
螺紋 2 鎖螺紋
THREAD2 安全地從緩存中讀取
線程2解鎖
THREAD1 初始化 T 電流;
螺紋 1 鎖螺紋
THREAD1 從緩存中安全讀取
線程 1 解鎖
THREAD1 產生返回電流;
THREAD2 產生返回電流;
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.