From cd0e8738e99a79c92891d65653540678ec3c8aea Mon Sep 17 00:00:00 2001 From: Adrian Cole Date: Fri, 6 Mar 2026 12:33:03 +0800 Subject: [PATCH] fix: restore smart-approve mode Signed-off-by: Adrian Cole --- crates/goose-test-support/src/mcp.rs | 2 +- crates/goose/src/agents/agent.rs | 14 +- crates/goose/src/config/permission.rs | 73 ++++++++ crates/goose/src/permission/mod.rs | 1 - .../src/permission/permission_inspector.rs | 168 +++++++++++++++--- .../goose/src/permission/permission_judge.rs | 107 ----------- .../goose/src/security/security_inspector.rs | 3 +- crates/goose/src/tool_inspection.rs | 57 +++--- crates/goose/src/tool_monitor.rs | 1 + crates/goose/tests/providers.rs | 71 ++++++-- .../tests/tool_inspection_manager_tests.rs | 9 +- 11 files changed, 327 insertions(+), 179 deletions(-) diff --git a/crates/goose-test-support/src/mcp.rs b/crates/goose-test-support/src/mcp.rs index 8ccba5c52597..519f57ac051d 100644 --- a/crates/goose-test-support/src/mcp.rs +++ b/crates/goose-test-support/src/mcp.rs @@ -105,7 +105,7 @@ impl McpFixtureServer { } } - #[tool(description = "Get the code")] + #[tool(description = "Get the code", annotations(read_only_hint = true))] fn get_code(&self) -> Result { Ok(CallToolResult::success(vec![Content::text(FAKE_CODE)])) } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 42f993558a79..7e494b3f70bb 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -245,7 +245,10 @@ impl Agent { tool_result_tx: tool_tx, tool_result_rx: Arc::new(Mutex::new(tool_rx)), retry_manager: RetryManager::new(), - tool_inspection_manager: Self::create_tool_inspection_manager(permission_manager), + tool_inspection_manager: Self::create_tool_inspection_manager( + permission_manager, + provider.clone(), + ), container: Mutex::new(None), } } @@ -253,6 +256,7 @@ impl Agent { /// Create a tool inspection manager with default inspectors fn create_tool_inspection_manager( permission_manager: Arc, + provider: SharedProvider, ) -> ToolInspectionManager { let mut tool_inspection_manager = ToolInspectionManager::new(); @@ -261,9 +265,8 @@ impl Agent { // Add permission inspector (medium-high priority) tool_inspection_manager.add_inspector(Box::new(PermissionInspector::new( - std::collections::HashSet::new(), // readonly tools - will be populated from extension manager - std::collections::HashSet::new(), // regular tools - will be populated from extension manager permission_manager, + provider, ))); // Add repetition inspector (lower priority - basic repetition checking) @@ -350,6 +353,10 @@ impl Agent { .prepare_tools_and_prompt(session_id, working_dir) .await?; + if self.config.goose_mode == GooseMode::SmartApprove { + self.tool_inspection_manager.apply_tool_annotations(&tools); + } + Ok(ReplyContext { conversation, tools, @@ -1261,6 +1268,7 @@ impl Agent { // Run all tool inspectors let inspection_results = self.tool_inspection_manager .inspect_tools( + &session_config.id, &remaining_requests, conversation.messages(), goose_mode, diff --git a/crates/goose/src/config/permission.rs b/crates/goose/src/config/permission.rs index 03e6bbc70e8a..0aa8acd226e7 100644 --- a/crates/goose/src/config/permission.rs +++ b/crates/goose/src/config/permission.rs @@ -1,4 +1,5 @@ use crate::config::paths::Paths; +use rmcp::model::Tool; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; @@ -97,6 +98,51 @@ impl PermissionManager { self.config_path.as_path() } + pub fn apply_tool_annotations(&self, tools: &[Tool]) { + let mut write_annotated = Vec::new(); + for tool in tools { + let Some(anns) = &tool.annotations else { + continue; + }; + if anns.read_only_hint == Some(false) { + write_annotated.push(tool.name.to_string()); + } + } + if !write_annotated.is_empty() { + self.bulk_update_smart_approve_permissions( + &write_annotated, + PermissionLevel::AskBefore, + ); + } + } + + fn bulk_update_smart_approve_permissions(&self, tool_names: &[String], level: PermissionLevel) { + let mut map = self.permission_map.write().unwrap(); + let permission_config = map.entry(SMART_APPROVE_PERMISSION.to_string()).or_default(); + + for tool_name in tool_names { + // Remove from all lists to avoid duplicates + permission_config.always_allow.retain(|p| p != tool_name); + permission_config.ask_before.retain(|p| p != tool_name); + permission_config.never_allow.retain(|p| p != tool_name); + + // Add to the appropriate list + match &level { + PermissionLevel::AlwaysAllow => { + permission_config.always_allow.push(tool_name.clone()) + } + PermissionLevel::AskBefore => permission_config.ask_before.push(tool_name.clone()), + PermissionLevel::NeverAllow => { + permission_config.never_allow.push(tool_name.clone()) + } + } + } + + let yaml_content = + serde_yaml::to_string(&*map).expect("Failed to serialize permission config"); + fs::write(&self.config_path, yaml_content).expect("Failed to write to permission.yaml"); + } + /// Helper function to retrieve the permission level for a specific permission category and tool. fn get_permission(&self, name: &str, principal_name: &str) -> Option { let map = self.permission_map.read().unwrap(); @@ -191,6 +237,8 @@ impl PermissionManager { #[cfg(test)] mod tests { use super::*; + use rmcp::model::ToolAnnotations; + use rmcp::object; use tempfile::TempDir; // Helper function to create a test instance of PermissionManager with a temp dir @@ -313,4 +361,29 @@ mod tests { fs::write(&permission_path, "{{invalid yaml: [broken").unwrap(); PermissionManager::new(temp_dir.path().to_path_buf()); } + + use test_case::test_case; + + #[test_case( + vec![Tool::new("tool".to_string(), String::new(), object!({"type": "object"})) + .annotate(ToolAnnotations::new().read_only(false))], + Some(PermissionLevel::AskBefore); + "write_annotation_caches_ask" + )] + #[test_case( + vec![Tool::new("tool".to_string(), String::new(), object!({"type": "object"}))], + None; + "unannotated_left_uncached" + )] + #[test_case( + vec![Tool::new("tool".to_string(), String::new(), object!({"type": "object"})) + .annotate(ToolAnnotations::new().read_only(true))], + None; + "readonly_annotation_skipped" + )] + fn test_apply_tool_annotations(tools: Vec, expect_cache: Option) { + let (manager, _temp_dir) = create_test_permission_manager(); + manager.apply_tool_annotations(&tools); + assert_eq!(manager.get_smart_approve_permission("tool"), expect_cache); + } } diff --git a/crates/goose/src/permission/mod.rs b/crates/goose/src/permission/mod.rs index d261577a99fc..dea80164d8e1 100644 --- a/crates/goose/src/permission/mod.rs +++ b/crates/goose/src/permission/mod.rs @@ -5,5 +5,4 @@ pub mod permission_store; pub use permission_confirmation::{Permission, PermissionConfirmation}; pub use permission_inspector::PermissionInspector; -pub use permission_judge::detect_read_only_tools; pub use permission_store::ToolPermissionStore; diff --git a/crates/goose/src/permission/permission_inspector.rs b/crates/goose/src/permission/permission_inspector.rs index 8510d6fc333f..6f62f5609673 100644 --- a/crates/goose/src/permission/permission_inspector.rs +++ b/crates/goose/src/permission/permission_inspector.rs @@ -1,34 +1,52 @@ use crate::agents::platform_extensions::MANAGE_EXTENSIONS_TOOL_NAME_COMPLETE; +use crate::agents::types::SharedProvider; use crate::config::permission::PermissionLevel; use crate::config::{GooseMode, PermissionManager}; use crate::conversation::message::{Message, ToolRequest}; -use crate::permission::permission_judge::PermissionCheckResult; +use crate::permission::permission_judge::{detect_read_only_tools, PermissionCheckResult}; use crate::tool_inspection::{InspectionAction, InspectionResult, ToolInspector}; use anyhow::Result; use async_trait::async_trait; +use rmcp::model::Tool; use std::collections::HashSet; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; /// Permission Inspector that handles tool permission checking pub struct PermissionInspector { - readonly_tools: HashSet, - regular_tools: HashSet, pub permission_manager: Arc, + provider: SharedProvider, + readonly_tools: RwLock>, } impl PermissionInspector { - pub fn new( - readonly_tools: HashSet, - regular_tools: HashSet, - permission_manager: Arc, - ) -> Self { + pub fn new(permission_manager: Arc, provider: SharedProvider) -> Self { Self { - readonly_tools, - regular_tools, permission_manager, + provider, + readonly_tools: RwLock::new(HashSet::new()), } } + // readonly_tools is per-agent to avoid concurrent session clobbering; write-annotated + // tools are cached globally via PermissionManager. + pub fn apply_tool_annotations(&self, tools: &[Tool]) { + let mut readonly_annotated = HashSet::new(); + for tool in tools { + let Some(anns) = &tool.annotations else { + continue; + }; + if anns.read_only_hint == Some(true) { + readonly_annotated.insert(tool.name.to_string()); + } + } + *self.readonly_tools.write().unwrap() = readonly_annotated; + self.permission_manager.apply_tool_annotations(tools); + } + + pub fn is_readonly_annotated_tool(&self, tool_name: &str) -> bool { + self.readonly_tools.read().unwrap().contains(tool_name) + } + /// Process inspection results into permission decisions /// This method takes all inspection results and converts them into a PermissionCheckResult /// that can be used by the agent to determine which tools to approve, deny, or ask for approval @@ -105,12 +123,14 @@ impl ToolInspector for PermissionInspector { async fn inspect( &self, + session_id: &str, tool_requests: &[ToolRequest], _messages: &[Message], goose_mode: GooseMode, ) -> Result> { let mut results = Vec::new(); let permission_manager = &self.permission_manager; + let mut llm_detect_candidates: Vec<&ToolRequest> = Vec::new(); for request in tool_requests { if let Ok(tool_call) = &request.tool_call { @@ -129,21 +149,28 @@ impl ToolInspector for PermissionInspector { InspectionAction::RequireApproval(None) } } - } - // 2. Check if it's a readonly or regular tool (both pre-approved) - else if self.readonly_tools.contains(&**tool_name) - || self.regular_tools.contains(&**tool_name) + // 2. Check if it's a smart-approved tool (annotation or cached LLM decision) + } else if self.is_readonly_annotated_tool(tool_name) + || (goose_mode == GooseMode::SmartApprove + && permission_manager.get_smart_approve_permission(tool_name) + == Some(PermissionLevel::AlwaysAllow)) { InspectionAction::Allow - } - // 4. Special case for extension management - else if tool_name == MANAGE_EXTENSIONS_TOOL_NAME_COMPLETE { + // 3. Special case for extension management + } else if tool_name == MANAGE_EXTENSIONS_TOOL_NAME_COMPLETE { InspectionAction::RequireApproval(Some( "Extension management requires approval for security".to_string(), )) - } + // 4. Defer to LLM detection (SmartApprove, not yet cached) + } else if goose_mode == GooseMode::SmartApprove + && permission_manager + .get_smart_approve_permission(tool_name) + .is_none() + { + llm_detect_candidates.push(request); + continue; // 5. Default: require approval for unknown tools - else { + } else { InspectionAction::RequireApproval(None) } } @@ -153,10 +180,10 @@ impl ToolInspector for PermissionInspector { InspectionAction::Allow => { if goose_mode == GooseMode::Auto { "Auto mode - all tools approved".to_string() - } else if self.readonly_tools.contains(&**tool_name) { - "Tool marked as read-only".to_string() - } else if self.regular_tools.contains(&**tool_name) { - "Tool pre-approved".to_string() + } else if self.is_readonly_annotated_tool(tool_name) { + "Tool annotated as read-only".to_string() + } else if goose_mode == GooseMode::SmartApprove { + "SmartApprove cached as read-only".to_string() } else { "User permission allows this tool".to_string() } @@ -182,6 +209,99 @@ impl ToolInspector for PermissionInspector { } } + // LLM-based read-only detection for deferred SmartApprove candidates + if !llm_detect_candidates.is_empty() { + let detected: HashSet = match self.provider.lock().await.clone() { + Some(provider) => { + detect_read_only_tools(provider, session_id, llm_detect_candidates.to_vec()) + .await + .into_iter() + .collect() + } + None => Default::default(), + }; + + for candidate in &llm_detect_candidates { + let is_readonly = candidate + .tool_call + .as_ref() + .map(|tc| detected.contains(&tc.name.to_string())) + .unwrap_or(false); + + // Cache the LLM decision for future calls + if let Ok(tc) = &candidate.tool_call { + let level = if is_readonly { + PermissionLevel::AlwaysAllow + } else { + PermissionLevel::AskBefore + }; + permission_manager.update_smart_approve_permission(&tc.name, level); + } + + results.push(InspectionResult { + tool_request_id: candidate.id.clone(), + action: if is_readonly { + InspectionAction::Allow + } else { + InspectionAction::RequireApproval(None) + }, + reason: if is_readonly { + "LLM detected as read-only".to_string() + } else { + "Tool requires user approval".to_string() + }, + confidence: 1.0, // Permission decisions are definitive + inspector_name: self.name().to_string(), + finding_id: None, + }); + } + } + Ok(results) } } + +#[cfg(test)] +mod tests { + use super::*; + use rmcp::model::CallToolRequestParams; + use rmcp::object; + use std::sync::Arc; + use test_case::test_case; + use tokio::sync::Mutex; + + #[test_case(GooseMode::Auto, false, None, InspectionAction::Allow; "auto_allows")] + #[test_case(GooseMode::SmartApprove, true, None, InspectionAction::Allow; "smart_approve_annotation_allows")] + #[test_case(GooseMode::SmartApprove, false, Some(PermissionLevel::AlwaysAllow), InspectionAction::Allow; "smart_approve_cached_allow")] + #[test_case(GooseMode::SmartApprove, false, Some(PermissionLevel::AskBefore), InspectionAction::RequireApproval(None); "smart_approve_cached_ask")] + #[test_case(GooseMode::SmartApprove, false, None, InspectionAction::RequireApproval(None); "smart_approve_unknown_defers")] + #[test_case(GooseMode::Approve, false, None, InspectionAction::RequireApproval(None); "approve_requires_approval")] + #[test_case(GooseMode::Approve, false, Some(PermissionLevel::AlwaysAllow), InspectionAction::RequireApproval(None); "approve_ignores_cache")] + #[tokio::test] + async fn test_inspect_action( + mode: GooseMode, + smart_approved: bool, + cache: Option, + expected: InspectionAction, + ) { + let pm = Arc::new(PermissionManager::new(tempfile::tempdir().unwrap().keep())); + if let Some(level) = cache { + pm.update_smart_approve_permission("tool", level); + } + let inspector = PermissionInspector::new(pm, Arc::new(Mutex::new(None))); + if smart_approved { + *inspector.readonly_tools.write().unwrap() = ["tool".to_string()].into_iter().collect(); + } + let req = ToolRequest { + id: "req".into(), + tool_call: Ok(CallToolRequestParams::new("tool").with_arguments(object!({}))), + metadata: None, + tool_meta: None, + }; + let results = inspector + .inspect(goose_test_support::TEST_SESSION_ID, &[req], &[], mode) + .await + .unwrap(); + assert_eq!(results[0].action, expected); + } +} diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index 039f18e02574..b6dbe3c96f6f 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -1,6 +1,3 @@ -use crate::agents::platform_extensions::MANAGE_EXTENSIONS_TOOL_NAME_COMPLETE; -use crate::config::permission::PermissionLevel; -use crate::config::PermissionManager; use crate::conversation::message::{Message, MessageContent, ToolRequest}; use crate::conversation::Conversation; use crate::prompt_template::render_template; @@ -11,7 +8,6 @@ use rmcp::model::{Tool, ToolAnnotations}; use rmcp::object; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::HashSet; use std::sync::Arc; #[derive(Serialize)] @@ -164,106 +160,3 @@ pub struct PermissionCheckResult { pub needs_approval: Vec, pub denied: Vec, } - -pub async fn check_tool_permissions( - session_id: &str, - candidate_requests: &[ToolRequest], - mode: &str, - tools_with_readonly_annotation: HashSet, - tools_without_annotation: HashSet, - permission_manager: &mut PermissionManager, - provider: Arc, -) -> (PermissionCheckResult, Vec) { - let mut approved = vec![]; - let mut needs_approval = vec![]; - let mut denied = vec![]; - let mut llm_detect_candidates = vec![]; - let mut extension_request_ids = vec![]; - - for request in candidate_requests { - if let Ok(tool_call) = request.tool_call.clone() { - if mode == "chat" { - continue; - } else if mode == "auto" { - approved.push(request.clone()); - } else { - if tool_call.name == MANAGE_EXTENSIONS_TOOL_NAME_COMPLETE { - extension_request_ids.push(request.id.clone()); - } - - // 1. Check user-defined permission - if let Some(level) = permission_manager.get_user_permission(&tool_call.name) { - match level { - PermissionLevel::AlwaysAllow => approved.push(request.clone()), - PermissionLevel::AskBefore => needs_approval.push(request.clone()), - PermissionLevel::NeverAllow => denied.push(request.clone()), - } - continue; - } - - // 2. Fallback based on mode - match mode { - "approve" => { - needs_approval.push(request.clone()); - } - "smart_approve" => { - if let Some(level) = - permission_manager.get_smart_approve_permission(&tool_call.name) - { - match level { - PermissionLevel::AlwaysAllow => approved.push(request.clone()), - PermissionLevel::AskBefore => needs_approval.push(request.clone()), - PermissionLevel::NeverAllow => denied.push(request.clone()), - } - continue; - } - - if tools_with_readonly_annotation.contains(&tool_call.name.to_string()) { - approved.push(request.clone()); - } else if tools_without_annotation.contains(&tool_call.name.to_string()) { - llm_detect_candidates.push(request.clone()); - } else { - needs_approval.push(request.clone()); - } - } - _ => { - needs_approval.push(request.clone()); - } - } - } - } - } - - // 3. LLM detect - if !llm_detect_candidates.is_empty() && mode == "smart_approve" { - let detected_readonly_tools = - detect_read_only_tools(provider, session_id, llm_detect_candidates.iter().collect()) - .await; - for request in llm_detect_candidates { - if let Ok(tool_call) = request.tool_call.clone() { - if detected_readonly_tools.contains(&tool_call.name.to_string()) { - approved.push(request.clone()); - permission_manager.update_smart_approve_permission( - &tool_call.name, - PermissionLevel::AlwaysAllow, - ); - } else { - needs_approval.push(request.clone()); - permission_manager.update_smart_approve_permission( - &tool_call.name, - PermissionLevel::AskBefore, - ); - } - } - } - } - - ( - PermissionCheckResult { - approved, - needs_approval, - denied, - }, - extension_request_ids, - ) -} diff --git a/crates/goose/src/security/security_inspector.rs b/crates/goose/src/security/security_inspector.rs index e1f47402e21e..e2b074ca092c 100644 --- a/crates/goose/src/security/security_inspector.rs +++ b/crates/goose/src/security/security_inspector.rs @@ -58,6 +58,7 @@ impl ToolInspector for SecurityInspector { async fn inspect( &self, + _session_id: &str, tool_requests: &[ToolRequest], messages: &[Message], _goose_mode: GooseMode, @@ -113,7 +114,7 @@ mod tests { }]; let results = inspector - .inspect(&tool_requests, &[], GooseMode::Approve) + .inspect("test", &tool_requests, &[], GooseMode::Approve) .await .unwrap(); diff --git a/crates/goose/src/tool_inspection.rs b/crates/goose/src/tool_inspection.rs index 306749538aff..193773b500ea 100644 --- a/crates/goose/src/tool_inspection.rs +++ b/crates/goose/src/tool_inspection.rs @@ -38,6 +38,7 @@ pub trait ToolInspector: Send + Sync { /// Inspect tool requests and return results async fn inspect( &self, + session_id: &str, tool_requests: &[ToolRequest], messages: &[Message], goose_mode: GooseMode, @@ -73,6 +74,7 @@ impl ToolInspectionManager { /// Run all inspectors on the tool requests pub async fn inspect_tools( &self, + session_id: &str, tool_requests: &[ToolRequest], messages: &[Message], goose_mode: GooseMode, @@ -90,7 +92,10 @@ impl ToolInspectionManager { "Running tool inspector" ); - match inspector.inspect(tool_requests, messages, goose_mode).await { + match inspector + .inspect(session_id, tool_requests, messages, goose_mode) + .await + { Ok(results) => { tracing::debug!( inspector_name = inspector.name(), @@ -118,49 +123,39 @@ impl ToolInspectionManager { self.inspectors.iter().map(|i| i.name()).collect() } - /// Update the permission manager for a specific tool + fn get_permission_inspector(&self) -> Option<&PermissionInspector> { + self.inspectors + .iter() + .find(|i| i.name() == "permission") + .and_then(|i| i.as_any().downcast_ref::()) + } + + pub fn apply_tool_annotations(&self, tools: &[rmcp::model::Tool]) { + if let Some(inspector) = self.get_permission_inspector() { + inspector.apply_tool_annotations(tools); + } + } + pub async fn update_permission_manager( &self, tool_name: &str, permission_level: crate::config::permission::PermissionLevel, ) { - for inspector in &self.inspectors { - if inspector.name() == "permission" { - // Downcast to PermissionInspector to access permission manager - if let Some(permission_inspector) = - inspector.as_any().downcast_ref::() - { - permission_inspector - .permission_manager - .update_user_permission(tool_name, permission_level); - return; - } - } + if let Some(inspector) = self.get_permission_inspector() { + inspector + .permission_manager + .update_user_permission(tool_name, permission_level); } - tracing::warn!("Permission inspector not found for permission manager update"); } - /// Process inspection results using the permission inspector - /// This delegates to the permission inspector's process_inspection_results method pub fn process_inspection_results_with_permission_inspector( &self, remaining_requests: &[ToolRequest], inspection_results: &[InspectionResult], ) -> Option { - for inspector in &self.inspectors { - if inspector.name() == "permission" { - if let Some(permission_inspector) = - inspector.as_any().downcast_ref::() - { - return Some( - permission_inspector - .process_inspection_results(remaining_requests, inspection_results), - ); - } - } - } - tracing::warn!("Permission inspector not found for processing inspection results"); - None + self.get_permission_inspector().map(|inspector| { + inspector.process_inspection_results(remaining_requests, inspection_results) + }) } } diff --git a/crates/goose/src/tool_monitor.rs b/crates/goose/src/tool_monitor.rs index 10465f739896..96a2c2779ab0 100644 --- a/crates/goose/src/tool_monitor.rs +++ b/crates/goose/src/tool_monitor.rs @@ -98,6 +98,7 @@ impl ToolInspector for RepetitionInspector { async fn inspect( &self, + _session_id: &str, tool_requests: &[ToolRequest], _messages: &[Message], _goose_mode: GooseMode, diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 2cd9f94110e1..437378ba2337 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -407,13 +407,31 @@ impl ProviderTester { if self.name != "codex" { self.test_permission_allow().await?; self.test_permission_deny().await?; + // Agentic CLI providers handle tools internally, SmartApprove == Approve + if !self.is_cli_provider { + self.test_smart_approve_llm_detect().await?; + self.test_smart_approve_readonly().await?; + } } Ok(()) } - async fn run_permission_test(&self, permission: Permission, label: &str) -> Result<()> { + async fn run_permission_test( + &self, + mode: GooseMode, + permission: Permission, + expect_action_required: bool, + message: &str, + label: &str, + ) -> Result<()> { + let mode_str = match mode { + GooseMode::Approve => "approve", + GooseMode::SmartApprove => "smart_approve", + GooseMode::Auto => "auto", + GooseMode::Chat => "chat", + }; // Guard must live through agent.reply() — providers read GOOSE_MODE at spawn time. - let _guard = env_lock::lock_env([("GOOSE_MODE", Some("approve"))]); + let _guard = env_lock::lock_env([("GOOSE_MODE", Some(mode_str))]); let provider = if self.is_cli_provider { create_with_named_model( &self.name.to_lowercase(), @@ -433,7 +451,7 @@ impl ProviderTester { session_manager.clone(), permission_manager, None, - GooseMode::Approve, + mode, true, GoosePlatform::GooseCli, )); @@ -452,8 +470,7 @@ impl ProviderTester { .await .map_err(|e| anyhow::anyhow!("{}", e))?; - let message = - Message::user().with_text("Use the get_code tool and output only its result."); + let message = Message::user().with_text(message); let session_config = SessionConfig { id: session.id, schedule_id: None, @@ -486,19 +503,53 @@ impl ProviderTester { } } - assert!(saw_action_required); + assert_eq!(saw_action_required, expect_action_required); println!("=== {}::{} ===", self.name, label); Ok(()) } async fn test_permission_allow(&self) -> Result<()> { - self.run_permission_test(Permission::AllowOnce, "permission_allow") - .await + self.run_permission_test( + GooseMode::Approve, + Permission::AllowOnce, + true, + "Use the get_code tool and output only its result.", + "permission_allow", + ) + .await } async fn test_permission_deny(&self) -> Result<()> { - self.run_permission_test(Permission::DenyOnce, "permission_deny") - .await + self.run_permission_test( + GooseMode::Approve, + Permission::DenyOnce, + true, + "Use the get_code tool and output only its result.", + "permission_deny", + ) + .await + } + + async fn test_smart_approve_llm_detect(&self) -> Result<()> { + self.run_permission_test( + GooseMode::SmartApprove, + Permission::AllowOnce, + false, + "Use the get_image tool and describe what you see in its result.", + "smart_approve_llm_detect", + ) + .await + } + + async fn test_smart_approve_readonly(&self) -> Result<()> { + self.run_permission_test( + GooseMode::SmartApprove, + Permission::AllowOnce, + false, + "Use the get_code tool and output only its result.", + "smart_approve_readonly", + ) + .await } } diff --git a/crates/goose/tests/tool_inspection_manager_tests.rs b/crates/goose/tests/tool_inspection_manager_tests.rs index af832d7afb8b..c168308dd3e9 100644 --- a/crates/goose/tests/tool_inspection_manager_tests.rs +++ b/crates/goose/tests/tool_inspection_manager_tests.rs @@ -25,6 +25,7 @@ impl ToolInspector for MockInspectorOk { } async fn inspect( &self, + _session_id: &str, _tool_requests: &[ToolRequest], _messages: &[Message], _goose_mode: GooseMode, @@ -43,6 +44,7 @@ impl ToolInspector for MockInspectorErr { } async fn inspect( &self, + _session_id: &str, _tool_requests: &[ToolRequest], _messages: &[Message], _goose_mode: GooseMode, @@ -86,7 +88,12 @@ async fn test_inspect_tools_aggregates_and_handles_errors() { // Act let results = manager - .inspect_tools(&tool_requests, &messages, GooseMode::Approve) + .inspect_tools( + goose_test_support::TEST_SESSION_ID, + &tool_requests, + &messages, + GooseMode::Approve, + ) .await .expect("inspect_tools should not fail when one inspector errors");