Skip to content

Commit d6c8df4

Browse files
torosentCopilot
andcommitted
Address PR review: prevent double-wrap and add cycle detection
- Add GetUnwrappedLoggerFactory() to prevent double-wrapping replay-safe loggers when wrapper contexts delegate to inner.ReplaySafeLoggerFactory - Add max-depth guard against infinite loops from misconfigured wrappers - Replace Moq with hand-rolled TrackingLoggerProvider in tests - Add regression tests for double-wrap and cycle detection - Remove unused using directive and unnecessary override in sample Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent f888b35 commit d6c8df4

3 files changed

Lines changed: 199 additions & 18 deletions

File tree

samples/ReplaySafeLoggerFactorySample/Program.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
namespace ReplaySafeLoggerFactorySample;
1919

20-
internal static class Program
20+
static class Program
2121
{
22-
private static async Task Main(string[] args)
22+
static async Task Main(string[] args)
2323
{
2424
HostApplicationBuilder builder = Host.CreateApplicationBuilder(args);
2525

@@ -72,7 +72,7 @@ private static async Task Main(string[] args)
7272
}
7373
}
7474

75-
private static void ConfigureDurableTask(
75+
static void ConfigureDurableTask(
7676
HostApplicationBuilder builder,
7777
bool useScheduler,
7878
string? schedulerConnectionString)
@@ -108,7 +108,7 @@ private static void ConfigureDurableTask(
108108
}
109109

110110
[DurableTask(nameof(ReplaySafeLoggingOrchestration))]
111-
internal sealed class ReplaySafeLoggingOrchestration : TaskOrchestrator<string, string>
111+
sealed class ReplaySafeLoggingOrchestration : TaskOrchestrator<string, string>
112112
{
113113
public override async Task<string> RunAsync(TaskOrchestrationContext context, string input)
114114
{
@@ -125,7 +125,7 @@ public override async Task<string> RunAsync(TaskOrchestrationContext context, st
125125
}
126126

127127
[DurableTask(nameof(SayHelloActivity))]
128-
internal sealed class SayHelloActivity : TaskActivity<string, string>
128+
sealed class SayHelloActivity : TaskActivity<string, string>
129129
{
130130
readonly ILogger<SayHelloActivity> logger;
131131

@@ -142,7 +142,7 @@ public override Task<string> RunAsync(TaskActivityContext context, string input)
142142
}
143143
}
144144

145-
internal sealed class LoggingTaskOrchestrationContext : TaskOrchestrationContext
145+
sealed class LoggingTaskOrchestrationContext : TaskOrchestrationContext
146146
{
147147
readonly TaskOrchestrationContext innerContext;
148148

@@ -151,6 +151,8 @@ public LoggingTaskOrchestrationContext(TaskOrchestrationContext innerContext)
151151
this.innerContext = innerContext ?? throw new ArgumentNullException(nameof(innerContext));
152152
}
153153

154+
// Only abstract members need explicit forwarding here. Virtual helpers such as
155+
// ReplaySafeLoggerFactory and the convenience overloads continue to work through these overrides.
154156
public override TaskName Name => this.innerContext.Name;
155157

156158
public override string InstanceId => this.innerContext.InstanceId;
@@ -165,8 +167,6 @@ public LoggingTaskOrchestrationContext(TaskOrchestrationContext innerContext)
165167

166168
public override IReadOnlyDictionary<string, object?> Properties => this.innerContext.Properties;
167169

168-
public override TaskOrchestrationEntityFeature Entities => this.innerContext.Entities;
169-
170170
protected override ILoggerFactory LoggerFactory => this.innerContext.ReplaySafeLoggerFactory;
171171

172172
public ILogger CreateLogger<T>()

src/Abstractions/TaskOrchestrationContext.cs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,17 +433,17 @@ public virtual Task CallSubOrchestratorAsync(
433433
/// <param name="categoryName">The logger's category name.</param>
434434
/// <returns>An instance of <see cref="ILogger"/> that is replay-safe.</returns>
435435
public virtual ILogger CreateReplaySafeLogger(string categoryName)
436-
=> new ReplaySafeLogger(this, this.LoggerFactory.CreateLogger(categoryName));
436+
=> new ReplaySafeLogger(this, this.GetUnwrappedLoggerFactory().CreateLogger(categoryName));
437437

438438
/// <inheritdoc cref="CreateReplaySafeLogger(string)" />
439439
/// <param name="type">The type to derive the category name from.</param>
440440
public virtual ILogger CreateReplaySafeLogger(Type type)
441-
=> new ReplaySafeLogger(this, this.LoggerFactory.CreateLogger(type));
441+
=> new ReplaySafeLogger(this, this.GetUnwrappedLoggerFactory().CreateLogger(type));
442442

443443
/// <inheritdoc cref="CreateReplaySafeLogger(string)" />
444444
/// <typeparam name="T">The type to derive category name from.</typeparam>
445445
public virtual ILogger CreateReplaySafeLogger<T>()
446-
=> new ReplaySafeLogger(this, this.LoggerFactory.CreateLogger<T>());
446+
=> new ReplaySafeLogger(this, this.GetUnwrappedLoggerFactory().CreateLogger<T>());
447447

448448
/// <summary>
449449
/// Checks if the current orchestration version is greater than the specified version.
@@ -466,6 +466,30 @@ public virtual int CompareVersionTo(string version)
466466
return TaskOrchestrationVersioningUtils.CompareVersions(this.Version, version);
467467
}
468468

469+
ILoggerFactory GetUnwrappedLoggerFactory()
470+
{
471+
ILoggerFactory loggerFactory = this.LoggerFactory;
472+
int depth = 0;
473+
474+
// When a wrapper context delegates LoggerFactory to inner.ReplaySafeLoggerFactory,
475+
// the returned factory is already a ReplaySafeLoggerFactoryImpl. Unwrap it to avoid
476+
// double-wrapping loggers with redundant replay-safe checks.
477+
while (loggerFactory is ReplaySafeLoggerFactoryImpl replaySafeLoggerFactory)
478+
{
479+
if (++depth > 10)
480+
{
481+
throw new InvalidOperationException(
482+
"Cycle detected while unwrapping ReplaySafeLoggerFactory. " +
483+
"Ensure the wrapper's LoggerFactory property delegates to the inner context's " +
484+
"ReplaySafeLoggerFactory (e.g., 'inner.ReplaySafeLoggerFactory'), not 'this.ReplaySafeLoggerFactory'.");
485+
}
486+
487+
loggerFactory = replaySafeLoggerFactory.UnderlyingLoggerFactory;
488+
}
489+
490+
return loggerFactory;
491+
}
492+
469493
sealed class ReplaySafeLogger : ILogger
470494
{
471495
readonly TaskOrchestrationContext context;
@@ -506,8 +530,10 @@ internal ReplaySafeLoggerFactoryImpl(TaskOrchestrationContext context)
506530
this.context = context ?? throw new ArgumentNullException(nameof(context));
507531
}
508532

533+
internal ILoggerFactory UnderlyingLoggerFactory => this.context.LoggerFactory;
534+
509535
public ILogger CreateLogger(string categoryName)
510-
=> this.context.CreateReplaySafeLogger(categoryName);
536+
=> new ReplaySafeLogger(this.context, this.context.GetUnwrappedLoggerFactory().CreateLogger(categoryName));
511537

512538
public void AddProvider(ILoggerProvider provider)
513539
=> throw new NotSupportedException(

test/Abstractions.Tests/TaskOrchestrationContextReplaySafeLoggerFactoryTests.cs

Lines changed: 161 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT License.
33

4-
using Microsoft.DurableTask.Entities;
54
using Microsoft.Extensions.Logging;
65

76
namespace Microsoft.DurableTask.Tests;
@@ -65,17 +64,53 @@ public void ReplaySafeLoggerFactory_AddProvider_ThrowsWithoutMutatingUnderlyingF
6564
TrackingLoggerProvider provider = new();
6665
TrackingLoggerFactory loggerFactory = new(provider);
6766
TestTaskOrchestrationContext context = new(loggerFactory, isReplaying: false);
68-
Mock<ILoggerProvider> additionalProvider = new();
67+
TrackingLoggerProvider additionalProvider = new();
6968

7069
// Act
71-
Action act = () => context.ReplaySafeLoggerFactory.AddProvider(additionalProvider.Object);
70+
Action act = () => context.ReplaySafeLoggerFactory.AddProvider(additionalProvider);
7271

7372
// Assert
7473
act.Should().Throw<NotSupportedException>()
7574
.WithMessage("*replay-safe logger factory*not supported*");
7675
loggerFactory.AddProviderCallCount.Should().Be(0);
7776
}
7877

78+
[Fact]
79+
public void ReplaySafeLoggerFactory_CreateLogger_FromWrappedContext_ChecksReplayOnce()
80+
{
81+
// Arrange
82+
TrackingLoggerProvider provider = new();
83+
TrackingLoggerFactory loggerFactory = new(provider);
84+
TestTaskOrchestrationContext innerContext = new(loggerFactory, isReplaying: false);
85+
WrappingTaskOrchestrationContext wrappedContext = new(innerContext);
86+
ILogger logger = wrappedContext.ReplaySafeLoggerFactory.CreateLogger("ReplaySafe");
87+
88+
// Act
89+
logger.LogInformation("This log should be written.");
90+
91+
// Assert
92+
innerContext.IsReplayingAccessCount.Should().Be(1);
93+
provider.Entries.Should().ContainSingle(entry =>
94+
entry.CategoryName == "ReplaySafe" &&
95+
entry.Message.Contains("This log should be written.", StringComparison.Ordinal));
96+
}
97+
98+
[Fact]
99+
public void ReplaySafeLoggerFactory_CreateLogger_ThrowsOnCyclicLoggerFactory()
100+
{
101+
// Arrange
102+
TrackingLoggerProvider provider = new();
103+
TrackingLoggerFactory loggerFactory = new(provider);
104+
SelfReferencingContext cyclicContext = new(loggerFactory);
105+
106+
// Act
107+
Action act = () => cyclicContext.ReplaySafeLoggerFactory.CreateLogger("Test");
108+
109+
// Assert
110+
act.Should().Throw<InvalidOperationException>()
111+
.WithMessage("*Cycle detected*");
112+
}
113+
79114
[Fact]
80115
public void ReplaySafeLoggerFactory_Dispose_DoesNotDisposeUnderlyingFactory()
81116
{
@@ -110,11 +145,18 @@ public TestTaskOrchestrationContext(ILoggerFactory loggerFactory, bool isReplayi
110145

111146
public override DateTime CurrentUtcDateTime => DateTime.UnixEpoch;
112147

113-
public override bool IsReplaying => this.isReplaying;
148+
public int IsReplayingAccessCount { get; private set; }
114149

115-
public override IReadOnlyDictionary<string, object?> Properties => new Dictionary<string, object?>();
150+
public override bool IsReplaying
151+
{
152+
get
153+
{
154+
this.IsReplayingAccessCount++;
155+
return this.isReplaying;
156+
}
157+
}
116158

117-
public override TaskOrchestrationEntityFeature Entities => throw new NotSupportedException();
159+
public override IReadOnlyDictionary<string, object?> Properties => new Dictionary<string, object?>();
118160

119161
protected override ILoggerFactory LoggerFactory => this.loggerFactory;
120162

@@ -150,6 +192,119 @@ public override Guid NewGuid()
150192
=> throw new NotImplementedException();
151193
}
152194

195+
sealed class WrappingTaskOrchestrationContext : TaskOrchestrationContext
196+
{
197+
readonly TaskOrchestrationContext innerContext;
198+
199+
public WrappingTaskOrchestrationContext(TaskOrchestrationContext innerContext)
200+
{
201+
this.innerContext = innerContext ?? throw new ArgumentNullException(nameof(innerContext));
202+
}
203+
204+
public override TaskName Name => this.innerContext.Name;
205+
206+
public override string InstanceId => this.innerContext.InstanceId;
207+
208+
public override ParentOrchestrationInstance? Parent => this.innerContext.Parent;
209+
210+
public override DateTime CurrentUtcDateTime => this.innerContext.CurrentUtcDateTime;
211+
212+
public override bool IsReplaying => this.innerContext.IsReplaying;
213+
214+
public override string Version => this.innerContext.Version;
215+
216+
public override IReadOnlyDictionary<string, object?> Properties => this.innerContext.Properties;
217+
218+
protected override ILoggerFactory LoggerFactory => this.innerContext.ReplaySafeLoggerFactory;
219+
220+
public override T GetInput<T>()
221+
where T : default
222+
=> this.innerContext.GetInput<T>()!;
223+
224+
public override Task<TResult> CallActivityAsync<TResult>(TaskName name, object? input = null, TaskOptions? options = null)
225+
=> this.innerContext.CallActivityAsync<TResult>(name, input, options);
226+
227+
public override Task CreateTimer(DateTime fireAt, CancellationToken cancellationToken)
228+
=> this.innerContext.CreateTimer(fireAt, cancellationToken);
229+
230+
public override Task<T> WaitForExternalEvent<T>(string eventName, CancellationToken cancellationToken = default)
231+
=> this.innerContext.WaitForExternalEvent<T>(eventName, cancellationToken);
232+
233+
public override void SendEvent(string instanceId, string eventName, object payload)
234+
=> this.innerContext.SendEvent(instanceId, eventName, payload);
235+
236+
public override void SetCustomStatus(object? customStatus)
237+
=> this.innerContext.SetCustomStatus(customStatus);
238+
239+
public override Task<TResult> CallSubOrchestratorAsync<TResult>(
240+
TaskName orchestratorName,
241+
object? input = null,
242+
TaskOptions? options = null)
243+
=> this.innerContext.CallSubOrchestratorAsync<TResult>(orchestratorName, input, options);
244+
245+
public override void ContinueAsNew(object? newInput = null, bool preserveUnprocessedEvents = true)
246+
=> this.innerContext.ContinueAsNew(newInput, preserveUnprocessedEvents);
247+
248+
public override Guid NewGuid()
249+
=> this.innerContext.NewGuid();
250+
}
251+
252+
sealed class SelfReferencingContext : TaskOrchestrationContext
253+
{
254+
readonly ILoggerFactory loggerFactory;
255+
256+
public SelfReferencingContext(ILoggerFactory loggerFactory)
257+
{
258+
this.loggerFactory = loggerFactory;
259+
}
260+
261+
public override TaskName Name => default;
262+
263+
public override string InstanceId => "cyclic-instance";
264+
265+
public override ParentOrchestrationInstance? Parent => null;
266+
267+
public override DateTime CurrentUtcDateTime => DateTime.UnixEpoch;
268+
269+
public override bool IsReplaying => false;
270+
271+
public override IReadOnlyDictionary<string, object?> Properties => new Dictionary<string, object?>();
272+
273+
// Bug: points at self instead of an inner context — should cause cycle detection.
274+
protected override ILoggerFactory LoggerFactory => this.ReplaySafeLoggerFactory;
275+
276+
public override T GetInput<T>()
277+
where T : default
278+
=> default!;
279+
280+
public override Task<TResult> CallActivityAsync<TResult>(TaskName name, object? input = null, TaskOptions? options = null)
281+
=> throw new NotImplementedException();
282+
283+
public override Task CreateTimer(DateTime fireAt, CancellationToken cancellationToken)
284+
=> throw new NotImplementedException();
285+
286+
public override Task<T> WaitForExternalEvent<T>(string eventName, CancellationToken cancellationToken = default)
287+
=> throw new NotImplementedException();
288+
289+
public override void SendEvent(string instanceId, string eventName, object payload)
290+
=> throw new NotImplementedException();
291+
292+
public override void SetCustomStatus(object? customStatus)
293+
=> throw new NotImplementedException();
294+
295+
public override Task<TResult> CallSubOrchestratorAsync<TResult>(
296+
TaskName orchestratorName,
297+
object? input = null,
298+
TaskOptions? options = null)
299+
=> throw new NotImplementedException();
300+
301+
public override void ContinueAsNew(object? newInput = null, bool preserveUnprocessedEvents = true)
302+
=> throw new NotImplementedException();
303+
304+
public override Guid NewGuid()
305+
=> throw new NotImplementedException();
306+
}
307+
153308
sealed class TrackingLoggerFactory : ILoggerFactory
154309
{
155310
readonly TrackingLoggerProvider provider;

0 commit comments

Comments
 (0)