[英]Thread-safe Cached Enumerator - lock with yield
我有一个自定义的“CachedEnumerable”class(受缓存 IEnumerable启发),我需要为我的 asp.net 核心 Z2567A5EC9705EB7AC2C984033E0618 应用程序设置线程安全。
Enumerator 线程的以下实现是否安全? (对 IList _cache 的所有其他读/写操作均已适当锁定)(可能与Does the C# Yield free a lock?相关)
更具体地说,如果有 2 个线程访问枚举器,我如何防止一个线程递增“索引”导致第二个枚举线程从 _cache 获取错误的元素(即索引 + 1 处的元素而不是索引处的元素) ? 这种比赛条件是一个真正的问题吗?
public IEnumerator<T> GetEnumerator()
{
var index = 0;
while (true)
{
T current;
lock (_enumeratorLock)
{
if (index >= _cache.Count && !MoveNext()) break;
current = _cache[index];
index++;
}
yield return current;
}
}
我的 CachedEnumerable 版本的完整代码:
public class CachedEnumerable<T> : IDisposable, IEnumerable<T>
{
IEnumerator<T> _enumerator;
private IList<T> _cache = new List<T>();
public bool CachingComplete { get; private set; } = false;
public CachedEnumerable(IEnumerable<T> enumerable)
{
switch (enumerable)
{
case CachedEnumerable<T> cachedEnumerable: //This case is actually dealt with by the extension method.
_cache = cachedEnumerable._cache;
CachingComplete = cachedEnumerable.CachingComplete;
_enumerator = cachedEnumerable.GetEnumerator();
break;
case IList<T> list:
//_cache = list; //without clone...
//Clone:
_cache = new T[list.Count];
list.CopyTo((T[]) _cache, 0);
CachingComplete = true;
break;
default:
_enumerator = enumerable.GetEnumerator();
break;
}
}
public CachedEnumerable(IEnumerator<T> enumerator)
{
_enumerator = enumerator;
}
private int CurCacheCount
{
get
{
lock (_enumeratorLock)
{
return _cache.Count;
}
}
}
public IEnumerator<T> GetEnumerator()
{
var index = 0;
while (true)
{
T current;
lock (_enumeratorLock)
{
if (index >= _cache.Count && !MoveNext()) break;
current = _cache[index];
index++;
}
yield return current;
}
}
//private readonly AsyncLock _enumeratorLock = new AsyncLock();
private readonly object _enumeratorLock = new object();
private bool MoveNext()
{
if (CachingComplete) return false;
if (_enumerator != null && _enumerator.MoveNext()) //The null check should have been unnecessary b/c of the lock...
{
_cache.Add(_enumerator.Current);
return true;
}
else
{
CachingComplete = true;
DisposeWrappedEnumerator(); //Release the enumerator, as it is no longer needed.
}
return false;
}
public T ElementAt(int index)
{
lock (_enumeratorLock)
{
if (index < _cache.Count)
{
return _cache[index];
}
}
EnumerateUntil(index);
lock (_enumeratorLock)
{
if (_cache.Count <= index) throw new ArgumentOutOfRangeException(nameof(index));
return _cache[index];
}
}
public bool TryGetElementAt(int index, out T value)
{
lock (_enumeratorLock)
{
value = default;
if (index < CurCacheCount)
{
value = _cache[index];
return true;
}
}
EnumerateUntil(index);
lock (_enumeratorLock)
{
if (_cache.Count <= index) return false;
value = _cache[index];
}
return true;
}
private void EnumerateUntil(int index)
{
while (true)
{
lock (_enumeratorLock)
{
if (_cache.Count > index || !MoveNext()) break;
}
}
}
public void Dispose()
{
DisposeWrappedEnumerator();
}
private void DisposeWrappedEnumerator()
{
if (_enumerator != null)
{
_enumerator.Dispose();
_enumerator = null;
if (_cache is List<T> list)
{
list.Trim();
}
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public int CachedCount
{
get
{
lock (_enumeratorLock)
{
return _cache.Count;
}
}
}
public int Count()
{
if (CachingComplete)
{
return _cache.Count;
}
EnsureCachingComplete();
return _cache.Count;
}
private void EnsureCachingComplete()
{
if (CachingComplete)
{
return;
}
//Enumerate the rest of the collection
while (!CachingComplete)
{
lock (_enumeratorLock)
{
if (!MoveNext()) break;
}
}
}
public T[] ToArray()
{
EnsureCachingComplete();
//Once Caching is complete, we don't need to lock
if (!(_cache is T[] array))
{
array = _cache.ToArray();
_cache = array;
}
return array;
}
public T this[int index] => ElementAt(index);
}
public static CachedEnumerable<T> Cached<T>(this IEnumerable<T> source)
{
//no gain in caching a cache.
if (source is CachedEnumerable<T> cached)
{
return cached;
}
return new CachedEnumerable<T>(source);
}
}
基本用法:(虽然不是一个有意义的用例)
var cached = expensiveEnumerable.Cached();
foreach (var element in cached) {
Console.WriteLine(element);
}
更新
我根据@Theodors 回答https://stackoverflow.com/a/58547863/5683904测试了当前实现,并确认(AFAICT)在使用 foreach 枚举而不创建重复值时它是线程安全的( Thread-safe Cached Enumerator -锁定收益):
class Program
{
static async Task Main(string[] args)
{
var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var c = new ConcurrentDictionary<int, List<int>>();
var tasks = Enumerable.Range(1, 100).Select(id => Test(id, cachedEnumerable, c));
Task.WaitAll(tasks.ToArray());
foreach (var keyValuePair in c)
{
var hasDuplicates = keyValuePair.Value.Distinct().Count() != keyValuePair.Value.Count;
Console.WriteLine($"Task #{keyValuePair.Key} count: {keyValuePair.Value.Count}. Has duplicates? {hasDuplicates}");
}
}
static async Task Test(int id, IEnumerable<int> cache, ConcurrentDictionary<int, List<int>> c)
{
foreach (var i in cache)
{
//await Task.Delay(10);
c.AddOrUpdate(id, v => new List<int>() {i}, (k, v) =>
{
v.Add(i);
return v;
});
}
}
}
您的 class 不是线程安全的,因为共享的 state 在 class 内的未受保护区域中发生突变。 未受保护的区域是:
Dispose
方法共享的 state 是:
_enumerator
私有字段_cache
私有字段CachingComplete
公共属性关于 class 的其他一些问题:
IDisposable
为调用者创建了处置您的 class 的责任。 IEnumerable
不需要是一次性的。 相反, IEnumerator
是一次性的,但它们的自动处理有语言支持( foreach
语句的特性)。IEnumerable
( ElementAt
、 Count
等)无法提供的扩展功能。 也许您打算改为实现一个CachedList
? 如果不实现IList<T>
接口,则 LINQ 方法(如Count()
和ToArray()
无法利用您的扩展功能,并且会像使用普通的IEnumerable
一样使用慢速路径。 更新:我刚刚注意到另一个线程安全问题。 这与public IEnumerator<T> GetEnumerator()
方法有关。 枚举器是编译器生成的,因为该方法是一个迭代器(利用yield return
)。 编译器生成的枚举器不是线程安全的。 例如,考虑以下代码:
var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var enumerator = cachedEnumerable.GetEnumerator();
var tasks = Enumerable.Range(1, 4).Select(id => Task.Run(() =>
{
int count = 0;
while (enumerator.MoveNext())
{
count++;
}
Console.WriteLine($"Task #{id} count: {count}");
})).ToArray();
Task.WaitAll(tasks);
四个线程同时使用相同的IEnumerator
。 可枚举项有 1,000,000 项。 您可能期望每个线程会枚举约 250,000 个项目,但事实并非如此。
Output:
任务 #1 计数:0
任务 #4 计数:0
任务 #3 计数:0
任务 #2 计数:1000000
行中的MoveNext
while (enumerator.MoveNext())
不是您安全的MoveNext
。 它是编译器生成的不安全MoveNext
。 虽然不安全,但它包含一种可能用于处理异常的机制,在调用外部提供的代码之前临时将枚举器标记为已完成。 因此,当多个线程同时调用MoveNext
时,除第一个线程外,所有线程都将获得false
的返回值,并且将立即终止枚举,完成零循环。 要解决这个问题,您可能必须编写自己的IEnumerator
class。
更新:实际上我关于线程安全枚举的最后一点有点不公平,因为使用IEnumerator
接口进行枚举本质上是一种不安全的操作,如果没有调用代码的合作,这是无法修复的。 这是因为获取下一个元素不是原子操作,因为它涉及两个步骤(调用MoveNext()
+ 读取Current
)。 因此,您的线程安全问题仅限于保护 class 的内部 state(字段_enumerator
, _cache
和CachingComplete
)。 这些仅在构造函数和Dispose
方法中未受保护,但我认为 class 的正常使用可能不会遵循创建会导致内部 state 损坏的竞争条件的代码路径。
就我个人而言,我也更愿意处理这些代码路径,我不会让它随心所欲。
更新:我为IAsyncEnumerable
编写了一个缓存,以演示另一种技术。 源IAsyncEnumerable
的枚举不是由调用者驱动的,使用锁或信号量来获得独占访问,而是由单独的工作任务驱动。 第一个调用者启动工作任务。 每个调用者首先产生所有已缓存的项目,然后等待更多项目,或等待没有更多项目的通知。 作为通知机制,我使用了TaskCompletionSource<bool>
。 仍然使用lock
来确保对共享资源的所有访问都是同步的。
public class CachedAsyncEnumerable<T> : IAsyncEnumerable<T>
{
private readonly object _locker = new object();
private IAsyncEnumerable<T> _source;
private Task _sourceEnumerationTask;
private List<T> _buffer;
private TaskCompletionSource<bool> _moveNextTCS;
private Exception _sourceEnumerationException;
private int _sourceEnumerationVersion; // Incremented on exception
public CachedAsyncEnumerable(IAsyncEnumerable<T> source)
{
_source = source ?? throw new ArgumentNullException(nameof(source));
}
public async IAsyncEnumerator<T> GetAsyncEnumerator(
CancellationToken cancellationToken = default)
{
lock (_locker)
{
if (_sourceEnumerationTask == null)
{
_buffer = new List<T>();
_moveNextTCS = new TaskCompletionSource<bool>();
_sourceEnumerationTask = Task.Run(
() => EnumerateSourceAsync(cancellationToken));
}
}
int index = 0;
int localVersion = -1;
while (true)
{
T current = default;
Task<bool> moveNextTask = null;
lock (_locker)
{
if (localVersion == -1)
{
localVersion = _sourceEnumerationVersion;
}
else if (_sourceEnumerationVersion != localVersion)
{
ExceptionDispatchInfo
.Capture(_sourceEnumerationException).Throw();
}
if (index < _buffer.Count)
{
current = _buffer[index];
index++;
}
else
{
moveNextTask = _moveNextTCS.Task;
}
}
if (moveNextTask == null)
{
yield return current;
continue;
}
var moved = await moveNextTask;
if (!moved) yield break;
lock (_locker)
{
current = _buffer[index];
index++;
}
yield return current;
}
}
private async Task EnumerateSourceAsync(CancellationToken cancellationToken)
{
TaskCompletionSource<bool> localMoveNextTCS;
try
{
await foreach (var item in _source.WithCancellation(cancellationToken))
{
lock (_locker)
{
_buffer.Add(item);
localMoveNextTCS = _moveNextTCS;
_moveNextTCS = new TaskCompletionSource<bool>();
}
localMoveNextTCS.SetResult(true);
}
lock (_locker)
{
localMoveNextTCS = _moveNextTCS;
_buffer.TrimExcess();
_source = null;
}
localMoveNextTCS.SetResult(false);
}
catch (Exception ex)
{
lock (_locker)
{
localMoveNextTCS = _moveNextTCS;
_sourceEnumerationException = ex;
_sourceEnumerationVersion++;
_sourceEnumerationTask = null;
}
localMoveNextTCS.SetException(ex);
}
}
}
此实现遵循处理异常的特定策略。 如果在枚举源IAsyncEnumerable
时发生异常,该异常将传播到所有当前调用者,当前使用的IAsyncEnumerator
将被丢弃,不完整的缓存数据也将被丢弃。 当接收到下一个枚举请求时,新的工作任务可能会在稍后再次启动。
访问缓存,是的,它是线程安全的,每次只能从 _cache object 读取一个线程。
但是这样一来,您就不能保证所有线程都以与访问 GetEnumerator 相同的顺序获取元素。
检查这两个例子,如果行为是你所期望的,你可以用那种方式使用锁。
示例 1:
THREAD1 调用 GetEnumerator
THREAD1 初始化 T 电流;
THREAD2 调用 GetEnumerator
THREAD2 初始化 T 电流;
螺纹 2 锁螺纹
线程 1 等待
THREAD2 从缓存中安全读取 _cache[0]
THREAD2 索引++
线程2解锁
螺纹1锁
THREAD1 从缓存中安全读取 _cache[1]
线程1 i++
线程 1 解锁
THREAD2 产生返回电流;
THREAD1 产生返回电流;
示例 2:
THREAD2 初始化 T 电流;
螺纹 2 锁螺纹
THREAD2 安全地从缓存中读取
线程2解锁
THREAD1 初始化 T 电流;
螺纹 1 锁螺纹
THREAD1 从缓存中安全读取
线程 1 解锁
THREAD1 产生返回电流;
THREAD2 产生返回电流;
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.