Skip to content

Commit 89fcf3f

Browse files
Copilotstephentoub
andauthored
Wait for in-flight message handlers before ProcessMessagesCoreAsync returns (#1403)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: stephentoub <2642209+stephentoub@users.noreply.github.com>
1 parent 2894c0d commit 89fcf3f

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

src/ModelContextProtocol.Core/McpSessionHandler.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,19 @@ public Task ProcessMessagesAsync(CancellationToken cancellationToken)
182182

183183
private async Task ProcessMessagesCoreAsync(CancellationToken cancellationToken)
184184
{
185+
// Track in-flight message handlers so we can wait for them to complete before returning.
186+
// Start at 1 to represent ProcessMessagesCoreAsync itself; it's decremented after the loop exits.
187+
int inFlightCount = 1;
188+
var allHandlersCompleted = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
189+
185190
try
186191
{
187192
await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
188193
{
189194
LogMessageRead(EndpointName, message.GetType().Name);
190195

196+
Interlocked.Increment(ref inFlightCount);
197+
191198
// Fire and forget the message handling to avoid blocking the transport.
192199
if (message.Context?.ExecutionContext is null)
193200
{
@@ -295,6 +302,11 @@ ex is OperationCanceledException &&
295302
_handlingRequests.TryRemove(messageWithId.Id, out _);
296303
combinedCts!.Dispose();
297304
}
305+
306+
if (Interlocked.Decrement(ref inFlightCount) == 0)
307+
{
308+
allHandlersCompleted.TrySetResult(true);
309+
}
298310
}
299311
}
300312
}
@@ -306,6 +318,12 @@ ex is OperationCanceledException &&
306318
}
307319
finally
308320
{
321+
// Decrement our own count. If all handlers have already completed, this will signal completion.
322+
if (Interlocked.Decrement(ref inFlightCount) != 0)
323+
{
324+
await allHandlersCompleted.Task.ConfigureAwait(false);
325+
}
326+
309327
// Fail any pending requests, as they'll never be satisfied.
310328
foreach (var entry in _pendingRequests)
311329
{

tests/ModelContextProtocol.Tests/Server/McpServerTests.cs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,60 @@ await transport.SendClientMessageAsync(new JsonRpcNotification
12411241
await runTask;
12421242
}
12431243

1244+
[Fact]
1245+
public async Task RunAsync_WaitsForInFlightHandlersBeforeReturning()
1246+
{
1247+
// Arrange: Create a tool handler that blocks until we release it.
1248+
var handlerStarted = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
1249+
var releaseHandler = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
1250+
bool handlerCompleted = false;
1251+
1252+
await using var transport = new TestServerTransport();
1253+
var options = CreateOptions(new ServerCapabilities { Tools = new() });
1254+
options.Handlers.CallToolHandler = async (request, ct) =>
1255+
{
1256+
handlerStarted.SetResult(true);
1257+
await releaseHandler.Task;
1258+
handlerCompleted = true;
1259+
return new CallToolResult { Content = [new TextContentBlock { Text = "done" }] };
1260+
};
1261+
options.Handlers.ListToolsHandler = (request, ct) => throw new NotImplementedException();
1262+
1263+
await using var server = McpServer.Create(transport, options, LoggerFactory);
1264+
var runTask = server.RunAsync(TestContext.Current.CancellationToken);
1265+
1266+
// Send a tool call request.
1267+
await transport.SendClientMessageAsync(
1268+
new JsonRpcRequest
1269+
{
1270+
Method = RequestMethods.ToolsCall,
1271+
Id = new RequestId(1)
1272+
},
1273+
TestContext.Current.CancellationToken);
1274+
1275+
// Wait for the handler to start executing.
1276+
await handlerStarted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken);
1277+
1278+
// Dispose the transport to simulate client disconnect while the handler is still running.
1279+
await transport.DisposeAsync();
1280+
1281+
// Release the handler after a delay, giving ProcessMessagesCoreAsync time to notice the
1282+
// channel closed. Without the fix, RunAsync would return before the handler completes.
1283+
var ct = TestContext.Current.CancellationToken;
1284+
_ = Task.Run(async () =>
1285+
{
1286+
await Task.Delay(200, ct);
1287+
releaseHandler.SetResult(true);
1288+
}, ct);
1289+
1290+
// Wait for RunAsync to complete.
1291+
await runTask.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken);
1292+
1293+
// With the fix, RunAsync waits for in-flight handlers. Without it, it returns immediately
1294+
// after the transport closes (before the 500ms delay releases the handler).
1295+
Assert.True(handlerCompleted, "RunAsync should wait for in-flight handlers to complete before returning.");
1296+
}
1297+
12441298
private static async Task InitializeServerAsync(TestServerTransport transport, ClientCapabilities capabilities, CancellationToken cancellationToken = default)
12451299
{
12461300
var initializeRequest = new JsonRpcRequest

0 commit comments

Comments
 (0)