@@ -11,7 +11,10 @@ use tracing::debug;
1111use super :: common:: client_side_sse:: { ExponentialBackoff , SseRetryPolicy , SseStreamReconnect } ;
1212use crate :: {
1313 RoleClient ,
14- model:: { ClientJsonRpcMessage , ServerJsonRpcMessage , ServerResult } ,
14+ model:: {
15+ ClientJsonRpcMessage , ClientNotification , InitializedNotification , ServerJsonRpcMessage ,
16+ ServerResult ,
17+ } ,
1518 transport:: {
1619 common:: client_side_sse:: SseAutoReconnectStream ,
1720 worker:: { Worker , WorkerQuitReason , WorkerSendRequest , WorkerTransport } ,
@@ -79,6 +82,8 @@ pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
7982 InsufficientScope ( InsufficientScopeError ) ,
8083 #[ error( "Header name '{0}' is reserved and conflicts with default headers" ) ]
8184 ReservedHeaderConflict ( String ) ,
85+ #[ error( "Session expired (HTTP 404)" ) ]
86+ SessionExpired ,
8287}
8388
8489#[ derive( Debug , Clone , Error ) ]
@@ -307,6 +312,69 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
307312 }
308313 Ok ( ( ) )
309314 }
315+
316+ /// Performs a transparent re-initialization handshake after a session-expired 404.
317+ ///
318+ /// Takes an owned clone of the client (avoiding `&self` across `.await` so the
319+ /// future remains `Send` without requiring `C: Sync`). POSTs the saved
320+ /// initialize request without a session ID, extracts the new session ID and
321+ /// protocol version, sends `notifications/initialized`, and returns the new
322+ /// `(session_id, protocol_headers)` pair. The init result message is **not**
323+ /// forwarded to the handler because the handler already processed the original
324+ /// initialization.
325+ async fn perform_reinitialization (
326+ client : C ,
327+ saved_init_request : ClientJsonRpcMessage ,
328+ uri : Arc < str > ,
329+ auth_header : Option < String > ,
330+ custom_headers : HashMap < HeaderName , HeaderValue > ,
331+ ) -> Result < ( Option < Arc < str > > , HashMap < HeaderName , HeaderValue > ) , StreamableHttpError < C :: Error > >
332+ {
333+ let ( init_msg, new_session_id_str) = client
334+ . post_message (
335+ uri. clone ( ) ,
336+ saved_init_request,
337+ None ,
338+ auth_header. clone ( ) ,
339+ custom_headers. clone ( ) ,
340+ )
341+ . await ?
342+ . expect_initialized :: < C :: Error > ( )
343+ . await ?;
344+
345+ let new_session_id: Option < Arc < str > > = new_session_id_str. map ( |s| Arc :: from ( s. as_str ( ) ) ) ;
346+
347+ // Start from custom_headers, then inject the negotiated MCP-Protocol-Version
348+ // so all subsequent requests carry the right version (MCP 2025-06-18 spec).
349+ let mut new_protocol_headers = custom_headers;
350+ if let ServerJsonRpcMessage :: Response ( response) = & init_msg {
351+ if let ServerResult :: InitializeResult ( init_result) = & response. result {
352+ if let Ok ( hv) = HeaderValue :: from_str ( init_result. protocol_version . as_str ( ) ) {
353+ new_protocol_headers
354+ . insert ( HeaderName :: from_static ( "mcp-protocol-version" ) , hv) ;
355+ }
356+ }
357+ }
358+
359+ let initialized_notification = ClientJsonRpcMessage :: notification (
360+ ClientNotification :: InitializedNotification ( InitializedNotification {
361+ method : Default :: default ( ) ,
362+ extensions : Default :: default ( ) ,
363+ } ) ,
364+ ) ;
365+ client
366+ . post_message (
367+ uri,
368+ initialized_notification,
369+ new_session_id. clone ( ) ,
370+ auth_header,
371+ new_protocol_headers. clone ( ) ,
372+ )
373+ . await ?
374+ . expect_accepted_or_json :: < C :: Error > ( ) ?;
375+
376+ Ok ( ( new_session_id, new_protocol_headers) )
377+ }
310378}
311379
312380impl < C : StreamableHttpClient > Worker for StreamableHttpClientWorker < C > {
@@ -338,14 +406,15 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
338406 responder,
339407 message : initialize_request,
340408 } = context. recv_from_handler ( ) . await ?;
409+ let saved_init_request = initialize_request. clone ( ) ;
341410 let ( message, session_id) = match self
342411 . client
343412 . post_message (
344413 config. uri . clone ( ) ,
345414 initialize_request,
346415 None ,
347- self . config . auth_header ,
348- self . config . custom_headers ,
416+ config. auth_header . clone ( ) ,
417+ config. custom_headers . clone ( ) ,
349418 )
350419 . await
351420 {
@@ -364,7 +433,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
364433 ) ) ;
365434 }
366435 } ;
367- let session_id: Option < Arc < str > > = if let Some ( session_id) = session_id {
436+ let mut session_id: Option < Arc < str > > = if let Some ( session_id) = session_id {
368437 Some ( session_id. into ( ) )
369438 } else {
370439 if !self . config . allow_stateless {
@@ -378,7 +447,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
378447 // Extract the negotiated protocol version from the init response
379448 // and build a custom headers map that includes MCP-Protocol-Version
380449 // for all subsequent HTTP requests (per MCP 2025-06-18 spec).
381- let protocol_headers = {
450+ let mut protocol_headers = {
382451 let mut headers = config. custom_headers . clone ( ) ;
383452 if let ServerJsonRpcMessage :: Response ( response) = & message {
384453 if let ServerResult :: InitializeResult ( init_result) = & response. result {
@@ -392,7 +461,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
392461 } ;
393462
394463 // Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns)
395- let session_cleanup_info = session_id. as_ref ( ) . map ( |sid| SessionCleanupInfo {
464+ let mut session_cleanup_info = session_id. as_ref ( ) . map ( |sid| SessionCleanupInfo {
396465 client : self . client . clone ( ) ,
397466 uri : config. uri . clone ( ) ,
398467 session_id : sid. clone ( ) ,
@@ -516,17 +585,171 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
516585 match event {
517586 Event :: ClientMessage ( send_request) => {
518587 let WorkerSendRequest { message, responder } = send_request;
588+ // Pass a clone to the first attempt so `message` is retained for a
589+ // potential re-init retry. `post_message` takes ownership and the
590+ // trait cannot be changed, so the clone is unavoidable.
519591 let response = self
520592 . client
521593 . post_message (
522594 config. uri . clone ( ) ,
523- message,
595+ message. clone ( ) ,
524596 session_id. clone ( ) ,
525597 config. auth_header . clone ( ) ,
526598 protocol_headers. clone ( ) ,
527599 )
528600 . await ;
529601 let send_result = match response {
602+ Err ( StreamableHttpError :: SessionExpired ) => {
603+ // The server discarded the session (HTTP 404). Perform a
604+ // fresh handshake once and replay the original message.
605+ tracing:: info!(
606+ "session expired (HTTP 404), attempting transparent re-initialization"
607+ ) ;
608+ match Self :: perform_reinitialization (
609+ self . client . clone ( ) ,
610+ saved_init_request. clone ( ) ,
611+ config. uri . clone ( ) ,
612+ config. auth_header . clone ( ) ,
613+ config. custom_headers . clone ( ) ,
614+ )
615+ . await
616+ {
617+ Ok ( ( new_session_id, new_protocol_headers) ) => {
618+ // Old streams hold the stale session ID; abort them
619+ // so the new standalone SSE stream takes over.
620+ streams. abort_all ( ) ;
621+
622+ session_id = new_session_id;
623+ protocol_headers = new_protocol_headers;
624+ session_cleanup_info =
625+ session_id. as_ref ( ) . map ( |sid| SessionCleanupInfo {
626+ client : self . client . clone ( ) ,
627+ uri : config. uri . clone ( ) ,
628+ session_id : sid. clone ( ) ,
629+ auth_header : config. auth_header . clone ( ) ,
630+ protocol_headers : protocol_headers. clone ( ) ,
631+ } ) ;
632+
633+ if let Some ( new_sid) = & session_id {
634+ let client = self . client . clone ( ) ;
635+ let uri = config. uri . clone ( ) ;
636+ let new_sid = new_sid. clone ( ) ;
637+ let auth_header = config. auth_header . clone ( ) ;
638+ let retry_config = self . config . retry_config . clone ( ) ;
639+ let sse_tx = sse_worker_tx. clone ( ) ;
640+ let task_ct = transport_task_ct. clone ( ) ;
641+ let config_uri = config. uri . clone ( ) ;
642+ let config_auth = config. auth_header . clone ( ) ;
643+ let spawn_headers = protocol_headers. clone ( ) ;
644+ streams. spawn ( async move {
645+ match client
646+ . get_stream (
647+ uri,
648+ new_sid. clone ( ) ,
649+ None ,
650+ auth_header. clone ( ) ,
651+ spawn_headers. clone ( ) ,
652+ )
653+ . await
654+ {
655+ Ok ( stream) => {
656+ let sse_stream = SseAutoReconnectStream :: new (
657+ stream,
658+ StreamableHttpClientReconnect {
659+ client : client. clone ( ) ,
660+ session_id : new_sid,
661+ uri : config_uri,
662+ auth_header : config_auth,
663+ custom_headers : spawn_headers,
664+ } ,
665+ retry_config,
666+ ) ;
667+ Self :: execute_sse_stream (
668+ sse_stream,
669+ sse_tx,
670+ false ,
671+ task_ct. child_token ( ) ,
672+ )
673+ . await
674+ }
675+ Err ( StreamableHttpError :: ServerDoesNotSupportSse ) => {
676+ tracing:: debug!(
677+ "server doesn't support sse after re-init"
678+ ) ;
679+ Ok ( ( ) )
680+ }
681+ Err ( e) => {
682+ tracing:: error!(
683+ "fail to get common stream after re-init: {e}"
684+ ) ;
685+ Err ( e)
686+ }
687+ }
688+ } ) ;
689+ }
690+
691+ let retry_response = self
692+ . client
693+ . post_message (
694+ config. uri . clone ( ) ,
695+ message,
696+ session_id. clone ( ) ,
697+ config. auth_header . clone ( ) ,
698+ protocol_headers. clone ( ) ,
699+ )
700+ . await ;
701+ match retry_response {
702+ Err ( e) => Err ( e) ,
703+ Ok ( StreamableHttpPostResponse :: Accepted ) => {
704+ tracing:: trace!(
705+ "client message accepted after re-init"
706+ ) ;
707+ Ok ( ( ) )
708+ }
709+ Ok ( StreamableHttpPostResponse :: Json ( msg, ..) ) => {
710+ context. send_to_handler ( msg) . await ?;
711+ Ok ( ( ) )
712+ }
713+ Ok ( StreamableHttpPostResponse :: Sse ( stream, ..) ) => {
714+ if let Some ( sid) = & session_id {
715+ let sse_stream = SseAutoReconnectStream :: new (
716+ stream,
717+ StreamableHttpClientReconnect {
718+ client : self . client . clone ( ) ,
719+ session_id : sid. clone ( ) ,
720+ uri : config. uri . clone ( ) ,
721+ auth_header : config. auth_header . clone ( ) ,
722+ custom_headers : protocol_headers. clone ( ) ,
723+ } ,
724+ self . config . retry_config . clone ( ) ,
725+ ) ;
726+ streams. spawn ( Self :: execute_sse_stream (
727+ sse_stream,
728+ sse_worker_tx. clone ( ) ,
729+ true ,
730+ transport_task_ct. child_token ( ) ,
731+ ) ) ;
732+ } else {
733+ let sse_stream =
734+ SseAutoReconnectStream :: never_reconnect (
735+ stream,
736+ StreamableHttpError :: < C :: Error > :: UnexpectedEndOfStream ,
737+ ) ;
738+ streams. spawn ( Self :: execute_sse_stream (
739+ sse_stream,
740+ sse_worker_tx. clone ( ) ,
741+ true ,
742+ transport_task_ct. child_token ( ) ,
743+ ) ) ;
744+ }
745+ tracing:: trace!( "got new sse stream after re-init" ) ;
746+ Ok ( ( ) )
747+ }
748+ }
749+ }
750+ Err ( reinit_err) => Err ( reinit_err) ,
751+ }
752+ }
530753 Err ( e) => Err ( e) ,
531754 Ok ( StreamableHttpPostResponse :: Accepted ) => {
532755 tracing:: trace!( "client message accepted" ) ;
0 commit comments