Skip to content

Commit 27b0096

Browse files
authored
feat: transparent session re-init on HTTP 404 (#743)
1 parent 5322430 commit 27b0096

File tree

3 files changed

+321
-19
lines changed

3 files changed

+321
-19
lines changed

crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ impl StreamableHttpClient for reqwest::Client {
144144
}
145145

146146
request = apply_custom_headers(request, custom_headers)?;
147+
let session_was_attached = session_id.is_some();
147148
if let Some(session_id) = session_id {
148149
request = request.header(HEADER_SESSION_ID, session_id.as_ref());
149150
}
@@ -186,6 +187,9 @@ impl StreamableHttpClient for reqwest::Client {
186187
) {
187188
return Ok(StreamableHttpPostResponse::Accepted);
188189
}
190+
if status == reqwest::StatusCode::NOT_FOUND && session_was_attached {
191+
return Err(StreamableHttpError::SessionExpired);
192+
}
189193
if !status.is_success() {
190194
let body = response
191195
.text()

crates/rmcp/src/transport/streamable_http_client.rs

Lines changed: 230 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ use tracing::debug;
1111
use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
1212
use crate::{
1313
RoleClient,
14-
model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult},
14+
model::{
15+
ClientJsonRpcMessage, ClientNotification, InitializedNotification, ServerJsonRpcMessage,
16+
ServerResult,
17+
},
1518
transport::{
1619
common::client_side_sse::SseAutoReconnectStream,
1720
worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
@@ -79,6 +82,8 @@ pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
7982
InsufficientScope(InsufficientScopeError),
8083
#[error("Header name '{0}' is reserved and conflicts with default headers")]
8184
ReservedHeaderConflict(String),
85+
#[error("Session expired (HTTP 404)")]
86+
SessionExpired,
8287
}
8388

8489
#[derive(Debug, Clone, Error)]
@@ -307,6 +312,69 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
307312
}
308313
Ok(())
309314
}
315+
316+
/// Performs a transparent re-initialization handshake after a session-expired 404.
317+
///
318+
/// Takes an owned clone of the client (avoiding `&self` across `.await` so the
319+
/// future remains `Send` without requiring `C: Sync`). POSTs the saved
320+
/// initialize request without a session ID, extracts the new session ID and
321+
/// protocol version, sends `notifications/initialized`, and returns the new
322+
/// `(session_id, protocol_headers)` pair. The init result message is **not**
323+
/// forwarded to the handler because the handler already processed the original
324+
/// initialization.
325+
async fn perform_reinitialization(
326+
client: C,
327+
saved_init_request: ClientJsonRpcMessage,
328+
uri: Arc<str>,
329+
auth_header: Option<String>,
330+
custom_headers: HashMap<HeaderName, HeaderValue>,
331+
) -> Result<(Option<Arc<str>>, HashMap<HeaderName, HeaderValue>), StreamableHttpError<C::Error>>
332+
{
333+
let (init_msg, new_session_id_str) = client
334+
.post_message(
335+
uri.clone(),
336+
saved_init_request,
337+
None,
338+
auth_header.clone(),
339+
custom_headers.clone(),
340+
)
341+
.await?
342+
.expect_initialized::<C::Error>()
343+
.await?;
344+
345+
let new_session_id: Option<Arc<str>> = new_session_id_str.map(|s| Arc::from(s.as_str()));
346+
347+
// Start from custom_headers, then inject the negotiated MCP-Protocol-Version
348+
// so all subsequent requests carry the right version (MCP 2025-06-18 spec).
349+
let mut new_protocol_headers = custom_headers;
350+
if let ServerJsonRpcMessage::Response(response) = &init_msg {
351+
if let ServerResult::InitializeResult(init_result) = &response.result {
352+
if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
353+
new_protocol_headers
354+
.insert(HeaderName::from_static("mcp-protocol-version"), hv);
355+
}
356+
}
357+
}
358+
359+
let initialized_notification = ClientJsonRpcMessage::notification(
360+
ClientNotification::InitializedNotification(InitializedNotification {
361+
method: Default::default(),
362+
extensions: Default::default(),
363+
}),
364+
);
365+
client
366+
.post_message(
367+
uri,
368+
initialized_notification,
369+
new_session_id.clone(),
370+
auth_header,
371+
new_protocol_headers.clone(),
372+
)
373+
.await?
374+
.expect_accepted_or_json::<C::Error>()?;
375+
376+
Ok((new_session_id, new_protocol_headers))
377+
}
310378
}
311379

312380
impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
@@ -338,14 +406,15 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
338406
responder,
339407
message: initialize_request,
340408
} = context.recv_from_handler().await?;
409+
let saved_init_request = initialize_request.clone();
341410
let (message, session_id) = match self
342411
.client
343412
.post_message(
344413
config.uri.clone(),
345414
initialize_request,
346415
None,
347-
self.config.auth_header,
348-
self.config.custom_headers,
416+
config.auth_header.clone(),
417+
config.custom_headers.clone(),
349418
)
350419
.await
351420
{
@@ -364,7 +433,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
364433
));
365434
}
366435
};
367-
let session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
436+
let mut session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
368437
Some(session_id.into())
369438
} else {
370439
if !self.config.allow_stateless {
@@ -378,7 +447,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
378447
// Extract the negotiated protocol version from the init response
379448
// and build a custom headers map that includes MCP-Protocol-Version
380449
// for all subsequent HTTP requests (per MCP 2025-06-18 spec).
381-
let protocol_headers = {
450+
let mut protocol_headers = {
382451
let mut headers = config.custom_headers.clone();
383452
if let ServerJsonRpcMessage::Response(response) = &message {
384453
if let ServerResult::InitializeResult(init_result) = &response.result {
@@ -392,7 +461,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
392461
};
393462

394463
// Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns)
395-
let session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
464+
let mut session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
396465
client: self.client.clone(),
397466
uri: config.uri.clone(),
398467
session_id: sid.clone(),
@@ -516,17 +585,171 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
516585
match event {
517586
Event::ClientMessage(send_request) => {
518587
let WorkerSendRequest { message, responder } = send_request;
588+
// Pass a clone to the first attempt so `message` is retained for a
589+
// potential re-init retry. `post_message` takes ownership and the
590+
// trait cannot be changed, so the clone is unavoidable.
519591
let response = self
520592
.client
521593
.post_message(
522594
config.uri.clone(),
523-
message,
595+
message.clone(),
524596
session_id.clone(),
525597
config.auth_header.clone(),
526598
protocol_headers.clone(),
527599
)
528600
.await;
529601
let send_result = match response {
602+
Err(StreamableHttpError::SessionExpired) => {
603+
// The server discarded the session (HTTP 404). Perform a
604+
// fresh handshake once and replay the original message.
605+
tracing::info!(
606+
"session expired (HTTP 404), attempting transparent re-initialization"
607+
);
608+
match Self::perform_reinitialization(
609+
self.client.clone(),
610+
saved_init_request.clone(),
611+
config.uri.clone(),
612+
config.auth_header.clone(),
613+
config.custom_headers.clone(),
614+
)
615+
.await
616+
{
617+
Ok((new_session_id, new_protocol_headers)) => {
618+
// Old streams hold the stale session ID; abort them
619+
// so the new standalone SSE stream takes over.
620+
streams.abort_all();
621+
622+
session_id = new_session_id;
623+
protocol_headers = new_protocol_headers;
624+
session_cleanup_info =
625+
session_id.as_ref().map(|sid| SessionCleanupInfo {
626+
client: self.client.clone(),
627+
uri: config.uri.clone(),
628+
session_id: sid.clone(),
629+
auth_header: config.auth_header.clone(),
630+
protocol_headers: protocol_headers.clone(),
631+
});
632+
633+
if let Some(new_sid) = &session_id {
634+
let client = self.client.clone();
635+
let uri = config.uri.clone();
636+
let new_sid = new_sid.clone();
637+
let auth_header = config.auth_header.clone();
638+
let retry_config = self.config.retry_config.clone();
639+
let sse_tx = sse_worker_tx.clone();
640+
let task_ct = transport_task_ct.clone();
641+
let config_uri = config.uri.clone();
642+
let config_auth = config.auth_header.clone();
643+
let spawn_headers = protocol_headers.clone();
644+
streams.spawn(async move {
645+
match client
646+
.get_stream(
647+
uri,
648+
new_sid.clone(),
649+
None,
650+
auth_header.clone(),
651+
spawn_headers.clone(),
652+
)
653+
.await
654+
{
655+
Ok(stream) => {
656+
let sse_stream = SseAutoReconnectStream::new(
657+
stream,
658+
StreamableHttpClientReconnect {
659+
client: client.clone(),
660+
session_id: new_sid,
661+
uri: config_uri,
662+
auth_header: config_auth,
663+
custom_headers: spawn_headers,
664+
},
665+
retry_config,
666+
);
667+
Self::execute_sse_stream(
668+
sse_stream,
669+
sse_tx,
670+
false,
671+
task_ct.child_token(),
672+
)
673+
.await
674+
}
675+
Err(StreamableHttpError::ServerDoesNotSupportSse) => {
676+
tracing::debug!(
677+
"server doesn't support sse after re-init"
678+
);
679+
Ok(())
680+
}
681+
Err(e) => {
682+
tracing::error!(
683+
"fail to get common stream after re-init: {e}"
684+
);
685+
Err(e)
686+
}
687+
}
688+
});
689+
}
690+
691+
let retry_response = self
692+
.client
693+
.post_message(
694+
config.uri.clone(),
695+
message,
696+
session_id.clone(),
697+
config.auth_header.clone(),
698+
protocol_headers.clone(),
699+
)
700+
.await;
701+
match retry_response {
702+
Err(e) => Err(e),
703+
Ok(StreamableHttpPostResponse::Accepted) => {
704+
tracing::trace!(
705+
"client message accepted after re-init"
706+
);
707+
Ok(())
708+
}
709+
Ok(StreamableHttpPostResponse::Json(msg, ..)) => {
710+
context.send_to_handler(msg).await?;
711+
Ok(())
712+
}
713+
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
714+
if let Some(sid) = &session_id {
715+
let sse_stream = SseAutoReconnectStream::new(
716+
stream,
717+
StreamableHttpClientReconnect {
718+
client: self.client.clone(),
719+
session_id: sid.clone(),
720+
uri: config.uri.clone(),
721+
auth_header: config.auth_header.clone(),
722+
custom_headers: protocol_headers.clone(),
723+
},
724+
self.config.retry_config.clone(),
725+
);
726+
streams.spawn(Self::execute_sse_stream(
727+
sse_stream,
728+
sse_worker_tx.clone(),
729+
true,
730+
transport_task_ct.child_token(),
731+
));
732+
} else {
733+
let sse_stream =
734+
SseAutoReconnectStream::never_reconnect(
735+
stream,
736+
StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
737+
);
738+
streams.spawn(Self::execute_sse_stream(
739+
sse_stream,
740+
sse_worker_tx.clone(),
741+
true,
742+
transport_task_ct.child_token(),
743+
));
744+
}
745+
tracing::trace!("got new sse stream after re-init");
746+
Ok(())
747+
}
748+
}
749+
}
750+
Err(reinit_err) => Err(reinit_err),
751+
}
752+
}
530753
Err(e) => Err(e),
531754
Ok(StreamableHttpPostResponse::Accepted) => {
532755
tracing::trace!("client message accepted");

0 commit comments

Comments
 (0)