diff --git a/src/codex_agent.rs b/src/codex_agent.rs index bea9c1e..577aff5 100644 --- a/src/codex_agent.rs +++ b/src/codex_agent.rs @@ -5,8 +5,8 @@ use agent_client_protocol::{ LoadSessionResponse, McpCapabilities, McpServer, McpServerHttp, McpServerStdio, NewSessionRequest, NewSessionResponse, PromptCapabilities, PromptRequest, PromptResponse, ProtocolVersion, SessionCapabilities, SessionId, SessionInfo, SessionListCapabilities, - SetSessionConfigOptionRequest, SetSessionConfigOptionResponse, SetSessionModeRequest, - SetSessionModeResponse, SetSessionModelRequest, SetSessionModelResponse, + SessionModeId, SetSessionConfigOptionRequest, SetSessionConfigOptionResponse, + SetSessionModeRequest, SetSessionModeResponse, SetSessionModelRequest, SetSessionModelResponse, }; use codex_core::{ CodexAuth, NewThread, RolloutRecorder, ThreadManager, ThreadSortKey, @@ -31,11 +31,12 @@ use std::{ rc::Rc, sync::{Arc, Mutex}, }; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; use unicode_segmentation::UnicodeSegmentation; use crate::{ local_spawner::{AcpFs, LocalSpawner}, + mode_overrides::ModeOverrideStore, thread::Thread, }; @@ -56,6 +57,8 @@ pub struct CodexAgent { sessions: Rc>>>, /// Session working directories for filesystem sandboxing session_roots: Arc>>, + /// Persists per-project session mode overrides across restarts + mode_override_store: ModeOverrideStore, } const SESSION_LIST_PAGE_SIZE: usize = 25; @@ -94,6 +97,7 @@ impl CodexAgent { )) }), ); + let mode_override_store = ModeOverrideStore::new(&config.codex_home); Self { auth_manager, client_capabilities, @@ -101,6 +105,7 @@ impl CodexAgent { thread_manager, sessions: Rc::default(), session_roots, + mode_override_store, } } @@ -358,6 +363,14 @@ impl Agent for CodexAgent { )); let load = thread.load().await?; + // Replay any mode override persisted for this project directory so that + // the client's last explicit mode choice survives server restarts. + if let Some(mode_id) = self.mode_override_store.get(&config.cwd) { + if let Err(e) = thread.set_mode(SessionModeId::new(mode_id)).await { + warn!("Failed to replay stored mode override: {e}"); + } + } + self.sessions .borrow_mut() .insert(session_id.clone(), thread); @@ -428,6 +441,14 @@ impl Agent for CodexAgent { let load = thread.load().await?; + // Replay any mode override persisted for this project directory so that + // the client's last explicit mode choice survives server restarts. + if let Some(mode_id) = self.mode_override_store.get(&config.cwd) { + if let Err(e) = thread.set_mode(SessionModeId::new(mode_id)).await { + warn!("Failed to replay stored mode override: {e}"); + } + } + self.session_roots .lock() .unwrap() @@ -524,10 +545,25 @@ impl Agent for CodexAgent { &self, args: SetSessionModeRequest, ) -> Result { - info!("Setting session mode for session: {}", args.session_id); - self.get_thread(&args.session_id)? - .set_mode(args.mode_id) - .await?; + let SetSessionModeRequest { + session_id, + mode_id, + .. + } = args; + info!("Setting session mode for session: {session_id}"); + self.get_thread(&session_id)?.set_mode(mode_id.clone()).await?; + + // Persist the mode override so it survives server restarts. + if let Some(cwd) = self + .session_roots + .lock() + .unwrap() + .get(&session_id) + .cloned() + { + self.mode_override_store.set(&cwd, mode_id.0.as_ref()); + } + Ok(SetSessionModeResponse::default()) } @@ -553,10 +589,31 @@ impl Agent for CodexAgent { args.session_id, args.config_id.0, args.value.0 ); + // If this is a mode change, capture the value before it is moved so we + // can persist it after the call succeeds. + let pending_mode_override: Option = if args.config_id.0.as_ref() == "mode" { + Some(format!("{}", args.value.0)) + } else { + None + }; + let thread = self.get_thread(&args.session_id)?; thread.set_config_option(args.config_id, args.value).await?; + // Persist mode override so it survives server restarts. + if let Some(mode_id) = pending_mode_override { + if let Some(cwd) = self + .session_roots + .lock() + .unwrap() + .get(&args.session_id) + .cloned() + { + self.mode_override_store.set(&cwd, &mode_id); + } + } + let config_options = thread.config_options().await?; Ok(SetSessionConfigOptionResponse::new(config_options)) diff --git a/src/lib.rs b/src/lib.rs index 7da85ed..c4f0e47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ use tracing_subscriber::EnvFilter; mod codex_agent; mod local_spawner; +mod mode_overrides; mod prompt_args; mod thread; diff --git a/src/mode_overrides.rs b/src/mode_overrides.rs new file mode 100644 index 0000000..0455147 --- /dev/null +++ b/src/mode_overrides.rs @@ -0,0 +1,80 @@ +//! Persistent storage for ACP session mode overrides. +//! +//! Mode overrides are persisted to `$CODEX_HOME/acp/session-mode-overrides.v1.json` +//! so that a client's explicit mode choice (e.g. `"auto"`, `"full-access"`) is +//! replayed automatically when the same project session is reloaded after a +//! server restart. +//! +//! The file format is a JSON object keyed by the project's working-directory +//! path: +//! +//! ```json +//! { +//! "/home/user/myproject": "auto", +//! "/home/user/other": "full-access" +//! } +//! ``` +//! +//! Errors are non-fatal: a warning is logged and the session continues with the +//! default mode. + +use std::{ + collections::HashMap, + path::{Path, PathBuf}, +}; + +use tracing::warn; + +/// Reads and writes per-project session mode overrides. +pub struct ModeOverrideStore { + path: PathBuf, +} + +impl ModeOverrideStore { + /// Create a new store backed by + /// `$codex_home/acp/session-mode-overrides.v1.json`. + pub fn new(codex_home: &Path) -> Self { + Self { + path: codex_home.join("acp").join("session-mode-overrides.v1.json"), + } + } + + /// Return the persisted mode ID for the given `cwd`, if any. + pub fn get(&self, cwd: &Path) -> Option { + let content = std::fs::read_to_string(&self.path).ok()?; + let map: HashMap = serde_json::from_str(&content) + .map_err(|e| warn!("Failed to parse mode overrides file: {e}")) + .ok()?; + map.get(cwd.to_string_lossy().as_ref()).cloned() + } + + /// Persist the given `mode_id` for `cwd`, creating the backing file if + /// necessary. + /// + /// Non-fatal: logs a warning and returns without panicking on any I/O or + /// serialisation error. + pub fn set(&self, cwd: &Path, mode_id: &str) { + let mut map: HashMap = std::fs::read_to_string(&self.path) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) + .unwrap_or_default(); + + map.insert(cwd.to_string_lossy().into_owned(), mode_id.to_owned()); + + if let Some(parent) = self.path.parent() { + if let Err(e) = std::fs::create_dir_all(parent) { + warn!("Failed to create acp directory for mode overrides: {e}"); + return; + } + } + + match serde_json::to_string_pretty(&map) { + Ok(content) => { + if let Err(e) = std::fs::write(&self.path, content) { + warn!("Failed to write mode overrides: {e}"); + } + } + Err(e) => warn!("Failed to serialise mode overrides: {e}"), + } + } +}