Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 114 additions & 79 deletions codex-rs/exec-server/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::OnceLock;
use std::time::Duration;

use arc_swap::ArcSwap;
Expand Down Expand Up @@ -140,6 +141,10 @@ struct Inner {
// need serialization so concurrent register/remove operations do not
// overwrite each other's copy-on-write updates.
sessions_write_lock: Mutex<()>,
// Once the transport closes, every executor operation should fail quickly
// with the same canonical message. This client never reconnects, so the
// latch only moves from unset to set once.
disconnected: OnceLock<String>,
session_id: std::sync::RwLock<Option<String>>,
reader_task: tokio::task::JoinHandle<()>,
}
Expand Down Expand Up @@ -171,6 +176,8 @@ pub enum ExecServerError {
InitializeTimedOut { timeout: Duration },
#[error("exec-server transport closed")]
Closed,
#[error("{0}")]
Disconnected(String),
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
Json(#[from] serde_json::Error),
#[error("exec-server protocol error: {0}")]
Expand Down Expand Up @@ -246,127 +253,85 @@ impl ExecServerClient {
}

pub async fn exec(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
self.inner
.client
.call(EXEC_METHOD, &params)
.await
.map_err(Into::into)
self.call(EXEC_METHOD, &params).await
}

pub async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
self.inner
.client
.call(EXEC_READ_METHOD, &params)
.await
.map_err(Into::into)
self.call(EXEC_READ_METHOD, &params).await
}

pub async fn write(
&self,
process_id: &ProcessId,
chunk: Vec<u8>,
) -> Result<WriteResponse, ExecServerError> {
self.inner
.client
.call(
EXEC_WRITE_METHOD,
&WriteParams {
process_id: process_id.clone(),
chunk: chunk.into(),
},
)
.await
.map_err(Into::into)
self.call(
EXEC_WRITE_METHOD,
&WriteParams {
process_id: process_id.clone(),
chunk: chunk.into(),
},
)
.await
}

pub async fn terminate(
&self,
process_id: &ProcessId,
) -> Result<TerminateResponse, ExecServerError> {
self.inner
.client
.call(
EXEC_TERMINATE_METHOD,
&TerminateParams {
process_id: process_id.clone(),
},
)
.await
.map_err(Into::into)
self.call(
EXEC_TERMINATE_METHOD,
&TerminateParams {
process_id: process_id.clone(),
},
)
.await
}

pub async fn fs_read_file(
&self,
params: FsReadFileParams,
) -> Result<FsReadFileResponse, ExecServerError> {
self.inner
.client
.call(FS_READ_FILE_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_READ_FILE_METHOD, &params).await
}

pub async fn fs_write_file(
&self,
params: FsWriteFileParams,
) -> Result<FsWriteFileResponse, ExecServerError> {
self.inner
.client
.call(FS_WRITE_FILE_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_WRITE_FILE_METHOD, &params).await
}

pub async fn fs_create_directory(
&self,
params: FsCreateDirectoryParams,
) -> Result<FsCreateDirectoryResponse, ExecServerError> {
self.inner
.client
.call(FS_CREATE_DIRECTORY_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_CREATE_DIRECTORY_METHOD, &params).await
}

pub async fn fs_get_metadata(
&self,
params: FsGetMetadataParams,
) -> Result<FsGetMetadataResponse, ExecServerError> {
self.inner
.client
.call(FS_GET_METADATA_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_GET_METADATA_METHOD, &params).await
}

pub async fn fs_read_directory(
&self,
params: FsReadDirectoryParams,
) -> Result<FsReadDirectoryResponse, ExecServerError> {
self.inner
.client
.call(FS_READ_DIRECTORY_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_READ_DIRECTORY_METHOD, &params).await
}

pub async fn fs_remove(
&self,
params: FsRemoveParams,
) -> Result<FsRemoveResponse, ExecServerError> {
self.inner
.client
.call(FS_REMOVE_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_REMOVE_METHOD, &params).await
}

pub async fn fs_copy(&self, params: FsCopyParams) -> Result<FsCopyResponse, ExecServerError> {
self.inner
.client
.call(FS_COPY_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_COPY_METHOD, &params).await
}

pub(crate) async fn register_session(
Expand Down Expand Up @@ -411,18 +376,21 @@ impl ExecServerClient {
&& let Err(err) =
handle_server_notification(&inner, notification).await
{
fail_all_sessions(
let message = record_disconnected(
&inner,
format!("exec-server notification handling failed: {err}"),
)
.await;
);
fail_all_sessions(&inner, message).await;
return;
}
}
RpcClientEvent::Disconnected { reason } => {
if let Some(inner) = weak.upgrade() {
fail_all_sessions(&inner, disconnected_message(reason.as_deref()))
.await;
let message = record_disconnected(
&inner,
disconnected_message(reason.as_deref()),
);
fail_all_sessions(&inner, message).await;
}
return;
}
Expand All @@ -434,6 +402,7 @@ impl ExecServerClient {
client: rpc_client,
sessions: ArcSwap::from_pointee(HashMap::new()),
sessions_write_lock: Mutex::new(()),
disconnected: OnceLock::new(),
session_id: std::sync::RwLock::new(None),
reader_task,
}
Expand All @@ -451,6 +420,36 @@ impl ExecServerClient {
.await
.map_err(ExecServerError::Json)
}

async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, ExecServerError>
where
P: serde::Serialize,
T: serde::de::DeserializeOwned,
{
// Reject new work before allocating a JSON-RPC request id. MCP tool
// calls, process writes, and fs operations all pass through here, so
// this is the shared low-level failure path after executor disconnect.
if let Some(error) = self.inner.disconnected_error() {
return Err(error);
}

match self.inner.client.call(method, params).await {
Ok(response) => Ok(response),
Err(error) => {
let error = ExecServerError::from(error);
if is_transport_closed_error(&error) {
// A call can race with disconnect after the preflight
// check. Only the reader task drains sessions so queued
// process notifications stay ordered before disconnect.
let message = disconnected_message(/*reason*/ None);
let message = record_disconnected(&self.inner, message);
Err(ExecServerError::Disconnected(message))
} else {
Err(error)
}
}
}
}
}

impl From<RpcCallError> for ExecServerError {
Expand Down Expand Up @@ -630,6 +629,20 @@ impl Session {
}

impl Inner {
fn disconnected_error(&self) -> Option<ExecServerError> {
self.disconnected
.get()
.cloned()
.map(ExecServerError::Disconnected)
}

fn set_disconnected(&self, message: String) -> Option<String> {
match self.disconnected.set(message.clone()) {
Ok(()) => Some(message),
Err(_) => None,
}
}

fn get_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
self.sessions.load().get(process_id).cloned()
}
Expand All @@ -640,6 +653,12 @@ impl Inner {
session: Arc<SessionState>,
) -> Result<(), ExecServerError> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
// Do not register a process session that can never receive executor
// notifications. Without this check, remote MCP startup could create a
// dead session and wait for process output that will never arrive.
if let Some(error) = self.disconnected_error() {
return Err(error);
}
let sessions = self.sessions.load();
if sessions.contains_key(process_id) {
return Err(ExecServerError::Protocol(format!(
Expand Down Expand Up @@ -680,20 +699,36 @@ fn disconnected_message(reason: Option<&str>) -> String {
}

fn is_transport_closed_error(error: &ExecServerError) -> bool {
matches!(error, ExecServerError::Closed)
|| matches!(
error,
ExecServerError::Server {
code: -32000,
message,
} if message == "JSON-RPC transport closed"
)
matches!(
error,
ExecServerError::Closed | ExecServerError::Disconnected(_)
) || matches!(
error,
ExecServerError::Server {
code: -32000,
message,
} if message == "JSON-RPC transport closed"
)
}

fn record_disconnected(inner: &Arc<Inner>, message: String) -> String {
// The first observer records the canonical disconnect reason. Session
// draining stays with the reader task so it can preserve notification
// ordering before publishing the terminal failure.
if let Some(message) = inner.set_disconnected(message.clone()) {
message
} else {
inner.disconnected.get().cloned().unwrap_or(message)
}
}

async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
let sessions = inner.take_all_sessions().await;

for (_, session) in sessions {
// Sessions synthesize a closed read response and emit a pushed Failed
// event. That covers both polling consumers and streaming consumers
// such as executor-backed MCP stdio.
session.set_failure(message.clone()).await;
}
}
Expand Down
39 changes: 38 additions & 1 deletion codex-rs/exec-server/src/remote_file_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,46 @@ fn map_remote_error(error: ExecServerError) -> io::Error {
io::Error::new(io::ErrorKind::InvalidInput, message)
}
ExecServerError::Server { message, .. } => io::Error::other(message),
ExecServerError::Closed => {
ExecServerError::Closed | ExecServerError::Disconnected(_) => {
io::Error::new(io::ErrorKind::BrokenPipe, "exec-server transport closed")
}
_ => io::Error::other(error.to_string()),
}
}

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;

use super::*;

#[test]
fn transport_errors_map_to_broken_pipe() {
let errors = [
ExecServerError::Closed,
ExecServerError::Disconnected("exec-server transport disconnected".to_string()),
];

let mapped_errors = errors
.into_iter()
.map(|error| {
let error = map_remote_error(error);
(error.kind(), error.to_string())
})
.collect::<Vec<_>>();

assert_eq!(
mapped_errors,
vec![
(
io::ErrorKind::BrokenPipe,
"exec-server transport closed".to_string()
),
(
io::ErrorKind::BrokenPipe,
"exec-server transport closed".to_string()
),
]
);
}
}
Loading
Loading