@@ -8,12 +8,10 @@ namespace ModelContextProtocol.Client;
88/// <summary>Provides the client side of a stream-based session transport.</summary>
99internal class StreamClientSessionTransport : TransportBase
1010{
11- private static readonly byte [ ] s_newlineBytes = "\n "u8 . ToArray ( ) ;
12-
1311 internal static UTF8Encoding NoBomUtf8Encoding { get ; } = new ( encoderShouldEmitUTF8Identifier : false ) ;
1412
1513 private readonly TextReader _serverOutput ;
16- private readonly Stream _serverInputStream ;
14+ private readonly TextWriter _serverInput ;
1715 private readonly SemaphoreSlim _sendLock = new ( 1 , 1 ) ;
1816 private CancellationTokenSource ? _shutdownCts = new ( ) ;
1917 private Task ? _readTask ;
@@ -22,13 +20,12 @@ internal class StreamClientSessionTransport : TransportBase
2220 /// Initializes a new instance of the <see cref="StreamClientSessionTransport"/> class.
2321 /// </summary>
2422 /// <param name="serverInput">
25- /// The server's input stream. Messages written to this stream will be sent to the server.
23+ /// The text writer connected to the server's input stream.
24+ /// Messages written to this writer will be sent to the server.
2625 /// </param>
2726 /// <param name="serverOutput">
28- /// The server's output stream. Messages read from this stream will be received from the server.
29- /// </param>
30- /// <param name="encoding">
31- /// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
27+ /// The text reader connected to the server's output stream.
28+ /// Messages read from this reader will be received from the server.
3229 /// </param>
3330 /// <param name="endpointName">
3431 /// A name that identifies this transport endpoint in logs.
@@ -40,18 +37,12 @@ internal class StreamClientSessionTransport : TransportBase
4037 /// This constructor starts a background task to read messages from the server output stream.
4138 /// The transport will be marked as connected once initialized.
4239 /// </remarks>
43- public StreamClientSessionTransport ( Stream serverInput , Stream serverOutput , Encoding ? encoding , string endpointName , ILoggerFactory ? loggerFactory )
40+ public StreamClientSessionTransport (
41+ TextWriter serverInput , TextReader serverOutput , string endpointName , ILoggerFactory ? loggerFactory )
4442 : base ( endpointName , loggerFactory )
4543 {
46- Throw . IfNull ( serverInput ) ;
47- Throw . IfNull ( serverOutput ) ;
48-
49- _serverInputStream = serverInput ;
50- #if NET
51- _serverOutput = new StreamReader ( serverOutput , encoding ?? NoBomUtf8Encoding ) ;
52- #else
53- _serverOutput = new CancellableStreamReader ( serverOutput , encoding ?? NoBomUtf8Encoding ) ;
54- #endif
44+ _serverOutput = serverOutput ;
45+ _serverInput = serverInput ;
5546
5647 SetConnected ( ) ;
5748
@@ -66,6 +57,43 @@ public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Enc
6657 readTask . Start ( ) ;
6758 }
6859
60+ /// <summary>
61+ /// Initializes a new instance of the <see cref="StreamClientSessionTransport"/> class.
62+ /// </summary>
63+ /// <param name="serverInput">
64+ /// The server's input stream. Messages written to this stream will be sent to the server.
65+ /// </param>
66+ /// <param name="serverOutput">
67+ /// The server's output stream. Messages read from this stream will be received from the server.
68+ /// </param>
69+ /// <param name="encoding">
70+ /// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
71+ /// </param>
72+ /// <param name="endpointName">
73+ /// A name that identifies this transport endpoint in logs.
74+ /// </param>
75+ /// <param name="loggerFactory">
76+ /// Optional factory for creating loggers. If null, a NullLogger is used.
77+ /// </param>
78+ /// <remarks>
79+ /// This constructor starts a background task to read messages from the server output stream.
80+ /// The transport will be marked as connected once initialized.
81+ /// </remarks>
82+ public StreamClientSessionTransport ( Stream serverInput , Stream serverOutput , Encoding ? encoding , string endpointName , ILoggerFactory ? loggerFactory )
83+ : this (
84+ new StreamWriter ( serverInput , encoding ?? NoBomUtf8Encoding ) ,
85+ #if NET
86+ new StreamReader ( serverOutput , encoding ?? NoBomUtf8Encoding ) ,
87+ #else
88+ new CancellableStreamReader ( serverOutput , encoding ?? NoBomUtf8Encoding ) ,
89+ #endif
90+ endpointName,
91+ loggerFactory )
92+ {
93+ Throw . IfNull ( serverInput ) ;
94+ Throw . IfNull ( serverOutput ) ;
95+ }
96+
6997 /// <inheritdoc/>
7098 public override async Task SendMessageAsync ( JsonRpcMessage message , CancellationToken cancellationToken = default )
7199 {
@@ -77,13 +105,14 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation
77105
78106 LogTransportSendingMessageSensitive ( message ) ;
79107
108+ var json = JsonSerializer . Serialize ( message , McpJsonUtilities . JsonContext . Default . JsonRpcMessage ) ;
109+
80110 using var _ = await _sendLock . LockAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
81111 try
82112 {
83- var json = JsonSerializer . SerializeToUtf8Bytes ( message , McpJsonUtilities . JsonContext . Default . JsonRpcMessage ) ;
84- await _serverInputStream . WriteAsync ( json , cancellationToken ) . ConfigureAwait ( false ) ;
85- await _serverInputStream . WriteAsync ( s_newlineBytes , cancellationToken ) . ConfigureAwait ( false ) ;
86- await _serverInputStream . FlushAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
113+ // Write the message followed by a newline using our UTF-8 writer
114+ await _serverInput . WriteLineAsync ( json ) . ConfigureAwait ( false ) ;
115+ await _serverInput . FlushAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
87116 }
88117 catch ( Exception ex )
89118 {
0 commit comments