Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 66 additions & 36 deletions TUnit.Mocks/MockEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ public sealed class MockEngine<T> : IMockEngineAccess where T : class
{
private readonly Lock _setupLock = new();
private Dictionary<int, List<MethodSetup>>? _setupsByMember;
private ConcurrentQueue<CallRecord>? _callHistory;
private volatile Dictionary<int, MethodSetup[]>? _setupsSnapshot;
private volatile bool _hasStatefulSetups;
private volatile ConcurrentQueue<CallRecord> _callHistory = new();

private ConcurrentDictionary<string, object?>? _autoTrackValues;
private ConcurrentQueue<(string EventName, bool IsSubscribe)>? _eventSubscriptions;
Expand Down Expand Up @@ -127,6 +129,19 @@ public void AddSetup(MethodSetup setup)
}

list.Add(setup);

if (setup.RequiredState is not null || setup.TransitionTarget is not null)
{
_hasStatefulSetups = true;
}

// Rebuild lock-free snapshot: shallow-copy existing, only re-array the affected member
var prev = _setupsSnapshot;
var snapshot = prev is null
? new Dictionary<int, MethodSetup[]>()
: new Dictionary<int, MethodSetup[]>(prev);
snapshot[setup.MemberId] = list.ToArray();
_setupsSnapshot = snapshot;
}
}

Expand Down Expand Up @@ -386,13 +401,8 @@ public bool TryHandleCallWithReturn<TReturn>(int memberId, string memberName, ob
/// </summary>
public IReadOnlyList<CallRecord> GetCallsFor(int memberId)
{
if (Volatile.Read(ref _callHistory) is not { } history)
{
return [];
}

var result = new List<CallRecord>();
foreach (var record in history)
foreach (var record in _callHistory)
{
if (record.MemberId == memberId)
{
Expand All @@ -407,7 +417,7 @@ public IReadOnlyList<CallRecord> GetCallsFor(int memberId)
/// </summary>
public IReadOnlyList<CallRecord> GetAllCalls()
{
return Volatile.Read(ref _callHistory)?.ToArray() ?? [];
return _callHistory.ToArray();
}

/// <summary>
Expand All @@ -416,13 +426,8 @@ public IReadOnlyList<CallRecord> GetAllCalls()
[EditorBrowsable(EditorBrowsableState.Never)]
public IReadOnlyList<CallRecord> GetUnverifiedCalls()
{
if (Volatile.Read(ref _callHistory) is not { } history)
{
return [];
}

var result = new List<CallRecord>();
foreach (var record in history)
foreach (var record in _callHistory)
{
if (!record.IsVerified)
{
Expand All @@ -438,20 +443,18 @@ public IReadOnlyList<CallRecord> GetUnverifiedCalls()
[EditorBrowsable(EditorBrowsableState.Never)]
public IReadOnlyList<MethodSetup> GetSetups()
{
lock (_setupLock)
var snapshot = _setupsSnapshot;
if (snapshot is null)
{
if (_setupsByMember is not { } setups)
{
return [];
}
return [];
}

var all = new List<MethodSetup>();
foreach (var list in setups.Values)
{
all.AddRange(list);
}
return all;
var all = new List<MethodSetup>();
foreach (var arr in snapshot.Values)
{
all.AddRange(arr);
}
return all;
}

/// <summary>
Expand Down Expand Up @@ -482,14 +485,11 @@ public Diagnostics.MockDiagnostics GetDiagnostics()
}

var unmatchedCalls = new List<CallRecord>();
if (Volatile.Read(ref _callHistory) is { } history)
foreach (var call in _callHistory)
{
foreach (var call in history)
if (call.IsUnmatched)
{
if (call.IsUnmatched)
{
unmatchedCalls.Add(call);
}
unmatchedCalls.Add(call);
}
}

Expand Down Expand Up @@ -518,11 +518,13 @@ public void Reset()
lock (_setupLock)
{
_setupsByMember = null;
_setupsSnapshot = null;
_hasStatefulSetups = false;
_currentState = null;
PendingRequiredState = null;
}

Volatile.Write(ref _callHistory, null);
_callHistory = new ConcurrentQueue<CallRecord>(); // volatile field — assignment is a volatile write
Volatile.Write(ref _autoTrackValues, null);
Volatile.Write(ref _eventSubscriptions, null);
Volatile.Write(ref _onSubscribeCallbacks, null);
Expand Down Expand Up @@ -609,7 +611,7 @@ private CallRecord RecordCall(int memberId, string memberName, object?[] args)
{
var seq = MockCallSequence.Next();
var record = new CallRecord(memberId, memberName, args, seq);
LazyInitializer.EnsureInitialized(ref _callHistory)!.Enqueue(record);
_callHistory.Enqueue(record);
return record;
}

Expand All @@ -625,6 +627,37 @@ private void RaiseEventsForSetup(MethodSetup setup)
}

private (bool SetupFound, IBehavior? Behavior, MethodSetup? Setup) FindMatchingSetup(int memberId, object?[] args)
{
// When state machine features are in use, serialize the full match-and-transition
// to prevent concurrent invocations from consuming the same state transition
if (_hasStatefulSetups)
{
return FindMatchingSetupLocked(memberId, args);
}

var snapshot = _setupsSnapshot;
if (snapshot is null || !snapshot.TryGetValue(memberId, out var setups))
{
return (false, null, null);
}

// Iterate last-added-first to implement "last wins" semantics
for (int i = setups.Length - 1; i >= 0; i--)
{
var setup = setups[i];

if (setup.Matches(args))
{
setup.IncrementInvokeCount();
setup.ApplyCaptures(args);
return (true, setup.GetNextBehavior(), setup);
}
}

return (false, null, null);
}

private (bool SetupFound, IBehavior? Behavior, MethodSetup? Setup) FindMatchingSetupLocked(int memberId, object?[] args)
{
lock (_setupLock)
{
Expand All @@ -633,12 +666,10 @@ private void RaiseEventsForSetup(MethodSetup setup)
return (false, null, null);
}

// Iterate last-added-first to implement "last wins" semantics
for (int i = setups.Count - 1; i >= 0; i--)
{
var setup = setups[i];

// State guard: skip setups that require a different state
if (setup.RequiredState is not null && setup.RequiredState != _currentState)
{
continue;
Expand All @@ -648,7 +679,6 @@ private void RaiseEventsForSetup(MethodSetup setup)
{
setup.IncrementInvokeCount();
setup.ApplyCaptures(args);
// Apply state transition inside the lock to prevent data races on _currentState
if (setup.TransitionTarget is not null)
{
_currentState = setup.TransitionTarget;
Expand Down
54 changes: 35 additions & 19 deletions TUnit.Mocks/Setup/MethodSetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ namespace TUnit.Mocks.Setup;
public sealed class MethodSetup
{
private readonly IArgumentMatcher[] _matchers;
private Lock? _behaviorLock;
private List<IBehavior>? _behaviors;
private readonly Lock _behaviorLock = new();
/// <summary>Fast path for the common single-behavior case. Avoids list + lock on read.</summary>
private volatile IBehavior? _singleBehavior;
private volatile List<IBehavior>? _behaviors;
private List<EventRaiseInfo>? _eventRaises;
private EventRaiseInfo[]? _eventRaisesSnapshot;
private Dictionary<int, object?>? _outRefAssignments;
private int _callIndex;

private Lock EnsureBehaviorLock() => LazyInitializer.EnsureInitialized(ref _behaviorLock)!;

public int MemberId { get; }

/// <summary>
Expand Down Expand Up @@ -62,10 +62,26 @@ public MethodSetup(int memberId, IArgumentMatcher[] matchers, string memberName

public void AddBehavior(IBehavior behavior)
{
lock (EnsureBehaviorLock())
lock (_behaviorLock)
{
var list = _behaviors ??= new();
list.Add(behavior);
if (_singleBehavior is null && _behaviors is null)
{
_singleBehavior = behavior;
return;
}

// Promote to list on second behavior. Write _behaviors before clearing
// _singleBehavior: both fields are volatile, so the volatile write to
// _singleBehavior acts as a release fence, guaranteeing that a lock-free
// reader in GetNextBehavior that sees _singleBehavior == null will also
// see the updated _behaviors reference.
if (_behaviors is null)
{
_behaviors = [_singleBehavior!];
}

_behaviors.Add(behavior);
_singleBehavior = null;
}
}

Expand All @@ -84,7 +100,7 @@ public bool Matches(object?[] actualArgs)

public void AddEventRaise(EventRaiseInfo raiseInfo)
{
lock (EnsureBehaviorLock())
lock (_behaviorLock)
{
var list = _eventRaises ??= new();
list.Add(raiseInfo);
Expand All @@ -105,7 +121,7 @@ public IReadOnlyList<EventRaiseInfo> GetEventRaises()
return snapshot;
}

lock (EnsureBehaviorLock())
lock (_behaviorLock)
{
return _eventRaisesSnapshot ??= _eventRaises!.ToArray();
}
Expand Down Expand Up @@ -134,7 +150,7 @@ public void ApplyCaptures(object?[] args)
/// <param name="value">The value to assign.</param>
public void SetOutRefValue(int paramIndex, object? value)
{
lock (EnsureBehaviorLock())
lock (_behaviorLock)
{
_outRefAssignments ??= new Dictionary<int, object?>();
_outRefAssignments[paramIndex] = value;
Expand All @@ -149,13 +165,7 @@ public void SetOutRefValue(int paramIndex, object? value)
{
get
{
var lck = Volatile.Read(ref _behaviorLock);
if (lck is null)
{
return null;
}

lock (lck)
lock (_behaviorLock)
{
return _outRefAssignments;
}
Expand All @@ -178,12 +188,18 @@ public string[] GetMatcherDescriptions()

public IBehavior? GetNextBehavior()
{
if (Volatile.Read(ref _behaviors) is null)
// Fast path: single behavior (most common case — no lock needed)
if (_singleBehavior is { } single)
{
return single;
}

if (_behaviors is null)
{
return null;
}

lock (EnsureBehaviorLock())
lock (_behaviorLock)
{
if (_behaviors is not { Count: > 0 } behaviors)
{
Expand Down
32 changes: 13 additions & 19 deletions TUnit.Mocks/Verification/CallVerificationBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ public void WasCalled(Times times, string? message)

var allCallsForMember = _engine.GetCallsFor(_memberId);

// Filter by matchers
var matchingCalls = FilterByMatchers(allCallsForMember);
var matchingCount = matchingCalls.Count;
var matchingCount = CountMatchingCalls(allCallsForMember, markVerified: false);

if (!times.Matches(matchingCount))
{
Expand All @@ -59,11 +57,8 @@ public void WasCalled(Times times, string? message)
throw new MockVerificationException(expectedCall, times, matchingCount, actualCallDescriptions, message);
}

// Mark matched calls as verified for VerifyNoOtherCalls
foreach (var call in matchingCalls)
{
call.IsVerified = true;
}
// Mark matched calls as verified only after assertion passes
CountMatchingCalls(allCallsForMember, markVerified: true);
}

/// <inheritdoc />
Expand All @@ -78,22 +73,21 @@ public void WasCalled(Times times, string? message)
/// <inheritdoc />
public void WasCalled(string? message) => WasCalled(Times.AtLeastOnce, message);

private List<CallRecord> FilterByMatchers(IReadOnlyList<CallRecord> calls)
private int CountMatchingCalls(IReadOnlyList<CallRecord> calls, bool markVerified)
{
if (_matchers.Length == 0)
{
return calls.ToList();
}

var result = new List<CallRecord>();
foreach (var call in calls)
var count = 0;
for (int i = 0; i < calls.Count; i++)
{
if (MatchesArguments(call.Arguments))
if (_matchers.Length == 0 || MatchesArguments(calls[i].Arguments))
{
result.Add(call);
count++;
if (markVerified)
{
calls[i].IsVerified = true;
}
}
}
return result;
return count;
}

private bool MatchesArguments(object?[] arguments)
Expand Down
Loading