Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use crate::stream_events_utils::HandleOutputCtx;
use crate::stream_events_utils::handle_non_tool_response_item;
use crate::stream_events_utils::handle_output_item_done;
use crate::stream_events_utils::last_assistant_message_from_item;
use crate::stream_events_utils::mark_thread_memory_mode_polluted_if_external_context;
use crate::stream_events_utils::raw_assistant_output_text_from_item;
use crate::stream_events_utils::record_completed_response_item;
use crate::turn_metadata::TurnMetadataState;
Expand Down Expand Up @@ -7830,8 +7831,15 @@ async fn drain_in_flight(
while let Some(res) = in_flight.next().await {
match res {
Ok(response_input) => {
sess.record_conversation_items(&turn_context, &[response_input.into()])
let response_item = response_input.into();
sess.record_conversation_items(&turn_context, std::slice::from_ref(&response_item))
.await;
mark_thread_memory_mode_polluted_if_external_context(
sess.as_ref(),
turn_context.as_ref(),
&response_item,
)
.await;
}
Err(err) => {
error_or_panic(format!("in-flight tool future failed during drain: {err}"));
Expand Down
15 changes: 12 additions & 3 deletions codex-rs/core/src/stream_events_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,20 @@ pub(crate) async fn record_completed_response_item(
sess.defer_mailbox_delivery_to_next_turn(&turn_context.sub_id)
.await;
}
maybe_mark_thread_memory_mode_polluted_from_web_search(sess, turn_context, item).await;
mark_thread_memory_mode_polluted_if_external_context(sess, turn_context, item).await;
record_stage1_output_usage_for_completed_item(turn_context, item).await;
}

async fn maybe_mark_thread_memory_mode_polluted_from_web_search(
fn response_item_may_include_external_context(item: &ResponseItem) -> bool {
matches!(
item,
ResponseItem::ToolSearchCall { .. }
| ResponseItem::ToolSearchOutput { .. }
| ResponseItem::WebSearchCall { .. }
)
}

pub(crate) async fn mark_thread_memory_mode_polluted_if_external_context(
sess: &Session,
turn_context: &TurnContext,
item: &ResponseItem,
Expand All @@ -149,7 +158,7 @@ async fn maybe_mark_thread_memory_mode_polluted_from_web_search(
.config
.memories
.no_memories_if_mcp_or_web_search
|| !matches!(item, ResponseItem::WebSearchCall { .. })
|| !response_item_may_include_external_context(item)
{
return;
}
Expand Down
83 changes: 83 additions & 0 deletions codex-rs/core/src/stream_events_utils_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@ use super::completed_item_defers_mailbox_delivery_to_next_turn;
use super::handle_non_tool_response_item;
use super::image_generation_artifact_path;
use super::last_assistant_message_from_item;
use super::response_item_may_include_external_context;
use super::save_image_generation_result;
use crate::codex::make_session_and_context;
use codex_protocol::error::CodexErr;
use codex_protocol::items::TurnItem;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::LocalShellAction;
use codex_protocol::models::LocalShellExecAction;
use codex_protocol::models::LocalShellStatus;
use codex_protocol::models::MessagePhase;
use codex_protocol::models::ResponseItem;
use codex_utils_absolute_path::test_support::PathExt;
Expand All @@ -28,6 +33,84 @@ fn assistant_output_text_with_phase(text: &str, phase: Option<MessagePhase>) ->
}
}

#[test]
fn external_context_pollution_items_include_web_search_and_tool_search() {
let polluting_items = [
ResponseItem::WebSearchCall {
id: None,
status: Some("completed".to_string()),
action: None,
},
ResponseItem::ToolSearchCall {
id: None,
call_id: Some("search-1".to_string()),
status: None,
execution: "client".to_string(),
arguments: serde_json::json!({"query": "calendar"}),
},
ResponseItem::ToolSearchOutput {
call_id: Some("search-1".to_string()),
status: "completed".to_string(),
execution: "client".to_string(),
tools: Vec::new(),
},
];

assert!(
polluting_items
.iter()
.all(response_item_may_include_external_context)
);
}

#[test]
fn external_context_pollution_items_exclude_local_tool_calls() {
let non_polluting_items = [
ResponseItem::LocalShellCall {
id: None,
call_id: Some("shell-1".to_string()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["cat".to_string(), "README.md".to_string()],
timeout_ms: None,
working_directory: None,
env: None,
user: None,
}),
},
ResponseItem::FunctionCall {
id: None,
name: "shell".to_string(),
namespace: None,
arguments: "{}".to_string(),
call_id: "call-1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload::from_text("ok".to_string()),
},
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "custom-1".to_string(),
name: "apply_patch".to_string(),
input: "*** Begin Patch\n*** End Patch\n".to_string(),
},
ResponseItem::CustomToolCallOutput {
call_id: "custom-1".to_string(),
name: Some("apply_patch".to_string()),
output: FunctionCallOutputPayload::from_text("ok".to_string()),
},
assistant_output_text("plain assistant text"),
];

assert!(
!non_polluting_items
.iter()
.any(response_item_may_include_external_context)
);
}

#[tokio::test]
async fn handle_non_tool_response_item_strips_citations_from_assistant_message() {
let (session, turn_context) = make_session_and_context().await;
Expand Down
Loading