Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions src/codex_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use agent_client_protocol::{
LoadSessionResponse, McpCapabilities, McpServer, McpServerHttp, McpServerStdio,
NewSessionRequest, NewSessionResponse, PromptCapabilities, PromptRequest, PromptResponse,
ProtocolVersion, SessionCapabilities, SessionId, SessionInfo, SessionListCapabilities,
SetSessionConfigOptionRequest, SetSessionConfigOptionResponse, SetSessionModeRequest,
SetSessionModeResponse, SetSessionModelRequest, SetSessionModelResponse,
SessionModeId, SetSessionConfigOptionRequest, SetSessionConfigOptionResponse,
SetSessionModeRequest, SetSessionModeResponse, SetSessionModelRequest, SetSessionModelResponse,
};
use codex_core::{
CodexAuth, NewThread, RolloutRecorder, ThreadManager, ThreadSortKey,
Expand All @@ -31,11 +31,12 @@ use std::{
rc::Rc,
sync::{Arc, Mutex},
};
use tracing::{debug, info};
use tracing::{debug, info, warn};
use unicode_segmentation::UnicodeSegmentation;

use crate::{
local_spawner::{AcpFs, LocalSpawner},
mode_overrides::ModeOverrideStore,
thread::Thread,
};

Expand All @@ -56,6 +57,8 @@ pub struct CodexAgent {
sessions: Rc<RefCell<HashMap<SessionId, Rc<Thread>>>>,
/// Session working directories for filesystem sandboxing
session_roots: Arc<Mutex<HashMap<SessionId, PathBuf>>>,
/// Persists per-project session mode overrides across restarts
mode_override_store: ModeOverrideStore,
}

const SESSION_LIST_PAGE_SIZE: usize = 25;
Expand Down Expand Up @@ -94,13 +97,15 @@ impl CodexAgent {
))
}),
);
let mode_override_store = ModeOverrideStore::new(&config.codex_home);
Self {
auth_manager,
client_capabilities,
config,
thread_manager,
sessions: Rc::default(),
session_roots,
mode_override_store,
}
}

Expand Down Expand Up @@ -358,6 +363,14 @@ impl Agent for CodexAgent {
));
let load = thread.load().await?;

// Replay any mode override persisted for this project directory so that
// the client's last explicit mode choice survives server restarts.
if let Some(mode_id) = self.mode_override_store.get(&config.cwd) {
if let Err(e) = thread.set_mode(SessionModeId::new(mode_id)).await {
warn!("Failed to replay stored mode override: {e}");
}
}

self.sessions
.borrow_mut()
.insert(session_id.clone(), thread);
Expand Down Expand Up @@ -428,6 +441,14 @@ impl Agent for CodexAgent {

let load = thread.load().await?;

// Replay any mode override persisted for this project directory so that
// the client's last explicit mode choice survives server restarts.
if let Some(mode_id) = self.mode_override_store.get(&config.cwd) {
if let Err(e) = thread.set_mode(SessionModeId::new(mode_id)).await {
warn!("Failed to replay stored mode override: {e}");
}
}

self.session_roots
.lock()
.unwrap()
Expand Down Expand Up @@ -524,10 +545,25 @@ impl Agent for CodexAgent {
&self,
args: SetSessionModeRequest,
) -> Result<SetSessionModeResponse, Error> {
info!("Setting session mode for session: {}", args.session_id);
self.get_thread(&args.session_id)?
.set_mode(args.mode_id)
.await?;
let SetSessionModeRequest {
session_id,
mode_id,
..
} = args;
info!("Setting session mode for session: {session_id}");
self.get_thread(&session_id)?.set_mode(mode_id.clone()).await?;

// Persist the mode override so it survives server restarts.
if let Some(cwd) = self
.session_roots
.lock()
.unwrap()
.get(&session_id)
.cloned()
{
self.mode_override_store.set(&cwd, mode_id.0.as_ref());
}

Ok(SetSessionModeResponse::default())
}

Expand All @@ -553,10 +589,31 @@ impl Agent for CodexAgent {
args.session_id, args.config_id.0, args.value.0
);

// If this is a mode change, capture the value before it is moved so we
// can persist it after the call succeeds.
let pending_mode_override: Option<String> = if args.config_id.0.as_ref() == "mode" {
Some(format!("{}", args.value.0))
} else {
None
};

let thread = self.get_thread(&args.session_id)?;

thread.set_config_option(args.config_id, args.value).await?;

// Persist mode override so it survives server restarts.
if let Some(mode_id) = pending_mode_override {
if let Some(cwd) = self
.session_roots
.lock()
.unwrap()
.get(&args.session_id)
.cloned()
{
self.mode_override_store.set(&cwd, &mode_id);
}
}

let config_options = thread.config_options().await?;

Ok(SetSessionConfigOptionResponse::new(config_options))
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tracing_subscriber::EnvFilter;

mod codex_agent;
mod local_spawner;
mod mode_overrides;
mod prompt_args;
mod thread;

Expand Down
80 changes: 80 additions & 0 deletions src/mode_overrides.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//! Persistent storage for ACP session mode overrides.
//!
//! Mode overrides are persisted to `$CODEX_HOME/acp/session-mode-overrides.v1.json`
//! so that a client's explicit mode choice (e.g. `"auto"`, `"full-access"`) is
//! replayed automatically when the same project session is reloaded after a
//! server restart.
//!
//! The file format is a JSON object keyed by the project's working-directory
//! path:
//!
//! ```json
//! {
//! "/home/user/myproject": "auto",
//! "/home/user/other": "full-access"
//! }
//! ```
//!
//! Errors are non-fatal: a warning is logged and the session continues with the
//! default mode.

use std::{
collections::HashMap,
path::{Path, PathBuf},
};

use tracing::warn;

/// Reads and writes per-project session mode overrides.
pub struct ModeOverrideStore {
path: PathBuf,
}

impl ModeOverrideStore {
/// Create a new store backed by
/// `$codex_home/acp/session-mode-overrides.v1.json`.
pub fn new(codex_home: &Path) -> Self {
Self {
path: codex_home.join("acp").join("session-mode-overrides.v1.json"),
}
}

/// Return the persisted mode ID for the given `cwd`, if any.
pub fn get(&self, cwd: &Path) -> Option<String> {
let content = std::fs::read_to_string(&self.path).ok()?;
let map: HashMap<String, String> = serde_json::from_str(&content)
.map_err(|e| warn!("Failed to parse mode overrides file: {e}"))
.ok()?;
map.get(cwd.to_string_lossy().as_ref()).cloned()
}

/// Persist the given `mode_id` for `cwd`, creating the backing file if
/// necessary.
///
/// Non-fatal: logs a warning and returns without panicking on any I/O or
/// serialisation error.
pub fn set(&self, cwd: &Path, mode_id: &str) {
let mut map: HashMap<String, String> = std::fs::read_to_string(&self.path)
.ok()
.and_then(|content| serde_json::from_str(&content).ok())
.unwrap_or_default();

map.insert(cwd.to_string_lossy().into_owned(), mode_id.to_owned());

if let Some(parent) = self.path.parent() {
if let Err(e) = std::fs::create_dir_all(parent) {
warn!("Failed to create acp directory for mode overrides: {e}");
return;
}
}

match serde_json::to_string_pretty(&map) {
Ok(content) => {
if let Err(e) = std::fs::write(&self.path, content) {
warn!("Failed to write mode overrides: {e}");
}
}
Err(e) => warn!("Failed to serialise mode overrides: {e}"),
}
}
}