Skip to content

Commit 3cbb522

Browse files
committed
Revalidate cached agent identity tasks
1 parent ce74a16 commit 3cbb522

5 files changed

Lines changed: 268 additions & 31 deletions

File tree

codex-rs/core/src/agent_identity.rs

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,20 @@ impl AgentIdentityManager {
150150
return Ok(None);
151151
};
152152

153+
self.ensure_registered_identity_for_binding(&binding).await
154+
}
155+
156+
async fn ensure_registered_identity_for_binding(
157+
&self,
158+
binding: &AgentIdentityBinding,
159+
) -> Result<Option<StoredAgentIdentity>> {
160+
if !self.feature_enabled {
161+
return Ok(None);
162+
}
163+
153164
let _guard = self.ensure_lock.lock().await;
154165

155-
if let Some(stored_identity) = self.load_stored_identity(&binding)? {
166+
if let Some(stored_identity) = self.load_stored_identity(binding)? {
156167
info!(
157168
agent_runtime_id = %stored_identity.agent_runtime_id,
158169
binding_id = %binding.binding_id,
@@ -161,11 +172,21 @@ impl AgentIdentityManager {
161172
return Ok(Some(stored_identity));
162173
}
163174

164-
let stored_identity = self.register_agent_identity(&binding).await?;
165-
self.store_identity(&binding, &stored_identity)?;
175+
let stored_identity = self.register_agent_identity(binding).await?;
176+
self.store_identity(binding, &stored_identity)?;
166177
Ok(Some(stored_identity))
167178
}
168179

180+
pub(crate) async fn task_matches_current_binding(&self, task: &RegisteredAgentTask) -> bool {
181+
if !self.feature_enabled {
182+
return false;
183+
}
184+
185+
self.current_binding()
186+
.await
187+
.is_some_and(|binding| task.matches_binding(&binding))
188+
}
189+
169190
async fn current_binding(&self) -> Option<AgentIdentityBinding> {
170191
let Some(auth) = self.auth_manager.auth().await else {
171192
debug!("skipping agent identity flow because no auth is available");
@@ -360,12 +381,11 @@ impl AgentIdentityManager {
360381

361382
impl StoredAgentIdentity {
362383
fn matches_binding(&self, binding: &AgentIdentityBinding) -> bool {
363-
self.binding_id == binding.binding_id
364-
&& self.chatgpt_account_id == binding.chatgpt_account_id
365-
&& match binding.chatgpt_user_id.as_deref() {
366-
Some(chatgpt_user_id) => self.chatgpt_user_id.as_deref() == Some(chatgpt_user_id),
367-
None => true,
368-
}
384+
binding.matches_parts(
385+
&self.binding_id,
386+
&self.chatgpt_account_id,
387+
self.chatgpt_user_id.as_deref(),
388+
)
369389
}
370390

371391
fn validate_key_material(&self) -> Result<()> {
@@ -388,6 +408,20 @@ impl StoredAgentIdentity {
388408
}
389409

390410
impl AgentIdentityBinding {
411+
fn matches_parts(
412+
&self,
413+
binding_id: &str,
414+
chatgpt_account_id: &str,
415+
chatgpt_user_id: Option<&str>,
416+
) -> bool {
417+
binding_id == self.binding_id
418+
&& chatgpt_account_id == self.chatgpt_account_id
419+
&& match self.chatgpt_user_id.as_deref() {
420+
Some(expected_user_id) => chatgpt_user_id == Some(expected_user_id),
421+
None => true,
422+
}
423+
}
424+
391425
fn from_auth(auth: &CodexAuth, forced_workspace_id: Option<String>) -> Option<Self> {
392426
if !auth.is_chatgpt_auth() {
393427
return None;

codex-rs/core/src/agent_identity/task_registration.rs

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ const AGENT_TASK_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(15);
1616

1717
#[derive(Clone, Debug, PartialEq, Eq)]
1818
pub(crate) struct RegisteredAgentTask {
19+
pub(crate) binding_id: String,
20+
pub(crate) chatgpt_account_id: String,
21+
pub(crate) chatgpt_user_id: Option<String>,
1922
pub(crate) agent_runtime_id: String,
2023
pub(crate) task_id: String,
2124
pub(crate) registered_at: String,
@@ -41,7 +44,18 @@ impl AgentIdentityManager {
4144
let Some(binding) = self.current_binding().await else {
4245
return Ok(None);
4346
};
44-
let Some(stored_identity) = self.ensure_registered_identity().await? else {
47+
48+
self.register_task_for_binding(binding).await
49+
}
50+
51+
async fn register_task_for_binding(
52+
&self,
53+
binding: AgentIdentityBinding,
54+
) -> Result<Option<RegisteredAgentTask>> {
55+
let Some(stored_identity) = self
56+
.ensure_registered_identity_for_binding(&binding)
57+
.await?
58+
else {
4559
return Ok(None);
4660
};
4761

@@ -72,6 +86,9 @@ impl AgentIdentityManager {
7286
.await
7387
.with_context(|| format!("failed to parse agent task response from {url}"))?;
7488
let registered_task = RegisteredAgentTask {
89+
binding_id: stored_identity.binding_id.clone(),
90+
chatgpt_account_id: stored_identity.chatgpt_account_id.clone(),
91+
chatgpt_user_id: stored_identity.chatgpt_user_id.clone(),
7592
agent_runtime_id: stored_identity.agent_runtime_id.clone(),
7693
task_id: decrypt_task_id_response(
7794
&stored_identity,
@@ -93,6 +110,22 @@ impl AgentIdentityManager {
93110
}
94111
}
95112

113+
impl RegisteredAgentTask {
114+
pub(super) fn matches_binding(&self, binding: &AgentIdentityBinding) -> bool {
115+
binding.matches_parts(
116+
&self.binding_id,
117+
&self.chatgpt_account_id,
118+
self.chatgpt_user_id.as_deref(),
119+
)
120+
}
121+
122+
pub(crate) fn has_same_binding(&self, other: &Self) -> bool {
123+
self.binding_id == other.binding_id
124+
&& self.chatgpt_account_id == other.chatgpt_account_id
125+
&& self.chatgpt_user_id == other.chatgpt_user_id
126+
}
127+
}
128+
96129
fn sign_task_registration_payload(
97130
stored_identity: &StoredAgentIdentity,
98131
timestamp: &str,
@@ -240,6 +273,9 @@ mod tests {
240273
assert_eq!(
241274
task,
242275
RegisteredAgentTask {
276+
binding_id: "chatgpt-account-account-123".to_string(),
277+
chatgpt_account_id: "account-123".to_string(),
278+
chatgpt_user_id: Some("user-123".to_string()),
243279
agent_runtime_id: "agent-123".to_string(),
244280
task_id: "task_123".to_string(),
245281
registered_at: task.registered_at.clone(),
@@ -292,6 +328,95 @@ mod tests {
292328
assert_eq!(task.task_id, "task_fallback");
293329
}
294330

331+
#[tokio::test]
332+
async fn register_task_for_binding_keeps_one_auth_snapshot() {
333+
let server = MockServer::start().await;
334+
mount_human_biscuit(&server).await;
335+
let tempdir = tempfile::tempdir().expect("tempdir");
336+
let keyring_store = Arc::new(MockKeyringStore::default());
337+
let secrets_manager = SecretsManager::new_with_keyring_store(
338+
tempdir.path().to_path_buf(),
339+
SecretsBackendKind::Local,
340+
keyring_store,
341+
);
342+
let auth_manager =
343+
AuthManager::from_auth_for_testing(make_chatgpt_auth("account-456", Some("user-456")));
344+
let manager = AgentIdentityManager::new_for_tests(
345+
auth_manager,
346+
/*feature_enabled*/ true,
347+
server.uri(),
348+
SessionSource::Cli,
349+
secrets_manager.clone(),
350+
);
351+
let stored_identity =
352+
seed_stored_identity(&manager, &secrets_manager, "agent-123", "account-123");
353+
let encrypted_task_id =
354+
encrypt_task_id_for_identity(&stored_identity, "task_123").expect("task ciphertext");
355+
let binding = AgentIdentityBinding::from_auth(
356+
&make_chatgpt_auth("account-123", Some("user-123")),
357+
/*forced_workspace_id*/ None,
358+
)
359+
.expect("binding");
360+
361+
Mock::given(method("POST"))
362+
.and(path("/v1/agent/agent-123/task/register"))
363+
.and(header("x-openai-authorization", "human-biscuit"))
364+
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
365+
"encrypted_task_id": encrypted_task_id,
366+
})))
367+
.expect(1)
368+
.mount(&server)
369+
.await;
370+
371+
let task = manager
372+
.register_task_for_binding(binding)
373+
.await
374+
.unwrap()
375+
.expect("task should be registered");
376+
377+
assert_eq!(
378+
task,
379+
RegisteredAgentTask {
380+
binding_id: "chatgpt-account-account-123".to_string(),
381+
chatgpt_account_id: "account-123".to_string(),
382+
chatgpt_user_id: Some("user-123".to_string()),
383+
agent_runtime_id: "agent-123".to_string(),
384+
task_id: "task_123".to_string(),
385+
registered_at: task.registered_at.clone(),
386+
}
387+
);
388+
}
389+
390+
#[tokio::test]
391+
async fn task_matches_current_binding_rejects_stale_auth_binding() {
392+
let tempdir = tempfile::tempdir().expect("tempdir");
393+
let keyring_store = Arc::new(MockKeyringStore::default());
394+
let secrets_manager = SecretsManager::new_with_keyring_store(
395+
tempdir.path().to_path_buf(),
396+
SecretsBackendKind::Local,
397+
keyring_store,
398+
);
399+
let auth_manager =
400+
AuthManager::from_auth_for_testing(make_chatgpt_auth("account-456", Some("user-456")));
401+
let manager = AgentIdentityManager::new_for_tests(
402+
auth_manager,
403+
/*feature_enabled*/ true,
404+
"https://chatgpt.com/backend-api/".to_string(),
405+
SessionSource::Cli,
406+
secrets_manager,
407+
);
408+
let task = RegisteredAgentTask {
409+
binding_id: "chatgpt-account-account-123".to_string(),
410+
chatgpt_account_id: "account-123".to_string(),
411+
chatgpt_user_id: Some("user-123".to_string()),
412+
agent_runtime_id: "agent-123".to_string(),
413+
task_id: "task_123".to_string(),
414+
registered_at: "2026-03-23T12:00:00Z".to_string(),
415+
};
416+
417+
assert!(!manager.task_matches_current_binding(&task).await);
418+
}
419+
295420
async fn mount_human_biscuit(server: &MockServer) {
296421
Mock::given(method("GET"))
297422
.and(path("/authenticate_app_v2"))

codex-rs/core/src/codex.rs

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,37 +1528,89 @@ impl Session {
15281528
.await;
15291529
handlers::shutdown(self, self.next_internal_sub_id()).await;
15301530
}
1531-
async fn ensure_agent_task_registered(&self) -> anyhow::Result<Option<RegisteredAgentTask>> {
1532-
{
1531+
1532+
async fn cached_agent_task_for_current_binding(&self) -> Option<RegisteredAgentTask> {
1533+
let agent_task = {
15331534
let state = self.state.lock().await;
1534-
if let Some(agent_task) = state.agent_task() {
1535+
state.agent_task()
1536+
}?;
1537+
1538+
if self
1539+
.services
1540+
.agent_identity_manager
1541+
.task_matches_current_binding(&agent_task)
1542+
.await
1543+
{
1544+
debug!(
1545+
agent_runtime_id = %agent_task.agent_runtime_id,
1546+
task_id = %agent_task.task_id,
1547+
"reusing cached agent task"
1548+
);
1549+
return Some(agent_task);
1550+
}
1551+
1552+
debug!(
1553+
agent_runtime_id = %agent_task.agent_runtime_id,
1554+
task_id = %agent_task.task_id,
1555+
"discarding cached agent task because auth binding changed"
1556+
);
1557+
let mut state = self.state.lock().await;
1558+
if state.agent_task().as_ref() == Some(&agent_task) {
1559+
state.clear_agent_task();
1560+
}
1561+
None
1562+
}
1563+
1564+
async fn ensure_agent_task_registered(&self) -> anyhow::Result<Option<RegisteredAgentTask>> {
1565+
if let Some(agent_task) = self.cached_agent_task_for_current_binding().await {
1566+
return Ok(Some(agent_task));
1567+
}
1568+
1569+
for _ in 0..2 {
1570+
let Some(agent_task) = self.services.agent_identity_manager.register_task().await?
1571+
else {
1572+
return Ok(None);
1573+
};
1574+
1575+
if !self
1576+
.services
1577+
.agent_identity_manager
1578+
.task_matches_current_binding(&agent_task)
1579+
.await
1580+
{
15351581
debug!(
15361582
agent_runtime_id = %agent_task.agent_runtime_id,
15371583
task_id = %agent_task.task_id,
1538-
"reusing cached agent task"
1584+
"discarding newly registered agent task because auth binding changed"
15391585
);
1540-
return Ok(Some(agent_task));
1586+
continue;
15411587
}
1542-
}
15431588

1544-
let Some(agent_task) = self.services.agent_identity_manager.register_task().await? else {
1545-
return Ok(None);
1546-
};
1547-
{
1548-
let mut state = self.state.lock().await;
1549-
if let Some(existing_agent_task) = state.agent_task() {
1550-
return Ok(Some(existing_agent_task));
1589+
{
1590+
let mut state = self.state.lock().await;
1591+
if let Some(existing_agent_task) = state.agent_task() {
1592+
if existing_agent_task.has_same_binding(&agent_task) {
1593+
return Ok(Some(existing_agent_task));
1594+
}
1595+
debug!(
1596+
agent_runtime_id = %existing_agent_task.agent_runtime_id,
1597+
task_id = %existing_agent_task.task_id,
1598+
"replacing cached agent task because auth binding changed"
1599+
);
1600+
}
1601+
state.set_agent_task(agent_task.clone());
15511602
}
1552-
state.set_agent_task(agent_task.clone());
1603+
1604+
info!(
1605+
thread_id = %self.conversation_id,
1606+
agent_runtime_id = %agent_task.agent_runtime_id,
1607+
task_id = %agent_task.task_id,
1608+
"registered agent task for thread"
1609+
);
1610+
return Ok(Some(agent_task));
15531611
}
15541612

1555-
info!(
1556-
thread_id = %self.conversation_id,
1557-
agent_runtime_id = %agent_task.agent_runtime_id,
1558-
task_id = %agent_task.task_id,
1559-
"registered agent task for thread"
1560-
);
1561-
Ok(Some(agent_task))
1613+
Ok(None)
15621614
}
15631615

15641616
#[allow(clippy::too_many_arguments)]

codex-rs/core/src/state/session.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ impl SessionState {
185185
self.agent_task = Some(agent_task);
186186
}
187187

188+
pub(crate) fn clear_agent_task(&mut self) {
189+
self.agent_task = None;
190+
}
191+
188192
// Adds connector IDs to the active set and returns the merged selection.
189193
pub(crate) fn merge_connector_selection<I>(&mut self, connector_ids: I) -> HashSet<String>
190194
where

0 commit comments

Comments
 (0)