diff --git a/TUnit.Mocks/MockEngine.cs b/TUnit.Mocks/MockEngine.cs index 2e69abad53..b630b9faac 100644 --- a/TUnit.Mocks/MockEngine.cs +++ b/TUnit.Mocks/MockEngine.cs @@ -24,7 +24,9 @@ public sealed class MockEngine : IMockEngineAccess where T : class { private readonly Lock _setupLock = new(); private Dictionary>? _setupsByMember; - private ConcurrentQueue? _callHistory; + private volatile Dictionary? _setupsSnapshot; + private volatile bool _hasStatefulSetups; + private volatile ConcurrentQueue _callHistory = new(); private ConcurrentDictionary? _autoTrackValues; private ConcurrentQueue<(string EventName, bool IsSubscribe)>? _eventSubscriptions; @@ -127,6 +129,19 @@ public void AddSetup(MethodSetup setup) } list.Add(setup); + + if (setup.RequiredState is not null || setup.TransitionTarget is not null) + { + _hasStatefulSetups = true; + } + + // Rebuild lock-free snapshot: shallow-copy existing, only re-array the affected member + var prev = _setupsSnapshot; + var snapshot = prev is null + ? new Dictionary() + : new Dictionary(prev); + snapshot[setup.MemberId] = list.ToArray(); + _setupsSnapshot = snapshot; } } @@ -386,13 +401,8 @@ public bool TryHandleCallWithReturn(int memberId, string memberName, ob /// public IReadOnlyList GetCallsFor(int memberId) { - if (Volatile.Read(ref _callHistory) is not { } history) - { - return []; - } - var result = new List(); - foreach (var record in history) + foreach (var record in _callHistory) { if (record.MemberId == memberId) { @@ -407,7 +417,7 @@ public IReadOnlyList GetCallsFor(int memberId) /// public IReadOnlyList GetAllCalls() { - return Volatile.Read(ref _callHistory)?.ToArray() ?? []; + return _callHistory.ToArray(); } /// @@ -416,13 +426,8 @@ public IReadOnlyList GetAllCalls() [EditorBrowsable(EditorBrowsableState.Never)] public IReadOnlyList GetUnverifiedCalls() { - if (Volatile.Read(ref _callHistory) is not { } history) - { - return []; - } - var result = new List(); - foreach (var record in history) + foreach (var record in _callHistory) { if (!record.IsVerified) { @@ -438,20 +443,18 @@ public IReadOnlyList GetUnverifiedCalls() [EditorBrowsable(EditorBrowsableState.Never)] public IReadOnlyList GetSetups() { - lock (_setupLock) + var snapshot = _setupsSnapshot; + if (snapshot is null) { - if (_setupsByMember is not { } setups) - { - return []; - } + return []; + } - var all = new List(); - foreach (var list in setups.Values) - { - all.AddRange(list); - } - return all; + var all = new List(); + foreach (var arr in snapshot.Values) + { + all.AddRange(arr); } + return all; } /// @@ -482,14 +485,11 @@ public Diagnostics.MockDiagnostics GetDiagnostics() } var unmatchedCalls = new List(); - if (Volatile.Read(ref _callHistory) is { } history) + foreach (var call in _callHistory) { - foreach (var call in history) + if (call.IsUnmatched) { - if (call.IsUnmatched) - { - unmatchedCalls.Add(call); - } + unmatchedCalls.Add(call); } } @@ -518,11 +518,13 @@ public void Reset() lock (_setupLock) { _setupsByMember = null; + _setupsSnapshot = null; + _hasStatefulSetups = false; _currentState = null; PendingRequiredState = null; } - Volatile.Write(ref _callHistory, null); + _callHistory = new ConcurrentQueue(); // volatile field — assignment is a volatile write Volatile.Write(ref _autoTrackValues, null); Volatile.Write(ref _eventSubscriptions, null); Volatile.Write(ref _onSubscribeCallbacks, null); @@ -609,7 +611,7 @@ private CallRecord RecordCall(int memberId, string memberName, object?[] args) { var seq = MockCallSequence.Next(); var record = new CallRecord(memberId, memberName, args, seq); - LazyInitializer.EnsureInitialized(ref _callHistory)!.Enqueue(record); + _callHistory.Enqueue(record); return record; } @@ -625,6 +627,37 @@ private void RaiseEventsForSetup(MethodSetup setup) } private (bool SetupFound, IBehavior? Behavior, MethodSetup? Setup) FindMatchingSetup(int memberId, object?[] args) + { + // When state machine features are in use, serialize the full match-and-transition + // to prevent concurrent invocations from consuming the same state transition + if (_hasStatefulSetups) + { + return FindMatchingSetupLocked(memberId, args); + } + + var snapshot = _setupsSnapshot; + if (snapshot is null || !snapshot.TryGetValue(memberId, out var setups)) + { + return (false, null, null); + } + + // Iterate last-added-first to implement "last wins" semantics + for (int i = setups.Length - 1; i >= 0; i--) + { + var setup = setups[i]; + + if (setup.Matches(args)) + { + setup.IncrementInvokeCount(); + setup.ApplyCaptures(args); + return (true, setup.GetNextBehavior(), setup); + } + } + + return (false, null, null); + } + + private (bool SetupFound, IBehavior? Behavior, MethodSetup? Setup) FindMatchingSetupLocked(int memberId, object?[] args) { lock (_setupLock) { @@ -633,12 +666,10 @@ private void RaiseEventsForSetup(MethodSetup setup) return (false, null, null); } - // Iterate last-added-first to implement "last wins" semantics for (int i = setups.Count - 1; i >= 0; i--) { var setup = setups[i]; - // State guard: skip setups that require a different state if (setup.RequiredState is not null && setup.RequiredState != _currentState) { continue; @@ -648,7 +679,6 @@ private void RaiseEventsForSetup(MethodSetup setup) { setup.IncrementInvokeCount(); setup.ApplyCaptures(args); - // Apply state transition inside the lock to prevent data races on _currentState if (setup.TransitionTarget is not null) { _currentState = setup.TransitionTarget; diff --git a/TUnit.Mocks/Setup/MethodSetup.cs b/TUnit.Mocks/Setup/MethodSetup.cs index 0370b2f8c2..55ad23b17e 100644 --- a/TUnit.Mocks/Setup/MethodSetup.cs +++ b/TUnit.Mocks/Setup/MethodSetup.cs @@ -11,15 +11,15 @@ namespace TUnit.Mocks.Setup; public sealed class MethodSetup { private readonly IArgumentMatcher[] _matchers; - private Lock? _behaviorLock; - private List? _behaviors; + private readonly Lock _behaviorLock = new(); + /// Fast path for the common single-behavior case. Avoids list + lock on read. + private volatile IBehavior? _singleBehavior; + private volatile List? _behaviors; private List? _eventRaises; private EventRaiseInfo[]? _eventRaisesSnapshot; private Dictionary? _outRefAssignments; private int _callIndex; - private Lock EnsureBehaviorLock() => LazyInitializer.EnsureInitialized(ref _behaviorLock)!; - public int MemberId { get; } /// @@ -62,10 +62,26 @@ public MethodSetup(int memberId, IArgumentMatcher[] matchers, string memberName public void AddBehavior(IBehavior behavior) { - lock (EnsureBehaviorLock()) + lock (_behaviorLock) { - var list = _behaviors ??= new(); - list.Add(behavior); + if (_singleBehavior is null && _behaviors is null) + { + _singleBehavior = behavior; + return; + } + + // Promote to list on second behavior. Write _behaviors before clearing + // _singleBehavior: both fields are volatile, so the volatile write to + // _singleBehavior acts as a release fence, guaranteeing that a lock-free + // reader in GetNextBehavior that sees _singleBehavior == null will also + // see the updated _behaviors reference. + if (_behaviors is null) + { + _behaviors = [_singleBehavior!]; + } + + _behaviors.Add(behavior); + _singleBehavior = null; } } @@ -84,7 +100,7 @@ public bool Matches(object?[] actualArgs) public void AddEventRaise(EventRaiseInfo raiseInfo) { - lock (EnsureBehaviorLock()) + lock (_behaviorLock) { var list = _eventRaises ??= new(); list.Add(raiseInfo); @@ -105,7 +121,7 @@ public IReadOnlyList GetEventRaises() return snapshot; } - lock (EnsureBehaviorLock()) + lock (_behaviorLock) { return _eventRaisesSnapshot ??= _eventRaises!.ToArray(); } @@ -134,7 +150,7 @@ public void ApplyCaptures(object?[] args) /// The value to assign. public void SetOutRefValue(int paramIndex, object? value) { - lock (EnsureBehaviorLock()) + lock (_behaviorLock) { _outRefAssignments ??= new Dictionary(); _outRefAssignments[paramIndex] = value; @@ -149,13 +165,7 @@ public void SetOutRefValue(int paramIndex, object? value) { get { - var lck = Volatile.Read(ref _behaviorLock); - if (lck is null) - { - return null; - } - - lock (lck) + lock (_behaviorLock) { return _outRefAssignments; } @@ -178,12 +188,18 @@ public string[] GetMatcherDescriptions() public IBehavior? GetNextBehavior() { - if (Volatile.Read(ref _behaviors) is null) + // Fast path: single behavior (most common case — no lock needed) + if (_singleBehavior is { } single) + { + return single; + } + + if (_behaviors is null) { return null; } - lock (EnsureBehaviorLock()) + lock (_behaviorLock) { if (_behaviors is not { Count: > 0 } behaviors) { diff --git a/TUnit.Mocks/Verification/CallVerificationBuilder.cs b/TUnit.Mocks/Verification/CallVerificationBuilder.cs index 2c8d3e6137..a1671f2d9d 100644 --- a/TUnit.Mocks/Verification/CallVerificationBuilder.cs +++ b/TUnit.Mocks/Verification/CallVerificationBuilder.cs @@ -48,9 +48,7 @@ public void WasCalled(Times times, string? message) var allCallsForMember = _engine.GetCallsFor(_memberId); - // Filter by matchers - var matchingCalls = FilterByMatchers(allCallsForMember); - var matchingCount = matchingCalls.Count; + var matchingCount = CountMatchingCalls(allCallsForMember, markVerified: false); if (!times.Matches(matchingCount)) { @@ -59,11 +57,8 @@ public void WasCalled(Times times, string? message) throw new MockVerificationException(expectedCall, times, matchingCount, actualCallDescriptions, message); } - // Mark matched calls as verified for VerifyNoOtherCalls - foreach (var call in matchingCalls) - { - call.IsVerified = true; - } + // Mark matched calls as verified only after assertion passes + CountMatchingCalls(allCallsForMember, markVerified: true); } /// @@ -78,22 +73,21 @@ public void WasCalled(Times times, string? message) /// public void WasCalled(string? message) => WasCalled(Times.AtLeastOnce, message); - private List FilterByMatchers(IReadOnlyList calls) + private int CountMatchingCalls(IReadOnlyList calls, bool markVerified) { - if (_matchers.Length == 0) - { - return calls.ToList(); - } - - var result = new List(); - foreach (var call in calls) + var count = 0; + for (int i = 0; i < calls.Count; i++) { - if (MatchesArguments(call.Arguments)) + if (_matchers.Length == 0 || MatchesArguments(calls[i].Arguments)) { - result.Add(call); + count++; + if (markVerified) + { + calls[i].IsVerified = true; + } } } - return result; + return count; } private bool MatchesArguments(object?[] arguments)