Skip to content

Commit d73da80

Browse files
aibrahim-oaicodex
andcommitted
Fail exec client operations after disconnect
Reject new exec-server operations once the transport disconnects and convert pending RPC calls into closed errors. This lets remote MCP stdio calls surface executor loss immediately instead of waiting for the MCP tool timeout. Co-authored-by: Codex <noreply@openai.com>
1 parent aa58931 commit d73da80

3 files changed

Lines changed: 162 additions & 94 deletions

File tree

codex-rs/exec-server/src/client.rs

Lines changed: 101 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ struct Inner {
126126
// need serialization so concurrent register/remove operations do not
127127
// overwrite each other's copy-on-write updates.
128128
sessions_write_lock: Mutex<()>,
129+
disconnected: std::sync::RwLock<Option<String>>,
129130
session_id: std::sync::RwLock<Option<String>>,
130131
reader_task: tokio::task::JoinHandle<()>,
131132
}
@@ -157,6 +158,8 @@ pub enum ExecServerError {
157158
InitializeTimedOut { timeout: Duration },
158159
#[error("exec-server transport closed")]
159160
Closed,
161+
#[error("{0}")]
162+
Disconnected(String),
160163
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
161164
Json(#[from] serde_json::Error),
162165
#[error("exec-server protocol error: {0}")]
@@ -232,127 +235,85 @@ impl ExecServerClient {
232235
}
233236

234237
pub async fn exec(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
235-
self.inner
236-
.client
237-
.call(EXEC_METHOD, &params)
238-
.await
239-
.map_err(Into::into)
238+
self.call(EXEC_METHOD, &params).await
240239
}
241240

242241
pub async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
243-
self.inner
244-
.client
245-
.call(EXEC_READ_METHOD, &params)
246-
.await
247-
.map_err(Into::into)
242+
self.call(EXEC_READ_METHOD, &params).await
248243
}
249244

250245
pub async fn write(
251246
&self,
252247
process_id: &ProcessId,
253248
chunk: Vec<u8>,
254249
) -> Result<WriteResponse, ExecServerError> {
255-
self.inner
256-
.client
257-
.call(
258-
EXEC_WRITE_METHOD,
259-
&WriteParams {
260-
process_id: process_id.clone(),
261-
chunk: chunk.into(),
262-
},
263-
)
264-
.await
265-
.map_err(Into::into)
250+
self.call(
251+
EXEC_WRITE_METHOD,
252+
&WriteParams {
253+
process_id: process_id.clone(),
254+
chunk: chunk.into(),
255+
},
256+
)
257+
.await
266258
}
267259

268260
pub async fn terminate(
269261
&self,
270262
process_id: &ProcessId,
271263
) -> Result<TerminateResponse, ExecServerError> {
272-
self.inner
273-
.client
274-
.call(
275-
EXEC_TERMINATE_METHOD,
276-
&TerminateParams {
277-
process_id: process_id.clone(),
278-
},
279-
)
280-
.await
281-
.map_err(Into::into)
264+
self.call(
265+
EXEC_TERMINATE_METHOD,
266+
&TerminateParams {
267+
process_id: process_id.clone(),
268+
},
269+
)
270+
.await
282271
}
283272

284273
pub async fn fs_read_file(
285274
&self,
286275
params: FsReadFileParams,
287276
) -> Result<FsReadFileResponse, ExecServerError> {
288-
self.inner
289-
.client
290-
.call(FS_READ_FILE_METHOD, &params)
291-
.await
292-
.map_err(Into::into)
277+
self.call(FS_READ_FILE_METHOD, &params).await
293278
}
294279

295280
pub async fn fs_write_file(
296281
&self,
297282
params: FsWriteFileParams,
298283
) -> Result<FsWriteFileResponse, ExecServerError> {
299-
self.inner
300-
.client
301-
.call(FS_WRITE_FILE_METHOD, &params)
302-
.await
303-
.map_err(Into::into)
284+
self.call(FS_WRITE_FILE_METHOD, &params).await
304285
}
305286

306287
pub async fn fs_create_directory(
307288
&self,
308289
params: FsCreateDirectoryParams,
309290
) -> Result<FsCreateDirectoryResponse, ExecServerError> {
310-
self.inner
311-
.client
312-
.call(FS_CREATE_DIRECTORY_METHOD, &params)
313-
.await
314-
.map_err(Into::into)
291+
self.call(FS_CREATE_DIRECTORY_METHOD, &params).await
315292
}
316293

317294
pub async fn fs_get_metadata(
318295
&self,
319296
params: FsGetMetadataParams,
320297
) -> Result<FsGetMetadataResponse, ExecServerError> {
321-
self.inner
322-
.client
323-
.call(FS_GET_METADATA_METHOD, &params)
324-
.await
325-
.map_err(Into::into)
298+
self.call(FS_GET_METADATA_METHOD, &params).await
326299
}
327300

328301
pub async fn fs_read_directory(
329302
&self,
330303
params: FsReadDirectoryParams,
331304
) -> Result<FsReadDirectoryResponse, ExecServerError> {
332-
self.inner
333-
.client
334-
.call(FS_READ_DIRECTORY_METHOD, &params)
335-
.await
336-
.map_err(Into::into)
305+
self.call(FS_READ_DIRECTORY_METHOD, &params).await
337306
}
338307

339308
pub async fn fs_remove(
340309
&self,
341310
params: FsRemoveParams,
342311
) -> Result<FsRemoveResponse, ExecServerError> {
343-
self.inner
344-
.client
345-
.call(FS_REMOVE_METHOD, &params)
346-
.await
347-
.map_err(Into::into)
312+
self.call(FS_REMOVE_METHOD, &params).await
348313
}
349314

350315
pub async fn fs_copy(&self, params: FsCopyParams) -> Result<FsCopyResponse, ExecServerError> {
351-
self.inner
352-
.client
353-
.call(FS_COPY_METHOD, &params)
354-
.await
355-
.map_err(Into::into)
316+
self.call(FS_COPY_METHOD, &params).await
356317
}
357318

358319
pub(crate) async fn register_session(
@@ -397,7 +358,7 @@ impl ExecServerClient {
397358
&& let Err(err) =
398359
handle_server_notification(&inner, notification).await
399360
{
400-
fail_all_sessions(
361+
mark_disconnected(
401362
&inner,
402363
format!("exec-server notification handling failed: {err}"),
403364
)
@@ -407,7 +368,7 @@ impl ExecServerClient {
407368
}
408369
RpcClientEvent::Disconnected { reason } => {
409370
if let Some(inner) = weak.upgrade() {
410-
fail_all_sessions(&inner, disconnected_message(reason.as_deref()))
371+
mark_disconnected(&inner, disconnected_message(reason.as_deref()))
411372
.await;
412373
}
413374
return;
@@ -420,6 +381,7 @@ impl ExecServerClient {
420381
client: rpc_client,
421382
sessions: ArcSwap::from_pointee(HashMap::new()),
422383
sessions_write_lock: Mutex::new(()),
384+
disconnected: std::sync::RwLock::new(None),
423385
session_id: std::sync::RwLock::new(None),
424386
reader_task,
425387
}
@@ -437,6 +399,30 @@ impl ExecServerClient {
437399
.await
438400
.map_err(ExecServerError::Json)
439401
}
402+
403+
async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, ExecServerError>
404+
where
405+
P: serde::Serialize,
406+
T: serde::de::DeserializeOwned,
407+
{
408+
if let Some(error) = self.inner.disconnected_error() {
409+
return Err(error);
410+
}
411+
412+
match self.inner.client.call(method, params).await {
413+
Ok(response) => Ok(response),
414+
Err(error) => {
415+
let error = ExecServerError::from(error);
416+
if is_transport_closed_error(&error) {
417+
let message = disconnected_message(/*reason*/ None);
418+
let message = mark_disconnected(&self.inner, message).await;
419+
Err(ExecServerError::Disconnected(message))
420+
} else {
421+
Err(error)
422+
}
423+
}
424+
}
425+
}
440426
}
441427

442428
impl From<RpcCallError> for ExecServerError {
@@ -573,6 +559,26 @@ impl Session {
573559
}
574560

575561
impl Inner {
562+
fn disconnected_error(&self) -> Option<ExecServerError> {
563+
self.disconnected
564+
.read()
565+
.unwrap_or_else(std::sync::PoisonError::into_inner)
566+
.clone()
567+
.map(ExecServerError::Disconnected)
568+
}
569+
570+
fn set_disconnected(&self, message: String) -> Option<String> {
571+
let mut disconnected = self
572+
.disconnected
573+
.write()
574+
.unwrap_or_else(std::sync::PoisonError::into_inner);
575+
if disconnected.is_some() {
576+
return None;
577+
}
578+
*disconnected = Some(message.clone());
579+
Some(message)
580+
}
581+
576582
fn get_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
577583
self.sessions.load().get(process_id).cloned()
578584
}
@@ -583,6 +589,9 @@ impl Inner {
583589
session: Arc<SessionState>,
584590
) -> Result<(), ExecServerError> {
585591
let _sessions_write_guard = self.sessions_write_lock.lock().await;
592+
if let Some(error) = self.disconnected_error() {
593+
return Err(error);
594+
}
586595
let sessions = self.sessions.load();
587596
if sessions.contains_key(process_id) {
588597
return Err(ExecServerError::Protocol(format!(
@@ -623,14 +632,30 @@ fn disconnected_message(reason: Option<&str>) -> String {
623632
}
624633

625634
fn is_transport_closed_error(error: &ExecServerError) -> bool {
626-
matches!(error, ExecServerError::Closed)
627-
|| matches!(
628-
error,
629-
ExecServerError::Server {
630-
code: -32000,
631-
message,
632-
} if message == "JSON-RPC transport closed"
633-
)
635+
matches!(
636+
error,
637+
ExecServerError::Closed | ExecServerError::Disconnected(_)
638+
) || matches!(
639+
error,
640+
ExecServerError::Server {
641+
code: -32000,
642+
message,
643+
} if message == "JSON-RPC transport closed"
644+
)
645+
}
646+
647+
async fn mark_disconnected(inner: &Arc<Inner>, message: String) -> String {
648+
if let Some(message) = inner.set_disconnected(message.clone()) {
649+
fail_all_sessions(inner, message.clone()).await;
650+
message
651+
} else {
652+
inner
653+
.disconnected
654+
.read()
655+
.unwrap_or_else(std::sync::PoisonError::into_inner)
656+
.clone()
657+
.unwrap_or(message)
658+
}
634659
}
635660

636661
async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {

0 commit comments

Comments
 (0)