Skip to content

Commit 418679b

Browse files
aibrahim-oaicodex
andcommitted
Fail exec client operations after disconnect
Reject new exec-server operations once the transport disconnects and convert pending RPC calls into closed errors. This lets remote MCP stdio calls surface executor loss immediately instead of waiting for the MCP tool timeout. Co-authored-by: Codex <noreply@openai.com>
1 parent 580771d commit 418679b

File tree

3 files changed

+162
-94
lines changed

3 files changed

+162
-94
lines changed

codex-rs/exec-server/src/client.rs

Lines changed: 101 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ struct Inner {
127127
// need serialization so concurrent register/remove operations do not
128128
// overwrite each other's copy-on-write updates.
129129
sessions_write_lock: Mutex<()>,
130+
disconnected: std::sync::RwLock<Option<String>>,
130131
session_id: std::sync::RwLock<Option<String>>,
131132
reader_task: tokio::task::JoinHandle<()>,
132133
}
@@ -158,6 +159,8 @@ pub enum ExecServerError {
158159
InitializeTimedOut { timeout: Duration },
159160
#[error("exec-server transport closed")]
160161
Closed,
162+
#[error("{0}")]
163+
Disconnected(String),
161164
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
162165
Json(#[from] serde_json::Error),
163166
#[error("exec-server protocol error: {0}")]
@@ -233,127 +236,85 @@ impl ExecServerClient {
233236
}
234237

235238
pub async fn exec(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
236-
self.inner
237-
.client
238-
.call(EXEC_METHOD, &params)
239-
.await
240-
.map_err(Into::into)
239+
self.call(EXEC_METHOD, &params).await
241240
}
242241

243242
pub async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
244-
self.inner
245-
.client
246-
.call(EXEC_READ_METHOD, &params)
247-
.await
248-
.map_err(Into::into)
243+
self.call(EXEC_READ_METHOD, &params).await
249244
}
250245

251246
pub async fn write(
252247
&self,
253248
process_id: &ProcessId,
254249
chunk: Vec<u8>,
255250
) -> Result<WriteResponse, ExecServerError> {
256-
self.inner
257-
.client
258-
.call(
259-
EXEC_WRITE_METHOD,
260-
&WriteParams {
261-
process_id: process_id.clone(),
262-
chunk: chunk.into(),
263-
},
264-
)
265-
.await
266-
.map_err(Into::into)
251+
self.call(
252+
EXEC_WRITE_METHOD,
253+
&WriteParams {
254+
process_id: process_id.clone(),
255+
chunk: chunk.into(),
256+
},
257+
)
258+
.await
267259
}
268260

269261
pub async fn terminate(
270262
&self,
271263
process_id: &ProcessId,
272264
) -> Result<TerminateResponse, ExecServerError> {
273-
self.inner
274-
.client
275-
.call(
276-
EXEC_TERMINATE_METHOD,
277-
&TerminateParams {
278-
process_id: process_id.clone(),
279-
},
280-
)
281-
.await
282-
.map_err(Into::into)
265+
self.call(
266+
EXEC_TERMINATE_METHOD,
267+
&TerminateParams {
268+
process_id: process_id.clone(),
269+
},
270+
)
271+
.await
283272
}
284273

285274
pub async fn fs_read_file(
286275
&self,
287276
params: FsReadFileParams,
288277
) -> Result<FsReadFileResponse, ExecServerError> {
289-
self.inner
290-
.client
291-
.call(FS_READ_FILE_METHOD, &params)
292-
.await
293-
.map_err(Into::into)
278+
self.call(FS_READ_FILE_METHOD, &params).await
294279
}
295280

296281
pub async fn fs_write_file(
297282
&self,
298283
params: FsWriteFileParams,
299284
) -> Result<FsWriteFileResponse, ExecServerError> {
300-
self.inner
301-
.client
302-
.call(FS_WRITE_FILE_METHOD, &params)
303-
.await
304-
.map_err(Into::into)
285+
self.call(FS_WRITE_FILE_METHOD, &params).await
305286
}
306287

307288
pub async fn fs_create_directory(
308289
&self,
309290
params: FsCreateDirectoryParams,
310291
) -> Result<FsCreateDirectoryResponse, ExecServerError> {
311-
self.inner
312-
.client
313-
.call(FS_CREATE_DIRECTORY_METHOD, &params)
314-
.await
315-
.map_err(Into::into)
292+
self.call(FS_CREATE_DIRECTORY_METHOD, &params).await
316293
}
317294

318295
pub async fn fs_get_metadata(
319296
&self,
320297
params: FsGetMetadataParams,
321298
) -> Result<FsGetMetadataResponse, ExecServerError> {
322-
self.inner
323-
.client
324-
.call(FS_GET_METADATA_METHOD, &params)
325-
.await
326-
.map_err(Into::into)
299+
self.call(FS_GET_METADATA_METHOD, &params).await
327300
}
328301

329302
pub async fn fs_read_directory(
330303
&self,
331304
params: FsReadDirectoryParams,
332305
) -> Result<FsReadDirectoryResponse, ExecServerError> {
333-
self.inner
334-
.client
335-
.call(FS_READ_DIRECTORY_METHOD, &params)
336-
.await
337-
.map_err(Into::into)
306+
self.call(FS_READ_DIRECTORY_METHOD, &params).await
338307
}
339308

340309
pub async fn fs_remove(
341310
&self,
342311
params: FsRemoveParams,
343312
) -> Result<FsRemoveResponse, ExecServerError> {
344-
self.inner
345-
.client
346-
.call(FS_REMOVE_METHOD, &params)
347-
.await
348-
.map_err(Into::into)
313+
self.call(FS_REMOVE_METHOD, &params).await
349314
}
350315

351316
pub async fn fs_copy(&self, params: FsCopyParams) -> Result<FsCopyResponse, ExecServerError> {
352-
self.inner
353-
.client
354-
.call(FS_COPY_METHOD, &params)
355-
.await
356-
.map_err(Into::into)
317+
self.call(FS_COPY_METHOD, &params).await
357318
}
358319

359320
pub(crate) async fn register_session(
@@ -398,7 +359,7 @@ impl ExecServerClient {
398359
&& let Err(err) =
399360
handle_server_notification(&inner, notification).await
400361
{
401-
fail_all_sessions(
362+
mark_disconnected(
402363
&inner,
403364
format!("exec-server notification handling failed: {err}"),
404365
)
@@ -408,7 +369,7 @@ impl ExecServerClient {
408369
}
409370
RpcClientEvent::Disconnected { reason } => {
410371
if let Some(inner) = weak.upgrade() {
411-
fail_all_sessions(&inner, disconnected_message(reason.as_deref()))
372+
mark_disconnected(&inner, disconnected_message(reason.as_deref()))
412373
.await;
413374
}
414375
return;
@@ -421,6 +382,7 @@ impl ExecServerClient {
421382
client: rpc_client,
422383
sessions: ArcSwap::from_pointee(HashMap::new()),
423384
sessions_write_lock: Mutex::new(()),
385+
disconnected: std::sync::RwLock::new(None),
424386
session_id: std::sync::RwLock::new(None),
425387
reader_task,
426388
}
@@ -438,6 +400,30 @@ impl ExecServerClient {
438400
.await
439401
.map_err(ExecServerError::Json)
440402
}
403+
404+
async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, ExecServerError>
405+
where
406+
P: serde::Serialize,
407+
T: serde::de::DeserializeOwned,
408+
{
409+
if let Some(error) = self.inner.disconnected_error() {
410+
return Err(error);
411+
}
412+
413+
match self.inner.client.call(method, params).await {
414+
Ok(response) => Ok(response),
415+
Err(error) => {
416+
let error = ExecServerError::from(error);
417+
if is_transport_closed_error(&error) {
418+
let message = disconnected_message(/*reason*/ None);
419+
let message = mark_disconnected(&self.inner, message).await;
420+
Err(ExecServerError::Disconnected(message))
421+
} else {
422+
Err(error)
423+
}
424+
}
425+
}
426+
}
441427
}
442428

443429
impl From<RpcCallError> for ExecServerError {
@@ -573,6 +559,26 @@ impl Session {
573559
}
574560

575561
impl Inner {
562+
fn disconnected_error(&self) -> Option<ExecServerError> {
563+
self.disconnected
564+
.read()
565+
.unwrap_or_else(std::sync::PoisonError::into_inner)
566+
.clone()
567+
.map(ExecServerError::Disconnected)
568+
}
569+
570+
fn set_disconnected(&self, message: String) -> Option<String> {
571+
let mut disconnected = self
572+
.disconnected
573+
.write()
574+
.unwrap_or_else(std::sync::PoisonError::into_inner);
575+
if disconnected.is_some() {
576+
return None;
577+
}
578+
*disconnected = Some(message.clone());
579+
Some(message)
580+
}
581+
576582
fn get_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
577583
self.sessions.load().get(process_id).cloned()
578584
}
@@ -583,6 +589,9 @@ impl Inner {
583589
session: Arc<SessionState>,
584590
) -> Result<(), ExecServerError> {
585591
let _sessions_write_guard = self.sessions_write_lock.lock().await;
592+
if let Some(error) = self.disconnected_error() {
593+
return Err(error);
594+
}
586595
let sessions = self.sessions.load();
587596
if sessions.contains_key(process_id) {
588597
return Err(ExecServerError::Protocol(format!(
@@ -623,14 +632,30 @@ fn disconnected_message(reason: Option<&str>) -> String {
623632
}
624633

625634
fn is_transport_closed_error(error: &ExecServerError) -> bool {
626-
matches!(error, ExecServerError::Closed)
627-
|| matches!(
628-
error,
629-
ExecServerError::Server {
630-
code: -32000,
631-
message,
632-
} if message == "JSON-RPC transport closed"
633-
)
635+
matches!(
636+
error,
637+
ExecServerError::Closed | ExecServerError::Disconnected(_)
638+
) || matches!(
639+
error,
640+
ExecServerError::Server {
641+
code: -32000,
642+
message,
643+
} if message == "JSON-RPC transport closed"
644+
)
645+
}
646+
647+
async fn mark_disconnected(inner: &Arc<Inner>, message: String) -> String {
648+
if let Some(message) = inner.set_disconnected(message.clone()) {
649+
fail_all_sessions(inner, message.clone()).await;
650+
message
651+
} else {
652+
inner
653+
.disconnected
654+
.read()
655+
.unwrap_or_else(std::sync::PoisonError::into_inner)
656+
.clone()
657+
.unwrap_or(message)
658+
}
634659
}
635660

636661
async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {

0 commit comments

Comments
 (0)