From 21476cb70e091782ce979be33b471a507823f036 Mon Sep 17 00:00:00 2001 From: ThorstenHellert Date: Sat, 10 Jan 2026 15:10:41 +0100 Subject: [PATCH 01/14] feat(state): Add unified artifact system with register_artifact API - Add ArtifactType enum (IMAGE, NOTEBOOK, COMMAND, HTML, FILE) - Create register_artifact() as single source of truth - Legacy methods delegate to new API with dual-write pattern - Add populate_legacy_fields_from_artifacts() for backward compatibility - Python capability uses new API directly with clean accumulation --- CHANGELOG.md | 13 + src/osprey/capabilities/python.py | 67 ++- src/osprey/infrastructure/respond_node.py | 35 +- src/osprey/state/__init__.py | 15 + src/osprey/state/artifacts.py | 369 ++++++++++++ src/osprey/state/state.py | 4 + src/osprey/state/state_manager.py | 272 ++++++--- tests/test_artifacts.py | 694 ++++++++++++++++++++++ 8 files changed, 1341 insertions(+), 128 deletions(-) create mode 100644 src/osprey/state/artifacts.py create mode 100644 tests/test_artifacts.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b64a3743..ca3d9adc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to the Osprey Framework will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- **State**: Unified artifact system with `ArtifactType` enum and `register_artifact()` API + - Single source of truth (`ui_artifacts`) for all artifact types: IMAGE, NOTEBOOK, COMMAND, HTML, FILE + - Legacy methods (`register_figure`, `register_notebook`, `register_command`) delegate to new API + - `populate_legacy_fields_from_artifacts()` helper for backward compatibility at finalization + +### Changed +- **Capabilities**: Python capability uses unified `register_artifact()` API directly + - Clean single-accumulation pattern for figures and notebooks + - Legacy fields populated at finalization rather than registration + ## [0.10.1] - 2026-01-09 ### Added diff --git a/src/osprey/capabilities/python.py b/src/osprey/capabilities/python.py index 62a526690..519ea1eba 100644 --- a/src/osprey/capabilities/python.py +++ b/src/osprey/capabilities/python.py @@ -57,7 +57,7 @@ from osprey.registry import get_registry from osprey.services.python_executor import PythonServiceResult from osprey.services.python_executor.models import PlanningMode, PythonExecutionRequest -from osprey.state import StateManager +from osprey.state import ArtifactType, StateManager from osprey.utils.config import get_full_configuration from osprey.utils.logger import get_logger @@ -548,40 +548,38 @@ async def execute(self) -> dict[str, Any]: self._state, "PYTHON_RESULTS", step.get("context_key"), results_context ) - # Register figures in centralized UI registry - figure_updates = {} - if results_context.figure_paths: - # Register figures using StateManager with proper accumulation - accumulating_figures = None # Start with None for first registration - - for figure_path in results_context.figure_paths: - figure_update = StateManager.register_figure( - self._state, - capability="python_executor", - figure_path=str(figure_path), - display_name="Python Execution Figure", - metadata={ - "execution_folder": results_context.folder_path, - "notebook_link": results_context.notebook_link, - "execution_time": results_context.execution_time, - "context_key": step.get("context_key"), - }, - current_figures=accumulating_figures, # Pass accumulating list - ) - # Get the updated list for next iteration - accumulating_figures = figure_update["ui_captured_figures"] + # Register artifacts using unified artifact system + # Single accumulation pattern - clean and simple + artifacts = None - # Final state update with all accumulated figures - figure_updates = figure_update # Last update contains all figures + # Register figures as IMAGE artifacts + for figure_path in results_context.figure_paths: + artifact_update = StateManager.register_artifact( + self._state, + artifact_type=ArtifactType.IMAGE, + capability="python_executor", + data={"path": str(figure_path), "format": figure_path.suffix[1:].lower()}, + display_name="Python Execution Figure", + metadata={ + "execution_folder": results_context.folder_path, + "notebook_link": results_context.notebook_link, + "execution_time": results_context.execution_time, + "context_key": step.get("context_key"), + }, + current_artifacts=artifacts, + ) + artifacts = artifact_update["ui_artifacts"] - # Register notebook in centralized UI registry - notebook_updates = {} + # Register notebook as NOTEBOOK artifact if results_context.notebook_link: - notebook_updates = StateManager.register_notebook( + artifact_update = StateManager.register_artifact( self._state, + artifact_type=ArtifactType.NOTEBOOK, capability="python_executor", - notebook_path=str(results_context.notebook_path), - notebook_link=results_context.notebook_link, + data={ + "path": str(results_context.notebook_path), + "url": results_context.notebook_link, + }, display_name="Python Execution Notebook", metadata={ "execution_folder": results_context.folder_path, @@ -591,13 +589,18 @@ async def execute(self) -> dict[str, Any]: len(results_context.code.split("\n")) if results_context.code else 0 ), }, + current_artifacts=artifacts, ) + artifacts = artifact_update["ui_artifacts"] + + # Build artifact updates (only ui_artifacts - legacy fields populated at finalization) + artifact_updates = {"ui_artifacts": artifacts} if artifacts else {} # Combine all updates if has_approval_resume: - return {**result_updates, **approval_cleanup, **figure_updates, **notebook_updates} + return {**result_updates, **approval_cleanup, **artifact_updates} else: - return {**result_updates, **figure_updates, **notebook_updates} + return {**result_updates, **artifact_updates} @staticmethod def classify_error(exc: Exception, context: dict) -> ErrorClassification: diff --git a/src/osprey/infrastructure/respond_node.py b/src/osprey/infrastructure/respond_node.py index d27f27517..2b12498ee 100644 --- a/src/osprey/infrastructure/respond_node.py +++ b/src/osprey/infrastructure/respond_node.py @@ -20,7 +20,7 @@ from osprey.models import get_chat_completion from osprey.prompts.loader import get_framework_prompts from osprey.registry import get_registry -from osprey.state import AgentState, StateManager +from osprey.state import AgentState, StateManager, populate_legacy_fields_from_artifacts from osprey.utils.config import get_model_config @@ -166,8 +166,16 @@ async def execute(self) -> dict[str, Any]: ) logger.info(f"Generated response for: '{task_objective}'") + # Populate legacy fields from unified artifacts for backward compatibility + # This ensures OpenWebUI and other interfaces can access figures/notebooks/commands + ui_artifacts = state.get("ui_artifacts", []) + legacy_updates = {} + if ui_artifacts: + legacy_updates = populate_legacy_fields_from_artifacts(ui_artifacts) + # Return native LangGraph pattern: AIMessage added to messages list - return {"messages": [AIMessage(content=response_text)]} + # Include legacy field updates for backward compatibility + return {"messages": [AIMessage(content=response_text)], **legacy_updates} except Exception as e: logger.error(f"Error in response generation: {e}") @@ -252,16 +260,23 @@ def _gather_information(state: AgentState, logger=None) -> ResponseContext: if logger: logger.info(f"Using technical response mode (context type: {response_mode})") - # Get figure information from centralized registry - ui_figures = state.get("ui_captured_figures", []) - figures_available = len(ui_figures) + # Populate legacy fields from unified artifacts (for backward compatibility) + # This derives ui_captured_figures, ui_launchable_commands, ui_captured_notebooks + # from the canonical ui_artifacts field + ui_artifacts = state.get("ui_artifacts", []) + if ui_artifacts: + legacy_fields = populate_legacy_fields_from_artifacts(ui_artifacts) + ui_figures = legacy_fields["ui_captured_figures"] + ui_commands = legacy_fields["ui_launchable_commands"] + ui_notebooks = legacy_fields["ui_captured_notebooks"] + else: + # Fall back to direct legacy field access (for old capabilities) + ui_figures = state.get("ui_captured_figures", []) + ui_commands = state.get("ui_launchable_commands", []) + ui_notebooks = state.get("ui_captured_notebooks", []) - # Get command information from centralized registry - ui_commands = state.get("ui_launchable_commands", []) + figures_available = len(ui_figures) commands_available = len(ui_commands) - - # Get notebook information from centralized registry - ui_notebooks = state.get("ui_captured_notebooks", []) notebooks_available = len(ui_notebooks) # Log notebook availability for debugging diff --git a/src/osprey/state/__init__.py b/src/osprey/state/__init__.py index 830bf662f..c67525d1c 100644 --- a/src/osprey/state/__init__.py +++ b/src/osprey/state/__init__.py @@ -83,6 +83,14 @@ :mod:`osprey.infrastructure.gateway` : Main entry point for state processing """ +from .artifacts import ( + Artifact, + ArtifactType, + create_artifact, + get_artifact_type_icon, + populate_legacy_fields_from_artifacts, + validate_artifact_data, +) from .control import AgentControlState, apply_slash_commands_to_agent_control_state from .execution import ApprovalRequest # Keep as dataclass from .messages import ChatHistoryFormatter, MessageUtils, UserMemories @@ -102,6 +110,13 @@ "AgentState", "StateUpdate", "StateManager", + # Artifact system + "Artifact", + "ArtifactType", + "create_artifact", + "get_artifact_type_icon", + "populate_legacy_fields_from_artifacts", + "validate_artifact_data", # Utility functions "create_status_update", "create_progress_event", diff --git a/src/osprey/state/artifacts.py b/src/osprey/state/artifacts.py new file mode 100644 index 000000000..c4c9e2847 --- /dev/null +++ b/src/osprey/state/artifacts.py @@ -0,0 +1,369 @@ +"""Artifact System - Unified artifact management for UI display. + +This module provides a unified abstraction for all types of artifacts that can be +generated by capabilities and displayed in UI interfaces (TUI, OpenWebUI, etc.). + +**Artifact Types:** + +- **image**: Figures, plots, visualizations (PNG, JPG, SVG) +- **notebook**: Jupyter notebooks with execution results +- **command**: Launchable URIs (web apps, desktop apps, viewers) +- **html**: Interactive HTML content (Bokeh, Plotly dashboards) +- **file**: Generic downloadable files + +**Architecture:** + +The artifact system provides a single registration point (`StateManager.register_artifact`) +that replaces the previous separate methods for figures, commands, and notebooks. Each +artifact has a unique ID, type, and type-specific data payload, enabling unified tracking +and "new" artifact detection across conversation turns. + +**Migration:** + +The old `ui_captured_figures`, `ui_launchable_commands`, and `ui_captured_notebooks` +fields are maintained for backward compatibility but will be deprecated. New code +should use `ui_artifacts` exclusively. + +.. seealso:: + :class:`osprey.state.StateManager` : Registration methods for artifacts + :class:`osprey.state.AgentState` : State structure containing ui_artifacts +""" + +from enum import Enum +from typing import Any, TypedDict + + +class ArtifactType(str, Enum): + """Enumeration of supported artifact types. + + Each artifact type corresponds to a specific category of generated content + that can be displayed in UI interfaces. The type determines how the artifact + is rendered and what actions are available. + + Attributes: + IMAGE: Static images and visualizations (PNG, JPG, SVG) + NOTEBOOK: Jupyter notebooks with execution results + COMMAND: Launchable URIs for external applications + HTML: Interactive HTML content (dashboards, widgets) + FILE: Generic downloadable files + """ + + IMAGE = "image" + NOTEBOOK = "notebook" + COMMAND = "command" + HTML = "html" + FILE = "file" + + +class ImageArtifactData(TypedDict, total=False): + """Data payload for image artifacts. + + Attributes: + path: Absolute path to the image file + format: Image format (png, jpg, svg, etc.) + width: Optional width in pixels + height: Optional height in pixels + """ + + path: str + format: str + width: int + height: int + + +class NotebookArtifactData(TypedDict, total=False): + """Data payload for notebook artifacts. + + Attributes: + path: Path to the notebook file (.ipynb) + url: URL to access the notebook (e.g., Jupyter server) + """ + + path: str + url: str + + +class CommandArtifactData(TypedDict, total=False): + """Data payload for command/URI artifacts. + + Attributes: + uri: Launchable URI (http://, file://, custom://) + command_type: Type of command (web_app, desktop_app, viewer, etc.) + """ + + uri: str + command_type: str + + +class HTMLArtifactData(TypedDict, total=False): + """Data payload for interactive HTML artifacts. + + Attributes: + path: Path to the HTML file + url: URL to access the content + framework: Framework used (bokeh, plotly, custom, etc.) + """ + + path: str + url: str + framework: str + + +class FileArtifactData(TypedDict, total=False): + """Data payload for generic file artifacts. + + Attributes: + path: Absolute path to the file + mime_type: MIME type of the file + size_bytes: File size in bytes + """ + + path: str + mime_type: str + size_bytes: int + + +# Union type for all artifact data payloads +ArtifactData = ( + ImageArtifactData + | NotebookArtifactData + | CommandArtifactData + | HTMLArtifactData + | FileArtifactData +) + + +class Artifact(TypedDict, total=False): + """Unified artifact structure for all artifact types. + + This TypedDict defines the common structure for all artifacts stored in + the `ui_artifacts` state field. Each artifact has a unique ID, type, + source capability, and type-specific data payload. + + Required Attributes: + id: Unique identifier (UUID) for the artifact + type: Artifact type from ArtifactType enum + capability: Name of the capability that generated this artifact + created_at: ISO format timestamp of creation + + Optional Attributes: + display_name: Human-readable name for the artifact + data: Type-specific data payload (path, url, etc.) + metadata: Additional capability-specific metadata + + Examples: + Image artifact:: + + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "type": "image", + "capability": "python_executor", + "created_at": "2024-01-15T10:30:00", + "display_name": "Analysis Plot", + "data": {"path": "/path/to/plot.png", "format": "png"}, + "metadata": {"execution_folder": "/path/to/exec"} + } + + Notebook artifact:: + + { + "id": "550e8400-e29b-41d4-a716-446655440001", + "type": "notebook", + "capability": "python_executor", + "created_at": "2024-01-15T10:30:00", + "display_name": "Execution Notebook", + "data": {"path": "/path/to/notebook.ipynb", "url": "http://jupyter/..."} + } + """ + + # Required fields + id: str + type: str # ArtifactType value + capability: str + created_at: str + + # Optional fields + display_name: str + data: dict[str, Any] # Type-specific payload + metadata: dict[str, Any] # Additional metadata + + +def create_artifact( + artifact_type: ArtifactType, + capability: str, + data: dict[str, Any], + display_name: str | None = None, + metadata: dict[str, Any] | None = None, +) -> Artifact: + """Create a new artifact with a unique ID and timestamp. + + This is the factory function for creating properly structured artifacts. + It generates a unique ID and timestamp, and validates the artifact type. + + Args: + artifact_type: Type of artifact from ArtifactType enum + capability: Name of the capability generating this artifact + data: Type-specific data payload (path, url, etc.) + display_name: Optional human-readable name + metadata: Optional additional metadata + + Returns: + Artifact: Properly structured artifact dictionary + + Examples: + Create an image artifact:: + + >>> artifact = create_artifact( + ... ArtifactType.IMAGE, + ... "python_executor", + ... {"path": "/path/to/plot.png", "format": "png"}, + ... display_name="Analysis Plot" + ... ) + >>> artifact["type"] + 'image' + >>> "id" in artifact + True + """ + import uuid + from datetime import datetime + + artifact: Artifact = { + "id": str(uuid.uuid4()), + "type": artifact_type.value, + "capability": capability, + "created_at": datetime.now().isoformat(), + "data": data, + } + + if display_name: + artifact["display_name"] = display_name + + if metadata: + artifact["metadata"] = metadata + + return artifact + + +def get_artifact_type_icon(artifact_type: ArtifactType | str) -> str: + """Get the display icon for an artifact type. + + Args: + artifact_type: Artifact type enum or string value + + Returns: + Unicode icon character for the artifact type + """ + if isinstance(artifact_type, str): + artifact_type = ArtifactType(artifact_type) + + icons = { + ArtifactType.IMAGE: "πŸ–Ό", + ArtifactType.NOTEBOOK: "πŸ““", + ArtifactType.COMMAND: "πŸ”—", + ArtifactType.HTML: "🌐", + ArtifactType.FILE: "πŸ“„", + } + return icons.get(artifact_type, "πŸ“Ž") + + +def validate_artifact_data(artifact_type: ArtifactType, data: dict[str, Any]) -> bool: + """Validate that artifact data contains required fields for its type. + + Args: + artifact_type: Type of artifact to validate + data: Data payload to validate + + Returns: + True if data is valid for the artifact type, False otherwise + """ + required_fields: dict[ArtifactType, list[str]] = { + ArtifactType.IMAGE: ["path"], + ArtifactType.NOTEBOOK: [], # Either path or url is acceptable + ArtifactType.COMMAND: ["uri"], + ArtifactType.HTML: [], # Either path or url is acceptable + ArtifactType.FILE: ["path"], + } + + required = required_fields.get(artifact_type, []) + return all(field in data for field in required) + + +def populate_legacy_fields_from_artifacts( + artifacts: list[Artifact], +) -> dict[str, Any]: + """Populate legacy UI fields from the unified artifacts list. + + This function derives the legacy `ui_captured_figures`, `ui_launchable_commands`, + and `ui_captured_notebooks` fields from the canonical `ui_artifacts` list. + This ensures backward compatibility with interfaces that rely on the old fields + (e.g., OpenWebUI pipeline) while allowing capabilities to use the new unified API. + + Args: + artifacts: List of artifacts from ui_artifacts state field + + Returns: + Dictionary with legacy field updates: + - ui_captured_figures: List of figure entries for IMAGE artifacts + - ui_launchable_commands: List of command entries for COMMAND artifacts + - ui_captured_notebooks: List of notebook URLs for NOTEBOOK artifacts + + Example: + >>> artifacts = [ + ... {"type": "image", "capability": "python", "data": {"path": "/fig.png"}}, + ... {"type": "notebook", "capability": "python", "data": {"url": "http://..."}}, + ... ] + >>> legacy = populate_legacy_fields_from_artifacts(artifacts) + >>> len(legacy["ui_captured_figures"]) + 1 + >>> len(legacy["ui_captured_notebooks"]) + 1 + """ + figures: list[dict[str, Any]] = [] + commands: list[dict[str, Any]] = [] + notebooks: list[str] = [] + + for artifact in artifacts: + artifact_type = artifact.get("type", "") + data = artifact.get("data", {}) + capability = artifact.get("capability", "") + created_at = artifact.get("created_at", "") + display_name = artifact.get("display_name") + metadata = artifact.get("metadata") + + if artifact_type == ArtifactType.IMAGE.value: + # Convert to legacy figure entry + figure_entry: dict[str, Any] = { + "capability": capability, + "figure_path": data.get("path", ""), + "created_at": created_at, + } + if display_name: + figure_entry["display_name"] = display_name + if metadata: + figure_entry["metadata"] = metadata + figures.append(figure_entry) + + elif artifact_type == ArtifactType.COMMAND.value: + # Convert to legacy command entry + command_entry: dict[str, Any] = { + "capability": capability, + "uri": data.get("uri", ""), + "created_at": created_at, + } + if display_name: + command_entry["display_name"] = display_name + if metadata: + command_entry["metadata"] = metadata + commands.append(command_entry) + + elif artifact_type == ArtifactType.NOTEBOOK.value: + # Legacy notebooks field is just a list of URLs + url = data.get("url", "") + if url: + notebooks.append(url) + + return { + "ui_captured_figures": figures, + "ui_launchable_commands": commands, + "ui_captured_notebooks": notebooks, + } diff --git a/src/osprey/state/state.py b/src/osprey/state/state.py index 5964290a8..098556a0a 100644 --- a/src/osprey/state/state.py +++ b/src/osprey/state/state.py @@ -352,6 +352,10 @@ class AgentState(MessagesState): control_validation_timestamp: float | None # UI result fields + ui_artifacts: list[dict[str, Any]] # Unified artifact registry for all UI-displayable content + + # Legacy UI fields (deprecated - use ui_artifacts instead) + # These are maintained for backward compatibility with existing interfaces ui_captured_notebooks: list[ dict[str, Any] ] # Centralized notebook registry for displaying notebooks in the UI diff --git a/src/osprey/state/state_manager.py b/src/osprey/state/state_manager.py index cb066c1a1..bec0d7467 100644 --- a/src/osprey/state/state_manager.py +++ b/src/osprey/state/state_manager.py @@ -61,6 +61,7 @@ from osprey.utils.config import get_agent_control_defaults as _get_agent_control_defaults from osprey.utils.logger import get_logger +from .artifacts import ArtifactType, create_artifact from .messages import MessageUtils from .state import AgentState, StateUpdate @@ -325,6 +326,8 @@ def create_fresh_state(user_input: str, current_state: AgentState | None = None) control_validation_context=None, control_validation_timestamp=None, # UI result fields - reset to defaults + ui_artifacts=[], # Unified artifact registry + # Legacy fields (maintained for backward compatibility) ui_captured_notebooks=[], ui_captured_figures=[], ui_launchable_commands=[], @@ -553,6 +556,97 @@ def get_current_step(state: AgentState) -> PlannedStep: "before routing to capabilities that need step extraction." ) + # ===== UNIFIED ARTIFACT REGISTRATION ===== + + @staticmethod + def register_artifact( + state: AgentState, + artifact_type: ArtifactType, + capability: str, + data: dict[str, Any], + display_name: str | None = None, + metadata: dict[str, Any] | None = None, + current_artifacts: list[dict[str, Any]] | None = None, + ) -> dict[str, Any]: + """Register an artifact in the unified UI artifact registry. + + This is the single entry point for all capabilities to register artifacts + for UI display. It replaces the separate figure/command/notebook registration + methods with a unified interface that supports any artifact type. + + The method creates a properly structured artifact with a unique ID and timestamp, + then appends it to the artifact list. It supports accumulation for registering + multiple artifacts within a single node execution. + + Args: + state: Current agent state + artifact_type: Type of artifact (IMAGE, NOTEBOOK, COMMAND, HTML, FILE) + capability: Name of the capability generating this artifact + data: Type-specific data payload (path, url, uri, etc.) + display_name: Optional human-readable name for the artifact + metadata: Optional additional metadata dictionary + current_artifacts: Optional list to accumulate artifacts (for multiple + registrations within same node). If None, reads from state. + + Returns: + State update dictionary with ui_artifacts update + + Examples: + Register an image artifact:: + + >>> update = StateManager.register_artifact( + ... state, + ... ArtifactType.IMAGE, + ... "python_executor", + ... {"path": "/path/to/plot.png", "format": "png"}, + ... display_name="Analysis Plot" + ... ) + >>> return {**other_updates, **update} + + Register multiple artifacts in one node:: + + >>> accumulating = None + >>> for path in figure_paths: + ... update = StateManager.register_artifact( + ... state, + ... ArtifactType.IMAGE, + ... "python_executor", + ... {"path": str(path), "format": path.suffix[1:]}, + ... current_artifacts=accumulating + ... ) + ... accumulating = update["ui_artifacts"] + >>> return update # Contains all artifacts + + .. seealso:: + :class:`osprey.state.artifacts.ArtifactType` : Available artifact types + :func:`osprey.state.artifacts.create_artifact` : Artifact creation factory + """ + # Create the artifact using the factory function + artifact = create_artifact( + artifact_type=artifact_type, + capability=capability, + data=data, + display_name=display_name, + metadata=metadata, + ) + + # Use provided current_artifacts or get from state + if current_artifacts is not None: + artifacts_list = current_artifacts + else: + artifacts_list = list(state.get("ui_artifacts", [])) + + artifacts_list.append(artifact) + + logger.info( + f"StateManager: registered {artifact_type.value} artifact for {capability}: " + f"{display_name or data.get('path') or data.get('uri') or data.get('url', 'unknown')}" + ) + + return {"ui_artifacts": artifacts_list} + + # ===== LEGACY REGISTRATION METHODS (delegate to register_artifact) ===== + @staticmethod def register_figure( state: AgentState, @@ -561,13 +655,17 @@ def register_figure( display_name: str | None = None, metadata: dict[str, Any] | None = None, current_figures: list[dict[str, Any]] | None = None, + current_artifacts: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: - """ - Register a figure in the centralized UI registry. + """Register a figure in the UI registry. - This is the single point of entry for all capabilities to register figures - for UI display. Provides a capability-agnostic interface that works for - Python, R, Julia, or any other figure-generating capability. + .. deprecated:: + Use :meth:`register_artifact` with ``ArtifactType.IMAGE`` instead. + This method is maintained for backward compatibility. + + This method now delegates to register_artifact() for the unified artifact + system while maintaining the legacy ui_captured_figures field for backward + compatibility with existing interfaces (e.g., OpenWebUI pipeline). Args: state: Current agent state @@ -575,61 +673,58 @@ def register_figure( figure_path: Path to the figure file (absolute or relative) display_name: Optional human-readable figure name metadata: Optional capability-specific metadata dictionary - current_figures: Optional list of current figures to accumulate (otherwise get from state) + current_figures: Optional list of current figures to accumulate + current_artifacts: Optional list of current artifacts to accumulate (for + multiple registrations in same node). If not provided but current_figures + is, artifacts from previous update are retrieved from state. Returns: - State update dictionary with ui_captured_figures update - - Examples: - Basic figure registration:: - - >>> figure_update = StateManager.register_figure( - ... state, "python_executor", "/path/to/plot.png" - ... ) - >>> return {**other_updates, **figure_update} - - Rich figure registration:: - - >>> figure_update = StateManager.register_figure( - ... state, - ... capability="python_executor", - ... figure_path="figures/analysis.png", - ... display_name="Performance Analysis", - ... metadata={ - ... "execution_folder": "/path/to/execution", - ... "notebook_link": "http://jupyter/notebook.ipynb" - ... "figure_type": "matplotlib_png" - ... } - ... ) + State update dictionary with both ui_artifacts and ui_captured_figures updates """ from datetime import datetime + from pathlib import Path + + # Determine format from file extension + path = Path(figure_path) + format_ext = path.suffix[1:].lower() if path.suffix else "unknown" + + # Determine current_artifacts for accumulation + # If current_figures is provided (accumulation mode), we need matching artifacts + if current_artifacts is None: + current_artifacts = list(state.get("ui_artifacts", [])) + + # Register in unified artifact system + artifact_update = StateManager.register_artifact( + state=state, + artifact_type=ArtifactType.IMAGE, + capability=capability, + data={"path": figure_path, "format": format_ext}, + display_name=display_name, + metadata=metadata, + current_artifacts=current_artifacts, + ) - # Create figure entry with required fields + # Also maintain legacy field for backward compatibility figure_entry = { "capability": capability, "figure_path": figure_path, "created_at": datetime.now().isoformat(), } - - # Add optional fields only if provided if display_name: figure_entry["display_name"] = display_name if metadata: figure_entry["metadata"] = metadata - # Use provided current_figures or get from state if current_figures is not None: - # Use the accumulating list (for multiple registrations within same node) figures_list = current_figures else: - # Start from state (for single registration) figures_list = list(state.get("ui_captured_figures", [])) - figures_list.append(figure_entry) logger.info(f"StateManager: prepared figure registration for {capability}: {figure_path}") - return {"ui_captured_figures": figures_list} + # Return both unified and legacy updates + return {**artifact_update, "ui_captured_figures": figures_list} @staticmethod def register_command( @@ -641,58 +736,52 @@ def register_command( metadata: dict[str, Any] | None = None, current_commands: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: - """ - Register a launchable command in the centralized UI registry. + """Register a launchable command in the UI registry. + + .. deprecated:: + Use :meth:`register_artifact` with ``ArtifactType.COMMAND`` instead. + This method is maintained for backward compatibility. - This method allows capabilities to register commands that users can execute - through the UI. Commands are typically external applications, web interfaces, - or desktop tools that can be launched via URIs. + This method now delegates to register_artifact() for the unified artifact + system while maintaining the legacy ui_launchable_commands field for backward + compatibility with existing interfaces. Args: state: Current agent state capability: Name of the capability that generated this command - launch_uri: URI that can be used to launch the command (e.g., http://, file://, custom://) + launch_uri: URI that can be used to launch the command display_name: Optional human-readable name for the command - command_type: Optional type of command (e.g., 'web_app', 'desktop_app', 'viewer') + command_type: Optional type of command (e.g., 'web_app', 'desktop_app') metadata: Optional capability-specific metadata dictionary - current_commands: Optional list of current commands to accumulate (otherwise get from state) + current_commands: Optional list of current commands to accumulate Returns: - State update dictionary with ui_launchable_commands update - - Examples: - Basic command registration:: - - >>> command_update = StateManager.register_command( - ... state, "file_processor", "file:///path/to/results.html" - ... ) - >>> return {**other_updates, **command_update} - - Rich command registration:: - - >>> command_update = StateManager.register_command( - ... state, - ... capability="data_visualizer", - ... launch_uri="http://localhost:8080/dashboard", - ... display_name="Interactive Dashboard", - ... command_type="web_app", - ... metadata={ - ... "port": 8080, - ... "data_source": "analysis_results", - ... "chart_count": 3 - ... } - ... ) + State update dictionary with both ui_artifacts and ui_launchable_commands updates """ from datetime import datetime - # Create command entry with required fields (following register_figure pattern) + # Build data payload for artifact + data: dict[str, Any] = {"uri": launch_uri} + if command_type: + data["command_type"] = command_type + + # Register in unified artifact system + artifact_update = StateManager.register_artifact( + state=state, + artifact_type=ArtifactType.COMMAND, + capability=capability, + data=data, + display_name=display_name, + metadata=metadata, + current_artifacts=list(state.get("ui_artifacts", [])), + ) + + # Also maintain legacy field for backward compatibility command_entry = { "capability": capability, "launch_uri": launch_uri, "created_at": datetime.now().isoformat(), } - - # Add optional fields only if provided if display_name: command_entry["display_name"] = display_name if command_type: @@ -700,21 +789,18 @@ def register_command( if metadata: command_entry["metadata"] = metadata - # Use provided current_commands or get from state if current_commands is not None: - # Use the accumulating list (for multiple registrations within same node) commands_list = current_commands else: - # Start from state (for single registration) commands_list = list(state.get("ui_launchable_commands", [])) - commands_list.append(command_entry) logger.info( f"StateManager: prepared command registration for {capability}: {display_name or launch_uri}" ) - return {"ui_launchable_commands": commands_list} + # Return both unified and legacy updates + return {**artifact_update, "ui_launchable_commands": commands_list} @staticmethod def register_notebook( @@ -725,26 +811,39 @@ def register_notebook( display_name: str | None = None, metadata: dict[str, Any] | None = None, ) -> dict[str, Any]: - """ - Register a notebook in the centralized UI registry. + """Register a notebook in the UI registry. - This method provides notebook registration functionality that was being called - but didn't exist. It follows the same pattern as figure registration. + .. deprecated:: + Use :meth:`register_artifact` with ``ArtifactType.NOTEBOOK`` instead. + This method is maintained for backward compatibility. + + This method now delegates to register_artifact() for the unified artifact + system while maintaining the legacy ui_captured_notebooks field for backward + compatibility with existing interfaces. Args: state: Current agent state - capability: Capability identifier (e.g., "python_executor", "data_analysis") + capability: Capability identifier (e.g., "python_executor") notebook_path: Path to the notebook file notebook_link: Link to access the notebook display_name: Optional human-readable notebook name metadata: Optional capability-specific metadata dictionary Returns: - State update dictionary with ui_captured_notebooks update + State update dictionary with both ui_artifacts and ui_captured_notebooks updates """ + # Register in unified artifact system + artifact_update = StateManager.register_artifact( + state=state, + artifact_type=ArtifactType.NOTEBOOK, + capability=capability, + data={"path": notebook_path, "url": notebook_link}, + display_name=display_name, + metadata=metadata, + current_artifacts=list(state.get("ui_artifacts", [])), + ) - # For now, maintain backward compatibility with simple notebook links - # In the future, this could be enhanced to store structured notebook objects + # Also maintain legacy field for backward compatibility notebook_links = list(state.get("ui_captured_notebooks", [])) notebook_links.append(notebook_link) @@ -752,7 +851,8 @@ def register_notebook( f"StateManager: prepared notebook registration for {capability}: {display_name or notebook_path}" ) - return {"ui_captured_notebooks": notebook_links} + # Return both unified and legacy updates + return {**artifact_update, "ui_captured_notebooks": notebook_links} def get_execution_steps_summary(state: AgentState) -> list[str]: diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py new file mode 100644 index 000000000..75d24d82e --- /dev/null +++ b/tests/test_artifacts.py @@ -0,0 +1,694 @@ +"""Tests for the unified artifact system. + +This module tests the artifact abstraction layer including: +- ArtifactType enum and artifact creation +- StateManager.register_artifact() unified registration +- Legacy method delegation (register_figure/command/notebook) +- Backward compatibility with ui_captured_* fields +""" + +import uuid + +from osprey.state import AgentState, StateManager +from osprey.state.artifacts import ( + ArtifactType, + create_artifact, + get_artifact_type_icon, + validate_artifact_data, +) + + +class TestArtifactType: + """Tests for the ArtifactType enum.""" + + def test_artifact_type_values(self): + """ArtifactType should have expected string values.""" + assert ArtifactType.IMAGE.value == "image" + assert ArtifactType.NOTEBOOK.value == "notebook" + assert ArtifactType.COMMAND.value == "command" + assert ArtifactType.HTML.value == "html" + assert ArtifactType.FILE.value == "file" + + def test_artifact_type_is_string_enum(self): + """ArtifactType should be usable as string.""" + assert str(ArtifactType.IMAGE) == "ArtifactType.IMAGE" + assert ArtifactType.IMAGE.value == "image" + + def test_artifact_type_from_string(self): + """ArtifactType should be constructible from string value.""" + assert ArtifactType("image") == ArtifactType.IMAGE + assert ArtifactType("notebook") == ArtifactType.NOTEBOOK + assert ArtifactType("command") == ArtifactType.COMMAND + + +class TestCreateArtifact: + """Tests for the create_artifact factory function.""" + + def test_creates_artifact_with_required_fields(self): + """create_artifact should populate all required fields.""" + artifact = create_artifact( + artifact_type=ArtifactType.IMAGE, + capability="python_executor", + data={"path": "/path/to/plot.png", "format": "png"}, + ) + + assert "id" in artifact + assert artifact["type"] == "image" + assert artifact["capability"] == "python_executor" + assert "created_at" in artifact + assert artifact["data"] == {"path": "/path/to/plot.png", "format": "png"} + + def test_creates_artifact_with_unique_id(self): + """create_artifact should generate unique UUIDs.""" + artifact1 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/a.png"}) + artifact2 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/b.png"}) + + assert artifact1["id"] != artifact2["id"] + # Verify they are valid UUIDs + uuid.UUID(artifact1["id"]) + uuid.UUID(artifact2["id"]) + + def test_creates_artifact_with_optional_display_name(self): + """create_artifact should include display_name when provided.""" + artifact = create_artifact( + ArtifactType.IMAGE, "test", {"path": "/plot.png"}, display_name="Analysis Plot" + ) + + assert artifact["display_name"] == "Analysis Plot" + + def test_creates_artifact_without_display_name(self): + """create_artifact should omit display_name when not provided.""" + artifact = create_artifact(ArtifactType.IMAGE, "test", {"path": "/plot.png"}) + + assert "display_name" not in artifact + + def test_creates_artifact_with_metadata(self): + """create_artifact should include metadata when provided.""" + metadata = {"execution_folder": "/tmp/exec", "notebook_link": "http://localhost:8888"} + artifact = create_artifact( + ArtifactType.IMAGE, "test", {"path": "/plot.png"}, metadata=metadata + ) + + assert artifact["metadata"] == metadata + + def test_creates_artifact_without_metadata(self): + """create_artifact should omit metadata when not provided.""" + artifact = create_artifact(ArtifactType.IMAGE, "test", {"path": "/plot.png"}) + + assert "metadata" not in artifact + + def test_creates_notebook_artifact(self): + """create_artifact should work for NOTEBOOK type.""" + artifact = create_artifact( + ArtifactType.NOTEBOOK, + "python_executor", + {"path": "/path/to/notebook.ipynb", "url": "http://jupyter/notebook"}, + display_name="Execution Notebook", + ) + + assert artifact["type"] == "notebook" + assert artifact["data"]["path"] == "/path/to/notebook.ipynb" + assert artifact["data"]["url"] == "http://jupyter/notebook" + + def test_creates_command_artifact(self): + """create_artifact should work for COMMAND type.""" + artifact = create_artifact( + ArtifactType.COMMAND, + "dashboard_builder", + {"uri": "http://localhost:8080/dashboard", "command_type": "web_app"}, + display_name="Interactive Dashboard", + ) + + assert artifact["type"] == "command" + assert artifact["data"]["uri"] == "http://localhost:8080/dashboard" + assert artifact["data"]["command_type"] == "web_app" + + def test_creates_html_artifact(self): + """create_artifact should work for HTML type.""" + artifact = create_artifact( + ArtifactType.HTML, + "visualization", + {"path": "/path/to/dashboard.html", "framework": "bokeh"}, + ) + + assert artifact["type"] == "html" + assert artifact["data"]["framework"] == "bokeh" + + def test_creates_file_artifact(self): + """create_artifact should work for FILE type.""" + artifact = create_artifact( + ArtifactType.FILE, + "data_export", + {"path": "/path/to/data.csv", "mime_type": "text/csv", "size_bytes": 1024}, + ) + + assert artifact["type"] == "file" + assert artifact["data"]["mime_type"] == "text/csv" + + +class TestGetArtifactTypeIcon: + """Tests for the get_artifact_type_icon helper function.""" + + def test_returns_correct_icons(self): + """get_artifact_type_icon should return correct icons for each type.""" + assert get_artifact_type_icon(ArtifactType.IMAGE) == "πŸ–Ό" + assert get_artifact_type_icon(ArtifactType.NOTEBOOK) == "πŸ““" + assert get_artifact_type_icon(ArtifactType.COMMAND) == "πŸ”—" + assert get_artifact_type_icon(ArtifactType.HTML) == "🌐" + assert get_artifact_type_icon(ArtifactType.FILE) == "πŸ“„" + + def test_accepts_string_type(self): + """get_artifact_type_icon should accept string type values.""" + assert get_artifact_type_icon("image") == "πŸ–Ό" + assert get_artifact_type_icon("notebook") == "πŸ““" + assert get_artifact_type_icon("command") == "πŸ”—" + + +class TestValidateArtifactData: + """Tests for the validate_artifact_data helper function.""" + + def test_validates_image_data(self): + """validate_artifact_data should require path for IMAGE.""" + assert validate_artifact_data(ArtifactType.IMAGE, {"path": "/plot.png"}) is True + assert validate_artifact_data(ArtifactType.IMAGE, {"url": "http://example.com"}) is False + assert validate_artifact_data(ArtifactType.IMAGE, {}) is False + + def test_validates_command_data(self): + """validate_artifact_data should require uri for COMMAND.""" + assert validate_artifact_data(ArtifactType.COMMAND, {"uri": "http://localhost"}) is True + assert validate_artifact_data(ArtifactType.COMMAND, {"path": "/file"}) is False + assert validate_artifact_data(ArtifactType.COMMAND, {}) is False + + def test_validates_file_data(self): + """validate_artifact_data should require path for FILE.""" + assert validate_artifact_data(ArtifactType.FILE, {"path": "/data.csv"}) is True + assert validate_artifact_data(ArtifactType.FILE, {}) is False + + def test_validates_notebook_data_flexible(self): + """validate_artifact_data should be flexible for NOTEBOOK.""" + # Notebook can have path OR url + assert validate_artifact_data(ArtifactType.NOTEBOOK, {"path": "/nb.ipynb"}) is True + assert validate_artifact_data(ArtifactType.NOTEBOOK, {"url": "http://jupyter"}) is True + assert validate_artifact_data(ArtifactType.NOTEBOOK, {}) is True # Flexible + + def test_validates_html_data_flexible(self): + """validate_artifact_data should be flexible for HTML.""" + # HTML can have path OR url + assert validate_artifact_data(ArtifactType.HTML, {"path": "/dash.html"}) is True + assert validate_artifact_data(ArtifactType.HTML, {"url": "http://dash"}) is True + assert validate_artifact_data(ArtifactType.HTML, {}) is True # Flexible + + +class TestAgentStateArtifacts: + """Tests for ui_artifacts field in AgentState.""" + + def test_agent_state_has_ui_artifacts_field(self): + """AgentState should have ui_artifacts as a valid field.""" + assert "ui_artifacts" in AgentState.__annotations__ + + def test_create_fresh_state_initializes_ui_artifacts(self): + """StateManager.create_fresh_state should initialize ui_artifacts to empty list.""" + state = StateManager.create_fresh_state("Hello") + + assert "ui_artifacts" in state + assert state["ui_artifacts"] == [] + + def test_create_fresh_state_resets_ui_artifacts(self): + """create_fresh_state should reset ui_artifacts each turn.""" + prev_state = { + "ui_artifacts": [{"id": "123", "type": "image"}], + "capability_context_data": {}, + } + + fresh_state = StateManager.create_fresh_state("Hello", current_state=prev_state) + + # ui_artifacts should be reset (not preserved) + assert fresh_state["ui_artifacts"] == [] + + +class TestStateManagerRegisterArtifact: + """Tests for StateManager.register_artifact() unified registration.""" + + def test_register_artifact_creates_valid_artifact(self): + """register_artifact should create properly structured artifact.""" + state = StateManager.create_fresh_state("Test") + + update = StateManager.register_artifact( + state=state, + artifact_type=ArtifactType.IMAGE, + capability="python_executor", + data={"path": "/path/to/plot.png", "format": "png"}, + display_name="Analysis Plot", + ) + + assert "ui_artifacts" in update + assert len(update["ui_artifacts"]) == 1 + + artifact = update["ui_artifacts"][0] + assert artifact["type"] == "image" + assert artifact["capability"] == "python_executor" + assert artifact["data"]["path"] == "/path/to/plot.png" + assert artifact["display_name"] == "Analysis Plot" + assert "id" in artifact + assert "created_at" in artifact + + def test_register_artifact_accumulates(self): + """register_artifact should accumulate multiple artifacts.""" + state = StateManager.create_fresh_state("Test") + + # Register first artifact + update1 = StateManager.register_artifact( + state, ArtifactType.IMAGE, "test", {"path": "/a.png"} + ) + + # Register second using accumulation pattern + update2 = StateManager.register_artifact( + state, + ArtifactType.IMAGE, + "test", + {"path": "/b.png"}, + current_artifacts=update1["ui_artifacts"], + ) + + assert len(update2["ui_artifacts"]) == 2 + assert update2["ui_artifacts"][0]["data"]["path"] == "/a.png" + assert update2["ui_artifacts"][1]["data"]["path"] == "/b.png" + + def test_register_artifact_with_metadata(self): + """register_artifact should include metadata.""" + state = StateManager.create_fresh_state("Test") + metadata = {"execution_folder": "/tmp", "step": 1} + + update = StateManager.register_artifact( + state, ArtifactType.IMAGE, "test", {"path": "/plot.png"}, metadata=metadata + ) + + artifact = update["ui_artifacts"][0] + assert artifact["metadata"] == metadata + + +class TestStateManagerRegisterFigureLegacy: + """Tests for legacy register_figure() method.""" + + def test_register_figure_creates_artifact(self): + """register_figure should create artifact in ui_artifacts.""" + state = StateManager.create_fresh_state("Test") + + update = StateManager.register_figure( + state=state, + capability="python_executor", + figure_path="/path/to/plot.png", + display_name="Test Plot", + ) + + # Should have both unified and legacy fields + assert "ui_artifacts" in update + assert "ui_captured_figures" in update + + # Check unified artifact + artifact = update["ui_artifacts"][0] + assert artifact["type"] == "image" + assert artifact["data"]["path"] == "/path/to/plot.png" + assert artifact["data"]["format"] == "png" + + def test_register_figure_maintains_legacy_format(self): + """register_figure should maintain legacy ui_captured_figures format.""" + state = StateManager.create_fresh_state("Test") + + update = StateManager.register_figure( + state=state, + capability="python_executor", + figure_path="/path/to/plot.png", + display_name="Test Plot", + metadata={"notebook_link": "http://jupyter"}, + ) + + # Check legacy format + figure = update["ui_captured_figures"][0] + assert figure["capability"] == "python_executor" + assert figure["figure_path"] == "/path/to/plot.png" + assert figure["display_name"] == "Test Plot" + assert figure["metadata"]["notebook_link"] == "http://jupyter" + assert "created_at" in figure + + def test_register_figure_detects_format(self): + """register_figure should detect image format from extension.""" + state = StateManager.create_fresh_state("Test") + + # Test PNG + update1 = StateManager.register_figure(state, "test", "/plot.png") + assert update1["ui_artifacts"][0]["data"]["format"] == "png" + + # Test JPG + update2 = StateManager.register_figure(state, "test", "/photo.jpg") + assert update2["ui_artifacts"][0]["data"]["format"] == "jpg" + + # Test SVG + update3 = StateManager.register_figure(state, "test", "/vector.svg") + assert update3["ui_artifacts"][0]["data"]["format"] == "svg" + + +class TestStateManagerRegisterCommandLegacy: + """Tests for legacy register_command() method.""" + + def test_register_command_creates_artifact(self): + """register_command should create artifact in ui_artifacts.""" + state = StateManager.create_fresh_state("Test") + + update = StateManager.register_command( + state=state, + capability="dashboard_builder", + launch_uri="http://localhost:8080", + display_name="Dashboard", + command_type="web_app", + ) + + # Should have both unified and legacy fields + assert "ui_artifacts" in update + assert "ui_launchable_commands" in update + + # Check unified artifact + artifact = update["ui_artifacts"][0] + assert artifact["type"] == "command" + assert artifact["data"]["uri"] == "http://localhost:8080" + assert artifact["data"]["command_type"] == "web_app" + + def test_register_command_maintains_legacy_format(self): + """register_command should maintain legacy ui_launchable_commands format.""" + state = StateManager.create_fresh_state("Test") + + update = StateManager.register_command( + state=state, + capability="dashboard_builder", + launch_uri="http://localhost:8080", + display_name="Dashboard", + command_type="web_app", + ) + + # Check legacy format + command = update["ui_launchable_commands"][0] + assert command["capability"] == "dashboard_builder" + assert command["launch_uri"] == "http://localhost:8080" + assert command["display_name"] == "Dashboard" + assert command["command_type"] == "web_app" + assert "created_at" in command + + +class TestStateManagerRegisterNotebookLegacy: + """Tests for legacy register_notebook() method.""" + + def test_register_notebook_creates_artifact(self): + """register_notebook should create artifact in ui_artifacts.""" + state = StateManager.create_fresh_state("Test") + + update = StateManager.register_notebook( + state=state, + capability="python_executor", + notebook_path="/path/to/notebook.ipynb", + notebook_link="http://jupyter/notebook", + display_name="Execution Notebook", + ) + + # Should have both unified and legacy fields + assert "ui_artifacts" in update + assert "ui_captured_notebooks" in update + + # Check unified artifact + artifact = update["ui_artifacts"][0] + assert artifact["type"] == "notebook" + assert artifact["data"]["path"] == "/path/to/notebook.ipynb" + assert artifact["data"]["url"] == "http://jupyter/notebook" + + def test_register_notebook_maintains_legacy_format(self): + """register_notebook should maintain legacy ui_captured_notebooks format.""" + state = StateManager.create_fresh_state("Test") + + update = StateManager.register_notebook( + state=state, + capability="python_executor", + notebook_path="/path/to/notebook.ipynb", + notebook_link="http://jupyter/notebook", + ) + + # Legacy format is just the link string + assert update["ui_captured_notebooks"][0] == "http://jupyter/notebook" + + +class TestArtifactAccumulationPatterns: + """Tests for proper artifact accumulation patterns.""" + + def test_multiple_artifacts_in_single_node(self): + """Should support registering multiple artifacts in a single node.""" + state = StateManager.create_fresh_state("Test") + + # Pattern: accumulate artifacts within single node + accumulating = None + figure_paths = ["/a.png", "/b.png", "/c.png"] + + for path in figure_paths: + update = StateManager.register_artifact( + state, + ArtifactType.IMAGE, + "python_executor", + {"path": path, "format": "png"}, + current_artifacts=accumulating, + ) + accumulating = update["ui_artifacts"] + + # Final update should contain all artifacts + assert len(update["ui_artifacts"]) == 3 + + def test_mixed_artifact_types(self): + """Should support different artifact types in same execution.""" + state = StateManager.create_fresh_state("Test") + + # Register image + update1 = StateManager.register_artifact( + state, ArtifactType.IMAGE, "test", {"path": "/plot.png"} + ) + + # Register notebook + update2 = StateManager.register_artifact( + state, + ArtifactType.NOTEBOOK, + "test", + {"path": "/nb.ipynb", "url": "http://jupyter"}, + current_artifacts=update1["ui_artifacts"], + ) + + # Register command + update3 = StateManager.register_artifact( + state, + ArtifactType.COMMAND, + "test", + {"uri": "http://dashboard"}, + current_artifacts=update2["ui_artifacts"], + ) + + assert len(update3["ui_artifacts"]) == 3 + types = [a["type"] for a in update3["ui_artifacts"]] + assert types == ["image", "notebook", "command"] + + +class TestBackwardCompatibility: + """Tests to ensure backward compatibility with existing code.""" + + def test_legacy_figure_registration_still_works(self): + """Existing code using register_figure should continue to work.""" + state = StateManager.create_fresh_state("Test") + + # This is the existing pattern in capabilities + update = StateManager.register_figure( + state, + capability="python_executor", + figure_path="/path/to/plot.png", + display_name="Test", + metadata={"execution_folder": "/tmp"}, + ) + + # Old code expects ui_captured_figures + assert "ui_captured_figures" in update + assert len(update["ui_captured_figures"]) == 1 + + def test_legacy_command_registration_still_works(self): + """Existing code using register_command should continue to work.""" + state = StateManager.create_fresh_state("Test") + + update = StateManager.register_command( + state, + capability="dashboard", + launch_uri="http://localhost:8080", + display_name="Dashboard", + ) + + # Old code expects ui_launchable_commands + assert "ui_launchable_commands" in update + assert len(update["ui_launchable_commands"]) == 1 + + def test_legacy_notebook_registration_still_works(self): + """Existing code using register_notebook should continue to work.""" + state = StateManager.create_fresh_state("Test") + + update = StateManager.register_notebook( + state, + capability="python_executor", + notebook_path="/nb.ipynb", + notebook_link="http://jupyter", + ) + + # Old code expects ui_captured_notebooks + assert "ui_captured_notebooks" in update + assert len(update["ui_captured_notebooks"]) == 1 + + +class TestPopulateLegacyFieldsFromArtifacts: + """Tests for populate_legacy_fields_from_artifacts() finalization helper.""" + + def test_populates_figures_from_image_artifacts(self): + """Should convert IMAGE artifacts to ui_captured_figures format.""" + from osprey.state import populate_legacy_fields_from_artifacts + + artifacts = [ + { + "id": "1", + "type": "image", + "capability": "python_executor", + "created_at": "2024-01-15T10:00:00", + "data": {"path": "/path/to/plot.png", "format": "png"}, + "display_name": "Analysis Plot", + "metadata": {"folder": "/tmp"}, + } + ] + + legacy = populate_legacy_fields_from_artifacts(artifacts) + + assert len(legacy["ui_captured_figures"]) == 1 + figure = legacy["ui_captured_figures"][0] + assert figure["capability"] == "python_executor" + assert figure["figure_path"] == "/path/to/plot.png" + assert figure["created_at"] == "2024-01-15T10:00:00" + assert figure["display_name"] == "Analysis Plot" + assert figure["metadata"] == {"folder": "/tmp"} + + def test_populates_commands_from_command_artifacts(self): + """Should convert COMMAND artifacts to ui_launchable_commands format.""" + from osprey.state import populate_legacy_fields_from_artifacts + + artifacts = [ + { + "id": "2", + "type": "command", + "capability": "dashboard", + "created_at": "2024-01-15T10:00:00", + "data": {"uri": "http://localhost:8080", "command_type": "web_app"}, + "display_name": "Dashboard", + } + ] + + legacy = populate_legacy_fields_from_artifacts(artifacts) + + assert len(legacy["ui_launchable_commands"]) == 1 + cmd = legacy["ui_launchable_commands"][0] + assert cmd["capability"] == "dashboard" + assert cmd["uri"] == "http://localhost:8080" + assert cmd["display_name"] == "Dashboard" + + def test_populates_notebooks_from_notebook_artifacts(self): + """Should convert NOTEBOOK artifacts to ui_captured_notebooks format (URL list).""" + from osprey.state import populate_legacy_fields_from_artifacts + + artifacts = [ + { + "id": "3", + "type": "notebook", + "capability": "python_executor", + "created_at": "2024-01-15T10:00:00", + "data": {"path": "/nb.ipynb", "url": "http://jupyter/notebook"}, + } + ] + + legacy = populate_legacy_fields_from_artifacts(artifacts) + + # Legacy format is just URL strings + assert len(legacy["ui_captured_notebooks"]) == 1 + assert legacy["ui_captured_notebooks"][0] == "http://jupyter/notebook" + + def test_handles_mixed_artifact_types(self): + """Should properly separate mixed artifact types into legacy fields.""" + from osprey.state import populate_legacy_fields_from_artifacts + + artifacts = [ + { + "id": "1", + "type": "image", + "capability": "test", + "created_at": "", + "data": {"path": "/a.png"}, + }, + { + "id": "2", + "type": "notebook", + "capability": "test", + "created_at": "", + "data": {"url": "http://nb1"}, + }, + { + "id": "3", + "type": "image", + "capability": "test", + "created_at": "", + "data": {"path": "/b.png"}, + }, + { + "id": "4", + "type": "command", + "capability": "test", + "created_at": "", + "data": {"uri": "http://cmd"}, + }, + { + "id": "5", + "type": "notebook", + "capability": "test", + "created_at": "", + "data": {"url": "http://nb2"}, + }, + ] + + legacy = populate_legacy_fields_from_artifacts(artifacts) + + assert len(legacy["ui_captured_figures"]) == 2 + assert len(legacy["ui_captured_notebooks"]) == 2 + assert len(legacy["ui_launchable_commands"]) == 1 + + def test_handles_empty_artifacts_list(self): + """Should return empty legacy fields for empty artifacts list.""" + from osprey.state import populate_legacy_fields_from_artifacts + + legacy = populate_legacy_fields_from_artifacts([]) + + assert legacy["ui_captured_figures"] == [] + assert legacy["ui_launchable_commands"] == [] + assert legacy["ui_captured_notebooks"] == [] + + def test_handles_artifacts_without_optional_fields(self): + """Should handle artifacts missing optional display_name and metadata.""" + from osprey.state import populate_legacy_fields_from_artifacts + + artifacts = [ + { + "id": "1", + "type": "image", + "capability": "test", + "created_at": "2024-01-15T10:00:00", + "data": {"path": "/plot.png"}, + # No display_name or metadata + } + ] + + legacy = populate_legacy_fields_from_artifacts(artifacts) + + figure = legacy["ui_captured_figures"][0] + assert "display_name" not in figure + assert "metadata" not in figure + assert figure["figure_path"] == "/plot.png" From 430276610614635db40f70318b635b86bcc40408 Mon Sep 17 00:00:00 2001 From: ThorstenHellert Date: Sat, 10 Jan 2026 15:11:16 +0100 Subject: [PATCH 02/14] feat(tui): Add artifact gallery and viewer widgets - ArtifactGallery with keyboard navigation (Ctrl+a, j/k, Enter, o, Esc) - ArtifactViewer modal with type-specific details and external open - Native image rendering via textual-image (Sixel/Kitty protocols) - New/seen tracking with [NEW] badges for current turn artifacts - Integration with ChatDisplay for artifact state management --- CHANGELOG.md | 5 + src/osprey/interfaces/tui/app.py | 34 ++ src/osprey/interfaces/tui/styles.tcss | 191 +++++++++ src/osprey/interfaces/tui/widgets/__init__.py | 20 +- .../interfaces/tui/widgets/artifact_viewer.py | 368 +++++++++++++++++ .../interfaces/tui/widgets/artifacts.py | 335 ++++++++++++++++ .../interfaces/tui/widgets/chat_display.py | 69 ++++ tests/test_tui_artifacts.py | 376 ++++++++++++++++++ 8 files changed, 1394 insertions(+), 4 deletions(-) create mode 100644 src/osprey/interfaces/tui/widgets/artifact_viewer.py create mode 100644 src/osprey/interfaces/tui/widgets/artifacts.py create mode 100644 tests/test_tui_artifacts.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ca3d9adc5..d5563f62b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Single source of truth (`ui_artifacts`) for all artifact types: IMAGE, NOTEBOOK, COMMAND, HTML, FILE - Legacy methods (`register_figure`, `register_notebook`, `register_command`) delegate to new API - `populate_legacy_fields_from_artifacts()` helper for backward compatibility at finalization +- **TUI**: Artifact gallery and viewer widgets for interactive artifact browsing + - ArtifactGallery with keyboard navigation (Ctrl+a focus, j/k navigate, Enter view, o open external) + - ArtifactViewer modal with type-specific details and actions (copy path, open in system app) + - Native image rendering via textual-image (Sixel for iTerm2/WezTerm, Kitty Graphics Protocol) + - New/seen tracking with [NEW] badges for artifacts from current turn ### Changed - **Capabilities**: Python capability uses unified `register_artifact()` API directly diff --git a/src/osprey/interfaces/tui/app.py b/src/osprey/interfaces/tui/app.py index dddd49477..337273d77 100644 --- a/src/osprey/interfaces/tui/app.py +++ b/src/osprey/interfaces/tui/app.py @@ -20,6 +20,8 @@ from osprey.interfaces.tui.constants import EXEC_STEP_PATTERN, TASK_PREP_COMPONENTS from osprey.interfaces.tui.handlers import QueueLogHandler from osprey.interfaces.tui.widgets import ( + ArtifactItem, + ArtifactViewer, ChatDisplay, ChatInput, ClassificationBlock, @@ -72,6 +74,8 @@ class OspreyTUI(App): Binding("b", "scroll_up", "Scroll up", show=False), Binding("g", "scroll_home", "Go to top", show=False), Binding("G", "scroll_end_chat", "Go to bottom", show=False), + # Artifact gallery (priority=True to override TextArea's select-all) + Binding("ctrl+a", "focus_artifacts", "Artifacts", priority=True), ] def __init__(self, config_path: str = "config.yml"): @@ -229,6 +233,14 @@ def action_toggle_help_panel(self) -> None: except NoMatches: self.screen.mount(HelpPanel()) + def on_artifact_item_selected(self, event: ArtifactItem.Selected) -> None: + """Handle artifact selection - open the artifact viewer modal. + + Args: + event: The artifact selection event containing the artifact data. + """ + self.push_screen(ArtifactViewer(event.artifact)) + def action_exit_app(self) -> None: """Exit the application.""" self.exit() @@ -257,6 +269,20 @@ def action_scroll_end_chat(self) -> None: chat = self.query_one("#chat-display", ChatDisplay) chat.scroll_end(animate=False) + def action_focus_artifacts(self) -> None: + """Focus the artifact gallery for keyboard navigation.""" + try: + chat_display = self.query_one("#chat-display", ChatDisplay) + gallery = chat_display.get_artifact_gallery() + if gallery and gallery.display: + gallery.focus() + # Scroll to make gallery visible + chat_display.scroll_to_widget(gallery) + else: + self.notify("No artifacts available", severity="information") + except Exception: + self.notify("No artifacts available", severity="information") + def _get_version(self) -> str: """Get the framework version.""" try: @@ -421,6 +447,9 @@ def _cmd_clear(self) -> None: chat_display._component_attempt_index = {} chat_display._retry_triggered = set() chat_display._pending_messages = {} + # Reset artifact gallery + chat_display._artifact_gallery = None + chat_display.clear_artifact_history() def _cmd_help(self, option: str | None) -> None: """Show help for commands. @@ -1041,6 +1070,11 @@ async def process_with_agent(self, user_input: str) -> None: # Show final response self._show_final_response(state.values, chat_display) + # Show artifacts AFTER the response (so they appear below) + artifacts = state.values.get("ui_artifacts", []) + if artifacts: + chat_display.update_artifacts(artifacts) + except Exception as e: chat_display.add_message(f"Error: {e}", "assistant", message_type="agent") diff --git a/src/osprey/interfaces/tui/styles.tcss b/src/osprey/interfaces/tui/styles.tcss index 10ab69b14..9e36c88f1 100644 --- a/src/osprey/interfaces/tui/styles.tcss +++ b/src/osprey/interfaces/tui/styles.tcss @@ -926,3 +926,194 @@ ThemePicker { color: $text-muted; background: transparent; } + +/* ============================================================================= + Artifact Gallery + ============================================================================= */ + +ArtifactGallery { + width: 100%; + height: auto; + margin: 0 2 1 2; + padding: 0 2; + display: none; /* Hidden until populated */ + &:dark { + border: tall $surface-lighten-1; + } + &:light { + border: tall $surface-darken-1; + } +} + +#gallery-header { + width: 100%; + height: 1; + color: $text; + margin-bottom: 1; +} + +#gallery-items { + width: 100%; + height: auto; +} + +#gallery-footer { + width: 100%; + height: 1; + color: $text-muted; + margin-top: 1; + text-align: right; +} + +/* Individual artifact item */ +ArtifactItem { + width: 100%; + height: auto; + padding: 0 1; + margin-bottom: 1; +} + +ArtifactItem:hover { + background: $surface; +} + +ArtifactItem.artifact-selected { + background: $primary-muted; +} + +/* Gallery focus state - show border when focused */ +ArtifactGallery:focus { + border: round $primary; +} + +ArtifactGallery.gallery-focused { + border: round $primary; +} + +/* Selected item gets a left indicator when gallery is focused */ +ArtifactGallery:focus ArtifactItem.artifact-selected { + background: $primary-muted; + border-left: thick $primary; +} + +#artifact-name { + width: 100%; + color: $text; +} + +#artifact-meta, +.artifact-meta { + width: 100%; + color: $text-muted; +} + +/* ============================================================================= + Artifact Viewer Modal + ============================================================================= */ + +ArtifactViewer { + align: center middle; + background: $background 80%; +} + +#artifact-viewer-container { + width: 110; + height: auto; + max-height: 90%; + margin: 0; + padding: 1 4; + background: $surface; + border: none; +} + +#artifact-viewer-header { + width: 100%; + height: auto; + margin-bottom: 1; +} + +#artifact-viewer-title { + width: auto; + text-style: bold; +} + +#artifact-header-spacer { + width: 1fr; +} + +#artifact-viewer-dismiss-hint { + width: auto; + color: $text-muted; +} + +#artifact-viewer-content { + width: 100%; + height: auto; + margin-bottom: 1; +} + +#image-preview { + width: 100%; + height: auto; +} + +#artifact-viewer-actions { + width: 100%; + height: auto; + margin: 1 0; +} + +#artifact-viewer-actions Button { + margin-right: 1; +} + +#artifact-viewer-footer { + width: 100%; + height: 1; + text-align: right; + color: $text-muted; +} + +/* Detail rows in artifact viewer */ +.detail-row { + width: 100%; + height: auto; + margin-bottom: 0; +} + +.detail-label { + width: 100%; + color: $text; + margin-top: 1; +} + +.detail-value { + width: 100%; + color: $text-muted; + margin-bottom: 1; +} + +.detail-path { + color: $primary; +} + +.detail-url { + color: $accent; + text-style: underline; +} + +.detail-separator { + width: 100%; + height: 1; + color: $surface; + margin: 1 0; +} + +.detail-section-header { + color: $text; + margin-top: 1; +} + +.detail-metadata { + color: $text-muted; +} diff --git a/src/osprey/interfaces/tui/widgets/__init__.py b/src/osprey/interfaces/tui/widgets/__init__.py index 0b2342217..f149b9608 100644 --- a/src/osprey/interfaces/tui/widgets/__init__.py +++ b/src/osprey/interfaces/tui/widgets/__init__.py @@ -1,5 +1,7 @@ """TUI Widget components.""" +from osprey.interfaces.tui.widgets.artifact_viewer import ArtifactViewer +from osprey.interfaces.tui.widgets.artifacts import ArtifactGallery, ArtifactItem from osprey.interfaces.tui.widgets.blocks import ( ClassificationBlock, ClassificationStep, @@ -34,12 +36,22 @@ from osprey.interfaces.tui.widgets.welcome import WelcomeBanner, WelcomeScreen __all__ = [ + # Artifact widgets + "ArtifactGallery", + "ArtifactItem", + "ArtifactViewer", + # Chat and display "ChatMessage", + "ChatDisplay", + "ChatInput", + # Modals "CommandPalette", "ContentViewer", + "LogViewer", + "ThemePicker", + # Processing blocks and steps "DebugBlock", "LogsLink", - "LogViewer", "ProcessingBlock", "ProcessingStep", "PromptLink", @@ -54,13 +66,13 @@ "TodoList", "TodoUpdateStep", "ExecutionStep", - "ChatDisplay", - "ChatInput", + # Input widgets "StatusPanel", "CommandDropdown", - "ThemePicker", + # Welcome screen "WelcomeBanner", "WelcomeScreen", + # Utilities "WrappedLabel", "WrappedStatic", ] diff --git a/src/osprey/interfaces/tui/widgets/artifact_viewer.py b/src/osprey/interfaces/tui/widgets/artifact_viewer.py new file mode 100644 index 000000000..e8f25ad19 --- /dev/null +++ b/src/osprey/interfaces/tui/widgets/artifact_viewer.py @@ -0,0 +1,368 @@ +"""Artifact Viewer modal for displaying artifact details and actions.""" + +from __future__ import annotations + +# Try to import textual-image for native image rendering +# Only use on terminals with confirmed graphics protocol support (no pixelated fallback) +import os +import platform +import subprocess +from datetime import datetime +from pathlib import Path +from typing import Any + +from textual.app import ComposeResult +from textual.containers import Container, Horizontal, ScrollableContainer +from textual.events import Key +from textual.screen import ModalScreen +from textual.widgets import Button, Static + +from osprey.state.artifacts import ArtifactType, get_artifact_type_icon + +_TERM = os.environ.get("TERM_PROGRAM", "").lower() +_IS_KITTY = "kitty" in _TERM or "KITTY_WINDOW_ID" in os.environ +_IS_ITERM = "iterm" in _TERM or "ITERM_SESSION_ID" in os.environ +_IS_WEZTERM = "wezterm" in _TERM or "WEZTERM_PANE" in os.environ + +try: + from textual_image.widget import SixelImage, TGPImage + + # Pick renderer based on terminal - only for supported terminals + if _IS_KITTY: + TextualImage = TGPImage + elif _IS_ITERM or _IS_WEZTERM: + TextualImage = SixelImage + else: + TextualImage = None # No fallback to pixelated AutoImage + +except ImportError: + TextualImage = None + + +class ArtifactViewer(ModalScreen[None]): + """Modal screen for viewing artifact details with type-specific actions. + + Displays artifact metadata and provides actions like: + - Open in external application + - Copy path/URL to clipboard + - Navigate to related resources + """ + + BINDINGS = [ + ("escape", "dismiss_viewer", "Close"), + ("o", "open_external", "Open External"), + ("c", "copy_path", "Copy Path"), + ] + + AUTO_FOCUS = "#artifact-viewer-content" + + def __init__(self, artifact: dict[str, Any], auto_open: bool = True) -> None: + """Initialize the artifact viewer. + + Args: + artifact: The artifact dictionary to display + auto_open: Whether to auto-open viewable artifacts (images, notebooks) + """ + super().__init__() + self.artifact = artifact + self._artifact_type = ArtifactType(artifact.get("type", "file")) + self._auto_open = auto_open + self._opened_externally = False + + def on_mount(self) -> None: + """Auto-open viewable artifacts on mount.""" + if self._auto_open and self._should_auto_open(): + self._open_external_silent() + self._opened_externally = True + + def _should_auto_open(self) -> bool: + """Check if this artifact type should auto-open. + + Currently disabled - users can press 'o' to open externally. + """ + return False + + def _open_external_silent(self) -> None: + """Open externally without notifications (for auto-open).""" + target = self._get_openable_target() + if not target: + return + + try: + system = platform.system() + if system == "Darwin": # macOS + subprocess.Popen(["open", target]) + elif system == "Linux": + subprocess.Popen(["xdg-open", target]) + elif system == "Windows": + subprocess.Popen(["start", target], shell=True) + except Exception: + pass # Silent failure for auto-open + + def compose(self) -> ComposeResult: + """Compose the artifact viewer layout.""" + with Container(id="artifact-viewer-container"): + # Header + with Horizontal(id="artifact-viewer-header"): + icon = get_artifact_type_icon(self._artifact_type) + title = self.artifact.get("display_name") or self._get_default_name() + yield Static(f"{icon} {title}", id="artifact-viewer-title") + yield Static("", id="artifact-header-spacer") + yield Static("esc", id="artifact-viewer-dismiss-hint") + + # Content area + with ScrollableContainer(id="artifact-viewer-content"): + yield from self._compose_details() + + # Action buttons + with Horizontal(id="artifact-viewer-actions"): + yield Button("Open External (o)", id="btn-open-external", variant="primary") + yield Button("Copy Path (c)", id="btn-copy-path") + + # Footer with hints + yield Static( + "[dim]o[/] open Β· [dim]c[/] copy path Β· [dim]esc[/] close", + id="artifact-viewer-footer", + ) + + def _get_default_name(self) -> str: + """Get default name based on artifact type and data.""" + data = self.artifact.get("data", {}) + if self._artifact_type == ArtifactType.IMAGE: + path = data.get("path", "") + return path.split("/")[-1] if path else "Figure" + elif self._artifact_type == ArtifactType.NOTEBOOK: + path = data.get("path", "") + return path.split("/")[-1] if path else "Notebook" + elif self._artifact_type == ArtifactType.COMMAND: + return data.get("command_type", "Command") + elif self._artifact_type == ArtifactType.HTML: + return data.get("framework", "Interactive Content") + elif self._artifact_type == ArtifactType.FILE: + path = data.get("path", "") + return path.split("/")[-1] if path else "File" + return "Artifact" + + def _compose_details(self) -> ComposeResult: + """Compose artifact details based on type.""" + data = self.artifact.get("data", {}) + metadata = self.artifact.get("metadata", {}) + capability = self.artifact.get("capability", "unknown") + created_at = self.artifact.get("created_at", "") + + # Type badge + type_display = self._artifact_type.value.upper() + yield Static(f"[bold]Type:[/] {type_display}", classes="detail-row") + + # Capability + yield Static(f"[bold]Source:[/] {capability}", classes="detail-row") + + # Timestamp + if created_at: + try: + dt = datetime.fromisoformat(created_at) + formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, TypeError): + formatted_time = created_at + yield Static(f"[bold]Created:[/] {formatted_time}", classes="detail-row") + + # Separator + yield Static("─" * 50, classes="detail-separator") + + # Type-specific details + yield from self._compose_type_specific_details(data) + + # Metadata section (if any) + if metadata: + yield Static("") + yield Static("[bold]Metadata:[/]", classes="detail-section-header") + for key, value in metadata.items(): + # Truncate long values + value_str = str(value) + if len(value_str) > 60: + value_str = value_str[:57] + "..." + yield Static(f" {key}: {value_str}", classes="detail-metadata") + + def _compose_type_specific_details(self, data: dict[str, Any]) -> ComposeResult: + """Compose type-specific detail rows.""" + if self._artifact_type == ArtifactType.IMAGE: + path = data.get("path", "N/A") + format_ext = data.get("format", "unknown") + + # Show hint about opening externally (always useful) + yield Static( + "[dim]Press [/]o[dim] to open in system viewer[/]", + classes="detail-row image-hint", + ) + yield Static("") + + # Try to render the image inline using native terminal graphics + # Only supported on modern terminals (Kitty, iTerm2, WezTerm) - no pixelated fallback + if path and path != "N/A": + image_path = Path(path) + if not image_path.exists(): + yield Static( + "[dim]Image file not found[/]", + classes="detail-row image-fallback", + ) + elif TextualImage is not None: + # Native graphics support available + yield TextualImage(path, id="image-preview") + yield Static("") + else: + # No native graphics support - show helpful message + yield Static( + "[dim]Inline preview requires a modern terminal:[/]", + classes="detail-row image-fallback", + ) + yield Static( + "[dim] iTerm2, Kitty, or WezTerm + textual-image[/]", + classes="detail-row image-fallback", + ) + yield Static("") + + yield Static("[bold]Path:[/]", classes="detail-label") + yield Static(f" {path}", classes="detail-value detail-path") + yield Static(f"[bold]Format:[/] {format_ext.upper()}", classes="detail-row") + if "width" in data and "height" in data: + yield Static( + f"[bold]Dimensions:[/] {data['width']}x{data['height']}", + classes="detail-row", + ) + + elif self._artifact_type == ArtifactType.NOTEBOOK: + if "path" in data: + yield Static("[bold]Path:[/]", classes="detail-label") + yield Static(f" {data['path']}", classes="detail-value detail-path") + if "url" in data: + yield Static("[bold]URL:[/]", classes="detail-label") + yield Static(f" {data['url']}", classes="detail-value detail-url") + + elif self._artifact_type == ArtifactType.COMMAND: + uri = data.get("uri", "N/A") + command_type = data.get("command_type", "unknown") + yield Static("[bold]URI:[/]", classes="detail-label") + yield Static(f" {uri}", classes="detail-value detail-url") + yield Static(f"[bold]Command Type:[/] {command_type}", classes="detail-row") + + elif self._artifact_type == ArtifactType.HTML: + if "path" in data: + yield Static("[bold]Path:[/]", classes="detail-label") + yield Static(f" {data['path']}", classes="detail-value detail-path") + if "url" in data: + yield Static("[bold]URL:[/]", classes="detail-label") + yield Static(f" {data['url']}", classes="detail-value detail-url") + if "framework" in data: + yield Static(f"[bold]Framework:[/] {data['framework']}", classes="detail-row") + + elif self._artifact_type == ArtifactType.FILE: + path = data.get("path", "N/A") + mime_type = data.get("mime_type", "unknown") + yield Static("[bold]Path:[/]", classes="detail-label") + yield Static(f" {path}", classes="detail-value detail-path") + yield Static(f"[bold]MIME Type:[/] {mime_type}", classes="detail-row") + if "size_bytes" in data: + size = self._format_size(data["size_bytes"]) + yield Static(f"[bold]Size:[/] {size}", classes="detail-row") + + def _format_size(self, size_bytes: int) -> str: + """Format file size in human-readable form.""" + for unit in ["B", "KB", "MB", "GB"]: + if size_bytes < 1024: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024 + return f"{size_bytes:.1f} TB" + + def _get_openable_target(self) -> str | None: + """Get the path or URL that can be opened externally.""" + data = self.artifact.get("data", {}) + + if self._artifact_type in (ArtifactType.IMAGE, ArtifactType.FILE): + return data.get("path") + elif self._artifact_type == ArtifactType.HTML: + return data.get("url") or data.get("path") + elif self._artifact_type == ArtifactType.NOTEBOOK: + return data.get("url") or data.get("path") + elif self._artifact_type == ArtifactType.COMMAND: + return data.get("uri") + return None + + def _get_copyable_path(self) -> str | None: + """Get the path or URL that can be copied.""" + data = self.artifact.get("data", {}) + return data.get("path") or data.get("url") or data.get("uri") + + def on_key(self, event: Key) -> None: + """Handle key events.""" + if event.key == "enter": + self.dismiss(None) + event.stop() + elif event.key == "space": + container = self.query_one("#artifact-viewer-content", ScrollableContainer) + container.scroll_page_down(animate=False) + event.stop() + elif event.key == "b": + container = self.query_one("#artifact-viewer-content", ScrollableContainer) + container.scroll_page_up(animate=False) + event.stop() + + def on_button_pressed(self, event: Button.Pressed) -> None: + """Handle button presses.""" + if event.button.id == "btn-open-external": + self.action_open_external() + elif event.button.id == "btn-copy-path": + self.action_copy_path() + + def action_dismiss_viewer(self) -> None: + """Dismiss the artifact viewer.""" + self.dismiss(None) + + def action_open_external(self) -> None: + """Open the artifact in the system's default application.""" + target = self._get_openable_target() + if not target: + self.notify("No path or URL available to open", severity="warning") + return + + try: + system = platform.system() + if system == "Darwin": # macOS + subprocess.Popen(["open", target]) + elif system == "Linux": + subprocess.Popen(["xdg-open", target]) + elif system == "Windows": + subprocess.Popen(["start", target], shell=True) + self.notify(f"Opening: {target[:50]}...") + except Exception as e: + self.notify(f"Failed to open: {e}", severity="error") + + def action_copy_path(self) -> None: + """Copy the artifact path/URL to clipboard.""" + path = self._get_copyable_path() + if not path: + self.notify("No path available to copy", severity="warning") + return + + try: + system = platform.system() + if system == "Darwin": # macOS + subprocess.run(["pbcopy"], input=path.encode(), check=True) + elif system == "Linux": + # Try xclip first, then xsel + try: + subprocess.run( + ["xclip", "-selection", "clipboard"], + input=path.encode(), + check=True, + ) + except FileNotFoundError: + subprocess.run( + ["xsel", "--clipboard", "--input"], + input=path.encode(), + check=True, + ) + elif system == "Windows": + subprocess.run(["clip"], input=path.encode(), check=True) + self.notify("Path copied to clipboard") + except Exception as e: + self.notify(f"Failed to copy: {e}", severity="error") diff --git a/src/osprey/interfaces/tui/widgets/artifacts.py b/src/osprey/interfaces/tui/widgets/artifacts.py new file mode 100644 index 000000000..c0f7df657 --- /dev/null +++ b/src/osprey/interfaces/tui/widgets/artifacts.py @@ -0,0 +1,335 @@ +"""Artifact gallery widgets for the TUI. + +This module provides widgets for displaying artifacts (figures, notebooks, commands, etc.) +in the TUI interface. It supports "new" vs "seen" tracking to highlight artifacts +that were generated in the current conversation turn. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from textual.app import ComposeResult +from textual.containers import Vertical +from textual.message import Message +from textual.widgets import Static + +from osprey.state.artifacts import ArtifactType, get_artifact_type_icon + +if TYPE_CHECKING: + from textual.events import Click + + +class ArtifactItem(Static): + """A single artifact item in the gallery. + + Displays artifact type icon, name, capability, and timestamp. + Shows [NEW] badge for artifacts from the current turn. + """ + + class Selected(Message): + """Message sent when an artifact is selected.""" + + def __init__(self, artifact: dict[str, Any]) -> None: + super().__init__() + self.artifact = artifact + + def __init__(self, artifact: dict[str, Any], is_new: bool = False, **kwargs) -> None: + """Initialize an artifact item. + + Args: + artifact: The artifact dictionary from state + is_new: Whether this artifact is new (from current turn) + """ + super().__init__(**kwargs) + self.artifact = artifact + self.is_new = is_new + self._artifact_type = ArtifactType(artifact.get("type", "file")) + + def compose(self) -> ComposeResult: + """Compose the artifact item layout.""" + icon = get_artifact_type_icon(self._artifact_type) + display_name = self.artifact.get("display_name") or self._get_default_name() + capability = self.artifact.get("capability", "unknown") + created_at = self._format_timestamp() + + # Build the display text + new_badge = "[bold $accent][NEW][/] " if self.is_new else " " + line1 = f"{new_badge}{icon} {display_name}" + line2 = f" {capability} [dim]Β·[/] {created_at}" + + yield Static(line1, id="artifact-name") + yield Static(line2, id="artifact-meta", classes="artifact-meta") + + def _get_default_name(self) -> str: + """Get default name based on artifact type and data.""" + data = self.artifact.get("data", {}) + if self._artifact_type == ArtifactType.IMAGE: + path = data.get("path", "") + return path.split("/")[-1] if path else "Figure" + elif self._artifact_type == ArtifactType.NOTEBOOK: + path = data.get("path", "") + return path.split("/")[-1] if path else "Notebook" + elif self._artifact_type == ArtifactType.COMMAND: + return data.get("command_type", "Command") + elif self._artifact_type == ArtifactType.HTML: + return data.get("framework", "Interactive Content") + elif self._artifact_type == ArtifactType.FILE: + path = data.get("path", "") + return path.split("/")[-1] if path else "File" + return "Artifact" + + def _format_timestamp(self) -> str: + """Format the created_at timestamp for display.""" + created_at = self.artifact.get("created_at", "") + if not created_at: + return "" + try: + dt = datetime.fromisoformat(created_at) + return dt.strftime("%H:%M:%S") + except (ValueError, TypeError): + return created_at[:8] if len(created_at) >= 8 else created_at + + def on_click(self, event: Click) -> None: + """Handle click to select this artifact.""" + event.stop() + self.post_message(self.Selected(self.artifact)) + + +class ArtifactGallery(Static, can_focus=True): + """Gallery widget displaying all artifacts from the current execution. + + Tracks which artifacts have been "seen" across conversation turns + to highlight new artifacts with a [NEW] badge. + + Keyboard Navigation: + - Ctrl+a: Focus gallery (global, defined in app) + - j/↓: Select next artifact + - k/↑: Select previous artifact + - Enter: Open selected in viewer modal + - o: Open selected in external application + - Esc/q: Return focus to input + """ + + BINDINGS = [ + ("enter", "open_selected", "Open"), + ("o", "open_external", "Open External"), + ("j", "select_next", "Next"), + ("k", "select_previous", "Previous"), + ("down", "select_next", "Next"), + ("up", "select_previous", "Previous"), + ("escape", "exit_gallery", "Exit"), + ("q", "exit_gallery", "Exit"), + ] + + def __init__(self, **kwargs) -> None: + """Initialize the artifact gallery.""" + super().__init__(**kwargs) + self._artifacts: list[dict[str, Any]] = [] + self._seen_ids: set[str] = set() + self._selected_index: int = 0 + self._mounted: bool = False + + def compose(self) -> ComposeResult: + """Compose the gallery layout.""" + yield Static("", id="gallery-header") + yield Vertical(id="gallery-items") + yield Static("", id="gallery-footer") + + def on_mount(self) -> None: + """Initialize the gallery display.""" + self._mounted = True + self._update_display() + + def update_artifacts(self, artifacts: list[dict[str, Any]]) -> None: + """Update the gallery with new artifacts. + + Marks artifacts as "new" if their ID hasn't been seen before, + then adds all IDs to the seen set. + + Args: + artifacts: List of artifact dictionaries from state + """ + self._artifacts = artifacts + + # Mark new artifacts and update seen set + for artifact in artifacts: + artifact_id = artifact.get("id", "") + artifact["_is_new"] = artifact_id not in self._seen_ids + if artifact_id: + self._seen_ids.add(artifact_id) + + # Only update display if mounted (children exist) + if self._mounted: + self._update_display() + + def clear_seen(self) -> None: + """Clear the seen artifacts set (e.g., on new session).""" + self._seen_ids.clear() + + def _count_new(self) -> int: + """Count how many artifacts are new.""" + return sum(1 for a in self._artifacts if a.get("_is_new", False)) + + def _update_display(self) -> None: + """Update the gallery display with current artifacts.""" + if not self._artifacts: + self.display = False + return + + self.display = True + + # Check if children exist (widget must be mounted and composed) + try: + header = self.query_one("#gallery-header", Static) + items_container = self.query_one("#gallery-items", Vertical) + _footer = self.query_one("#gallery-footer", Static) # noqa: F841 verify exists + except Exception: + # Children don't exist yet - will be updated in on_mount + return + + # Update header + new_count = self._count_new() + if new_count > 0: + header.update(f"[bold]Artifacts[/] [dim]({new_count} new)[/]") + else: + header.update(f"[bold]Artifacts[/] [dim]({len(self._artifacts)})[/]") + + # Remove existing items + for child in list(items_container.children): + child.remove() + + # Add artifact items + for i, artifact in enumerate(self._artifacts): + is_new = artifact.get("_is_new", False) + item = ArtifactItem(artifact, is_new=is_new, id=f"artifact-{i}") + if i == self._selected_index: + item.add_class("artifact-selected") + items_container.mount(item) + + # Update footer based on current focus state + self._update_footer_for_focus(focused=self.has_focus) + + def get_selected_artifact(self) -> dict[str, Any] | None: + """Get the currently selected artifact.""" + if 0 <= self._selected_index < len(self._artifacts): + return self._artifacts[self._selected_index] + return None + + def action_open_selected(self) -> None: + """Open the selected artifact in the viewer modal.""" + artifact = self.get_selected_artifact() + if artifact: + self.post_message(ArtifactItem.Selected(artifact)) + + def action_open_external(self) -> None: + """Open the selected artifact in the system's default application.""" + artifact = self.get_selected_artifact() + if artifact: + self._open_external(artifact) + + def _open_external(self, artifact: dict[str, Any]) -> None: + """Open an artifact in the system's default application.""" + import platform + import subprocess + + data = artifact.get("data", {}) + artifact_type = ArtifactType(artifact.get("type", "file")) + + # Determine what to open + target = None + if artifact_type in (ArtifactType.IMAGE, ArtifactType.FILE, ArtifactType.HTML): + target = data.get("path") + elif artifact_type == ArtifactType.NOTEBOOK: + # Prefer URL for notebooks, fall back to path + target = data.get("url") or data.get("path") + elif artifact_type == ArtifactType.COMMAND: + target = data.get("uri") + + if not target: + return + + # Open based on platform + try: + system = platform.system() + if system == "Darwin": # macOS + subprocess.Popen(["open", target]) + elif system == "Linux": + subprocess.Popen(["xdg-open", target]) + elif system == "Windows": + subprocess.Popen(["start", target], shell=True) + except Exception: + pass # Silently fail if can't open + + def on_artifact_item_selected(self, event: ArtifactItem.Selected) -> None: + """Handle artifact item selection - update selection index. + + The event bubbles up to the app automatically, no need to re-post. + """ + # Update selected index + for i, artifact in enumerate(self._artifacts): + if artifact.get("id") == event.artifact.get("id"): + self._selected_index = i + break + self._update_selection_visual() + # Don't re-post - the event bubbles up automatically + + def action_select_next(self) -> None: + """Select the next artifact in the list.""" + if not self._artifacts: + return + self._selected_index = min(self._selected_index + 1, len(self._artifacts) - 1) + self._update_selection_visual() + + def action_select_previous(self) -> None: + """Select the previous artifact in the list.""" + if not self._artifacts: + return + self._selected_index = max(self._selected_index - 1, 0) + self._update_selection_visual() + + def action_exit_gallery(self) -> None: + """Exit the gallery and return focus to the input.""" + # Find and focus the input widget + try: + from osprey.interfaces.tui.widgets.chat_input import ChatInput + + chat_input = self.app.query_one(ChatInput) + chat_input.focus() + except Exception: + # Fallback: just blur ourselves + self.blur() + + def _update_selection_visual(self) -> None: + """Update the visual selection indicator.""" + try: + items_container = self.query_one("#gallery-items", Vertical) + for i, child in enumerate(items_container.children): + if i == self._selected_index: + child.add_class("artifact-selected") + else: + child.remove_class("artifact-selected") + except Exception: + pass + + def on_focus(self) -> None: + """Handle focus - update footer to show navigation hints.""" + self._update_footer_for_focus(focused=True) + self.add_class("gallery-focused") + + def on_blur(self) -> None: + """Handle blur - update footer to show entry hint.""" + self._update_footer_for_focus(focused=False) + self.remove_class("gallery-focused") + + def _update_footer_for_focus(self, focused: bool) -> None: + """Update footer text based on focus state.""" + try: + footer = self.query_one("#gallery-footer", Static) + if focused: + footer.update("[dim]j/k[/] navigate Β· [dim]Enter[/] view Β· [dim]o[/] open Β· [dim]Esc[/] exit") + else: + footer.update("[dim]Press [/]Ctrl+a[dim] to browse artifacts[/]") + except Exception: + pass diff --git a/src/osprey/interfaces/tui/widgets/chat_display.py b/src/osprey/interfaces/tui/widgets/chat_display.py index c50733537..2cb16d5af 100644 --- a/src/osprey/interfaces/tui/widgets/chat_display.py +++ b/src/osprey/interfaces/tui/widgets/chat_display.py @@ -7,6 +7,7 @@ from textual.containers import ScrollableContainer +from osprey.interfaces.tui.widgets.artifacts import ArtifactGallery from osprey.interfaces.tui.widgets.blocks import ProcessingBlock from osprey.interfaces.tui.widgets.debug import DebugBlock from osprey.interfaces.tui.widgets.messages import ChatMessage @@ -36,6 +37,10 @@ def __init__(self, **kwargs): # Plan progress tracking for flow-style updates self._plan_steps: list[dict] = [] self._plan_step_states: list[str] = [] + # Artifact gallery - tracks "new" vs "seen" across conversation turns + self._artifact_gallery: ArtifactGallery | None = None + # Seen artifact IDs persist across conversation turns (session-scoped) + self._seen_artifact_ids: set[str] = set() def start_new_query(self, user_query: str) -> None: """Reset blocks for a new query and add user message. @@ -52,6 +57,9 @@ def start_new_query(self, user_query: str) -> None: self._plan_step_states = [] if self._debug_block: self._debug_block.clear() + # Hide artifact gallery for new query (will be shown when artifacts arrive) + if self._artifact_gallery: + self._artifact_gallery.display = False self.add_message(user_query, "user") def get_or_create_debug_block(self) -> DebugBlock | None: @@ -78,3 +86,64 @@ def add_message(self, content: str, role: str = "user", message_type: str = "") message = ChatMessage(content, role, message_type=message_type) self.mount(message) self.scroll_end(animate=False) + + # ===== ARTIFACT GALLERY METHODS ===== + + def get_artifact_gallery(self) -> ArtifactGallery | None: + """Get the artifact gallery widget if it exists. + + Returns: + The artifact gallery widget, or None if not yet created + """ + return self._artifact_gallery + + def get_or_create_artifact_gallery(self) -> ArtifactGallery: + """Get or create the artifact gallery widget. + + The gallery is lazily created on first use and reused across + conversation turns, maintaining "seen" state for artifact tracking. + + Returns: + The artifact gallery widget + """ + if not self._artifact_gallery: + self._artifact_gallery = ArtifactGallery(id="artifact-gallery") + # Transfer any previously seen IDs + self._artifact_gallery._seen_ids = self._seen_artifact_ids + self.mount(self._artifact_gallery) + return self._artifact_gallery + + def update_artifacts(self, artifacts: list[dict[str, Any]]) -> None: + """Update the artifact gallery with artifacts from the current execution. + + Marks artifacts as "new" if their ID hasn't been seen before in this session, + then adds all IDs to the seen set for future reference. + + Args: + artifacts: List of artifact dictionaries from state.ui_artifacts + """ + if not artifacts: + return + + gallery = self.get_or_create_artifact_gallery() + + # Mark new artifacts (before updating gallery's seen set) + for artifact in artifacts: + artifact_id = artifact.get("id", "") + artifact["_is_new"] = artifact_id not in self._seen_artifact_ids + if artifact_id: + self._seen_artifact_ids.add(artifact_id) + + # Update gallery with marked artifacts + gallery._artifacts = artifacts + gallery._update_display() + self.scroll_end(animate=False) + + def clear_artifact_history(self) -> None: + """Clear the seen artifacts history (e.g., on new session). + + This resets the "new" tracking so all artifacts will appear as new. + """ + self._seen_artifact_ids.clear() + if self._artifact_gallery: + self._artifact_gallery._seen_ids.clear() diff --git a/tests/test_tui_artifacts.py b/tests/test_tui_artifacts.py new file mode 100644 index 000000000..bb6da73ad --- /dev/null +++ b/tests/test_tui_artifacts.py @@ -0,0 +1,376 @@ +"""Tests for TUI artifact widgets. + +This module tests the TUI artifact display system including: +- ArtifactItem widget rendering +- ArtifactGallery tracking of new vs seen artifacts +- Integration with ChatDisplay +""" + +from osprey.state.artifacts import ArtifactType, create_artifact + + +class TestArtifactItemWidget: + """Tests for the ArtifactItem widget.""" + + def test_artifact_item_creation(self): + """ArtifactItem should be creatable with an artifact.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactItem + + artifact = create_artifact( + ArtifactType.IMAGE, "python_executor", {"path": "/test/plot.png", "format": "png"} + ) + item = ArtifactItem(artifact, is_new=True) + + assert item.artifact == artifact + assert item.is_new is True + assert item._artifact_type == ArtifactType.IMAGE + + def test_artifact_item_default_name_image(self): + """ArtifactItem should derive default name from path for images.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactItem + + artifact = create_artifact( + ArtifactType.IMAGE, "test", {"path": "/path/to/analysis_plot.png", "format": "png"} + ) + item = ArtifactItem(artifact) + + assert item._get_default_name() == "analysis_plot.png" + + def test_artifact_item_default_name_notebook(self): + """ArtifactItem should derive default name from path for notebooks.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactItem + + artifact = create_artifact( + ArtifactType.NOTEBOOK, + "test", + {"path": "/path/to/execution.ipynb", "url": "http://jupyter"}, + ) + item = ArtifactItem(artifact) + + assert item._get_default_name() == "execution.ipynb" + + def test_artifact_item_default_name_command(self): + """ArtifactItem should use command_type for commands.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactItem + + artifact = create_artifact( + ArtifactType.COMMAND, + "test", + {"uri": "http://localhost:8080", "command_type": "web_app"}, + ) + item = ArtifactItem(artifact) + + assert item._get_default_name() == "web_app" + + def test_artifact_item_uses_display_name_if_provided(self): + """ArtifactItem should use display_name when available.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactItem + + artifact = create_artifact( + ArtifactType.IMAGE, + "test", + {"path": "/plot.png"}, + display_name="My Custom Plot", + ) + item = ArtifactItem(artifact) + + # display_name takes precedence in rendering + assert artifact.get("display_name") == "My Custom Plot" + + +class TestArtifactGallery: + """Tests for the ArtifactGallery widget.""" + + def test_gallery_creation(self): + """ArtifactGallery should be creatable.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactGallery + + gallery = ArtifactGallery() + + assert gallery._artifacts == [] + assert gallery._seen_ids == set() + assert gallery._selected_index == 0 + + def test_gallery_tracks_new_artifacts(self): + """ArtifactGallery should mark unseen artifacts as new. + + Note: We test the tracking logic directly without calling update_artifacts() + since that requires mounted Textual widgets. + """ + from osprey.interfaces.tui.widgets.artifacts import ArtifactGallery + + gallery = ArtifactGallery() + + # First batch of artifacts + artifact1 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/a.png"}) + artifact2 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/b.png"}) + + # Simulate the tracking logic from update_artifacts + artifacts = [artifact1, artifact2] + for artifact in artifacts: + artifact_id = artifact.get("id", "") + artifact["_is_new"] = artifact_id not in gallery._seen_ids + if artifact_id: + gallery._seen_ids.add(artifact_id) + gallery._artifacts = artifacts + + # Both should be marked new (first time seeing them) + assert artifact1.get("_is_new") is True + assert artifact2.get("_is_new") is True + assert len(gallery._seen_ids) == 2 + + def test_gallery_marks_seen_artifacts_as_not_new(self): + """ArtifactGallery should not mark previously seen artifacts as new.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactGallery + + gallery = ArtifactGallery() + + # First artifact + artifact1 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/a.png"}) + artifact1["_is_new"] = artifact1["id"] not in gallery._seen_ids + gallery._seen_ids.add(artifact1["id"]) + + assert artifact1.get("_is_new") is True + + # Same artifact again (by ID) + artifact1_copy = artifact1.copy() + + # New artifact + artifact2 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/b.png"}) + + # Simulate update + artifacts = [artifact1_copy, artifact2] + for artifact in artifacts: + artifact_id = artifact.get("id", "") + artifact["_is_new"] = artifact_id not in gallery._seen_ids + if artifact_id: + gallery._seen_ids.add(artifact_id) + + # artifact1 should NOT be new (already seen), artifact2 should be new + assert artifact1_copy.get("_is_new") is False + assert artifact2.get("_is_new") is True + + def test_gallery_count_new(self): + """ArtifactGallery should correctly count new artifacts.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactGallery + + gallery = ArtifactGallery() + + artifact1 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/a.png"}) + artifact2 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/b.png"}) + + # Simulate first update + artifacts = [artifact1, artifact2] + for artifact in artifacts: + artifact_id = artifact.get("id", "") + artifact["_is_new"] = artifact_id not in gallery._seen_ids + if artifact_id: + gallery._seen_ids.add(artifact_id) + gallery._artifacts = artifacts + + assert gallery._count_new() == 2 + + # Add one more, reuse artifact1 + artifact3 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/c.png"}) + + # Simulate second update + artifacts = [artifact1.copy(), artifact2.copy(), artifact3] + for artifact in artifacts: + artifact_id = artifact.get("id", "") + artifact["_is_new"] = artifact_id not in gallery._seen_ids + if artifact_id: + gallery._seen_ids.add(artifact_id) + gallery._artifacts = artifacts + + # Only artifact3 is new now + assert gallery._count_new() == 1 + + def test_gallery_clear_seen(self): + """clear_seen should reset the seen artifacts set.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactGallery + + gallery = ArtifactGallery() + + # Add some IDs directly + gallery._seen_ids.add("test-id-1") + gallery._seen_ids.add("test-id-2") + + assert len(gallery._seen_ids) == 2 + + gallery.clear_seen() + + assert len(gallery._seen_ids) == 0 + + def test_gallery_get_selected_artifact(self): + """get_selected_artifact should return the selected artifact.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactGallery + + gallery = ArtifactGallery() + + artifact1 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/a.png"}) + artifact2 = create_artifact(ArtifactType.IMAGE, "test", {"path": "/b.png"}) + + gallery._artifacts = [artifact1, artifact2] + gallery._selected_index = 0 + + assert gallery.get_selected_artifact() == artifact1 + + gallery._selected_index = 1 + assert gallery.get_selected_artifact() == artifact2 + + def test_gallery_get_selected_artifact_empty(self): + """get_selected_artifact should return None when no artifacts.""" + from osprey.interfaces.tui.widgets.artifacts import ArtifactGallery + + gallery = ArtifactGallery() + + assert gallery.get_selected_artifact() is None + + +class TestArtifactViewer: + """Tests for the ArtifactViewer modal.""" + + def test_viewer_creation_image(self): + """ArtifactViewer should be creatable with an image artifact.""" + from osprey.interfaces.tui.widgets.artifact_viewer import ArtifactViewer + + artifact = create_artifact( + ArtifactType.IMAGE, + "python_executor", + {"path": "/path/to/plot.png", "format": "png"}, + display_name="Analysis Plot", + ) + viewer = ArtifactViewer(artifact) + + assert viewer.artifact == artifact + assert viewer._artifact_type == ArtifactType.IMAGE + + def test_viewer_creation_notebook(self): + """ArtifactViewer should be creatable with a notebook artifact.""" + from osprey.interfaces.tui.widgets.artifact_viewer import ArtifactViewer + + artifact = create_artifact( + ArtifactType.NOTEBOOK, + "python_executor", + {"path": "/path/to/notebook.ipynb", "url": "http://jupyter/notebook"}, + ) + viewer = ArtifactViewer(artifact) + + assert viewer._artifact_type == ArtifactType.NOTEBOOK + + def test_viewer_creation_command(self): + """ArtifactViewer should be creatable with a command artifact.""" + from osprey.interfaces.tui.widgets.artifact_viewer import ArtifactViewer + + artifact = create_artifact( + ArtifactType.COMMAND, + "dashboard", + {"uri": "http://localhost:8080/dashboard", "command_type": "web_app"}, + ) + viewer = ArtifactViewer(artifact) + + assert viewer._artifact_type == ArtifactType.COMMAND + + def test_viewer_get_openable_target_image(self): + """ArtifactViewer should return path for images.""" + from osprey.interfaces.tui.widgets.artifact_viewer import ArtifactViewer + + artifact = create_artifact( + ArtifactType.IMAGE, "test", {"path": "/path/to/plot.png", "format": "png"} + ) + viewer = ArtifactViewer(artifact) + + assert viewer._get_openable_target() == "/path/to/plot.png" + + def test_viewer_get_openable_target_notebook_prefers_url(self): + """ArtifactViewer should prefer URL for notebooks.""" + from osprey.interfaces.tui.widgets.artifact_viewer import ArtifactViewer + + artifact = create_artifact( + ArtifactType.NOTEBOOK, + "test", + {"path": "/path/to/notebook.ipynb", "url": "http://jupyter/notebook"}, + ) + viewer = ArtifactViewer(artifact) + + assert viewer._get_openable_target() == "http://jupyter/notebook" + + def test_viewer_get_openable_target_command(self): + """ArtifactViewer should return URI for commands.""" + from osprey.interfaces.tui.widgets.artifact_viewer import ArtifactViewer + + artifact = create_artifact(ArtifactType.COMMAND, "test", {"uri": "http://localhost:8080"}) + viewer = ArtifactViewer(artifact) + + assert viewer._get_openable_target() == "http://localhost:8080" + + def test_viewer_get_copyable_path(self): + """ArtifactViewer should return copyable path/url.""" + from osprey.interfaces.tui.widgets.artifact_viewer import ArtifactViewer + + artifact = create_artifact( + ArtifactType.IMAGE, "test", {"path": "/path/to/plot.png", "format": "png"} + ) + viewer = ArtifactViewer(artifact) + + assert viewer._get_copyable_path() == "/path/to/plot.png" + + def test_viewer_format_size(self): + """ArtifactViewer should format file sizes correctly.""" + from osprey.interfaces.tui.widgets.artifact_viewer import ArtifactViewer + + artifact = create_artifact(ArtifactType.FILE, "test", {"path": "/file.csv"}) + viewer = ArtifactViewer(artifact) + + assert viewer._format_size(100) == "100.0 B" + assert viewer._format_size(1024) == "1.0 KB" + assert viewer._format_size(1024 * 1024) == "1.0 MB" + assert viewer._format_size(1024 * 1024 * 1024) == "1.0 GB" + + +class TestChatDisplayArtifactIntegration: + """Tests for ChatDisplay artifact integration.""" + + def test_chat_display_has_artifact_tracking(self): + """ChatDisplay should have artifact tracking attributes.""" + from osprey.interfaces.tui.widgets.chat_display import ChatDisplay + + display = ChatDisplay() + + assert hasattr(display, "_artifact_gallery") + assert hasattr(display, "_seen_artifact_ids") + assert display._artifact_gallery is None + assert display._seen_artifact_ids == set() + + def test_chat_display_clear_artifact_history(self): + """ChatDisplay should clear artifact history.""" + from osprey.interfaces.tui.widgets.chat_display import ChatDisplay + + display = ChatDisplay() + display._seen_artifact_ids.add("test-id-1") + display._seen_artifact_ids.add("test-id-2") + + display.clear_artifact_history() + + assert display._seen_artifact_ids == set() + + +class TestArtifactTypeIcons: + """Tests for artifact type icons.""" + + def test_all_types_have_icons(self): + """All artifact types should have icons.""" + from osprey.state.artifacts import ArtifactType, get_artifact_type_icon + + for artifact_type in ArtifactType: + icon = get_artifact_type_icon(artifact_type) + assert icon is not None + assert len(icon) > 0 + + def test_icon_lookup_by_string(self): + """Icons should be retrievable by string type.""" + from osprey.state.artifacts import get_artifact_type_icon + + assert get_artifact_type_icon("image") == get_artifact_type_icon(ArtifactType.IMAGE) + assert get_artifact_type_icon("notebook") == get_artifact_type_icon(ArtifactType.NOTEBOOK) + assert get_artifact_type_icon("command") == get_artifact_type_icon(ArtifactType.COMMAND) From 266cd015eb3c77f5ed2fec236ffa3549de11d0a0 Mon Sep 17 00:00:00 2001 From: ThorstenHellert Date: Sat, 10 Jan 2026 15:34:13 +0100 Subject: [PATCH 03/14] feat(cli): Modernize artifact display to use unified ui_artifacts - Replace _extract_figures/commands/notebooks_for_cli with single method - Read from ui_artifacts directly instead of legacy fields - Support all artifact types: IMAGE, NOTEBOOK, COMMAND, HTML, FILE - Add type-specific formatting with icons and grouped display - Remove ~150 lines of redundant legacy extraction code --- CHANGELOG.md | 4 + .../interfaces/cli/direct_conversation.py | 257 ++++++++---------- .../interfaces/tui/widgets/artifacts.py | 4 +- tests/test_tui_artifacts.py | 2 +- 4 files changed, 115 insertions(+), 152 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5563f62b..c5a89fc67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Capabilities**: Python capability uses unified `register_artifact()` API directly - Clean single-accumulation pattern for figures and notebooks - Legacy fields populated at finalization rather than registration +- **CLI**: Modernized artifact display to use unified `ui_artifacts` registry + - Single `_extract_artifacts_for_cli()` replaces three legacy extraction methods + - Supports all artifact types: IMAGE, NOTEBOOK, COMMAND, HTML, FILE + - Grouped display with type-specific formatting and icons ## [0.10.1] - 2026-01-09 diff --git a/src/osprey/interfaces/cli/direct_conversation.py b/src/osprey/interfaces/cli/direct_conversation.py index fa18b9160..bd00fc86a 100644 --- a/src/osprey/interfaces/cli/direct_conversation.py +++ b/src/osprey/interfaces/cli/direct_conversation.py @@ -39,6 +39,7 @@ from osprey.graph import create_graph from osprey.infrastructure.gateway import Gateway from osprey.registry import get_registry, initialize_registry +from osprey.state.artifacts import ArtifactType, get_artifact_type_icon from osprey.utils.config import get_full_configuration from osprey.utils.logger import get_logger, quiet_logging @@ -890,21 +891,11 @@ async def _show_final_result(self, result: dict[str, Any]): # Fallback if no messages found self.console.print(f"[{Styles.SUCCESS}]βœ… Execution completed[/{Styles.SUCCESS}]") - # Extract and display additional content - figures_output = self._extract_figures_for_cli(result) - if figures_output: + # Extract and display artifacts from unified registry + artifacts_output = self._extract_artifacts_for_cli(result) + if artifacts_output: self.console.print() # Add spacing - self.console.print(f"[{Styles.INFO}]{figures_output}[/{Styles.INFO}]") - - commands_output = self._extract_commands_for_cli(result) - if commands_output: - self.console.print() # Add spacing - self.console.print(f"[{Styles.COMMAND}]{commands_output}[/{Styles.COMMAND}]") - - notebooks_output = self._extract_notebooks_for_cli(result) - if notebooks_output: - self.console.print() # Add spacing - self.console.print(f"[{Styles.INFO}]{notebooks_output}[/{Styles.INFO}]") + self.console.print(f"[{Styles.INFO}]{artifacts_output}[/{Styles.INFO}]") async def _handle_stream_event(self, event: dict[str, Any]): """Handle and display streaming events from LangGraph execution. @@ -973,172 +964,138 @@ async def _handle_stream_event(self, event: dict[str, Any]): # If no response found, show completion self.console.print(f"[{Styles.SUCCESS}]βœ… Execution completed[/{Styles.SUCCESS}]") - def _extract_figures_for_cli(self, state: dict[str, Any]) -> str | None: - """Extract figures from centralized registry and format for CLI display. + def _extract_artifacts_for_cli(self, state: dict[str, Any]) -> str | None: + """Extract artifacts from unified registry and format for CLI display. - Extracts generated figures from the state and formats them for terminal - display with file paths and metadata. Unlike the OpenWebUI version that - converts to base64 images, this provides file paths that users can - access directly from their terminal. + Reads from the unified ui_artifacts field and formats all artifact types + for terminal display with appropriate icons and type-specific details. - :param state: Complete agent state containing figure registry + :param state: Complete agent state containing ui_artifacts :type state: dict[str, Any] - :return: Formatted string with figure information or None if no figures + :return: Formatted string with artifact information or None if no artifacts :rtype: str | None Examples: - Display figures in terminal:: + Display artifacts in terminal:: πŸ“Š Generated Figures: - β€’ /path/to/analysis_plot.png (created by python_executor at 2024-01-01 12:00:00) - β€’ /path/to/data_visualization.jpg (created by data_analysis at 2024-01-01 12:01:00) - """ - try: - # Get figures from centralized registry - ui_figures = state.get("ui_captured_figures", []) - - if not ui_figures: - logger.debug("No figures found in ui_captured_figures registry") - return None - - logger.info( - f"Processing {len(ui_figures)} figures from centralized registry for CLI display" - ) - figure_lines = ["πŸ“Š Generated Figures:"] - - for figure_entry in ui_figures: - try: - # Extract figure information - capability = figure_entry.get("capability", "unknown") - figure_path = figure_entry["figure_path"] - created_at = figure_entry.get("created_at", "unknown") - - # Format created_at if it's available - created_at_str = ( - str(created_at)[:19] - if created_at and created_at != "unknown" - else "unknown time" - ) - - # Create CLI-friendly display - figure_line = f"β€’ {figure_path} (created by {capability} at {created_at_str})" - figure_lines.append(figure_line) - - except Exception as e: - logger.warning(f"Failed to process figure entry {figure_entry}: {e}") - # Continue processing other figures - continue - - if len(figure_lines) > 1: # More than just the header - return "\n".join(figure_lines) - - return None - - except Exception as e: - logger.error(f"Critical error in CLI figure extraction: {e}") - return f"❌ Figure display error: {str(e)}" - - def _extract_commands_for_cli(self, state: dict[str, Any]) -> str | None: - """Extract launchable commands from centralized registry and format for CLI display. - - Extracts registered commands from the state and formats them for terminal - display with launch URIs and descriptions. Provides clickable links for - terminal emulators that support them, or copy-paste URLs for others. - - :param state: Complete agent state containing command registry - :type state: dict[str, Any] - :return: Formatted string with command information or None if no commands - :rtype: str | None + /path/to/plot.png (created by python_executor at 2024-01-01 12:00:00) - Examples: - Display commands in terminal:: + πŸ““ Generated Notebooks: + Execution Notebook: http://localhost:8888/notebooks/analysis.ipynb πŸš€ Executable Commands: - β€’ Launch Jupyter Lab: http://localhost:8888/lab - β€’ Open Dashboard: http://localhost:3000/dashboard + Launch Jupyter Lab: http://localhost:8888/lab """ try: - # Get commands from centralized registry - ui_commands = state.get("ui_launchable_commands", []) + ui_artifacts = state.get("ui_artifacts", []) - if not ui_commands: - logger.debug("No commands found in ui_launchable_commands registry") + if not ui_artifacts: + logger.debug("No artifacts found in ui_artifacts registry") return None - logger.info( - f"Processing {len(ui_commands)} commands from centralized registry for CLI display" - ) - command_lines = ["πŸš€ Executable Commands:"] + logger.info(f"Processing {len(ui_artifacts)} artifacts for CLI display") - for i, command_entry in enumerate(ui_commands, 1): + # Group artifacts by type for organized display + artifacts_by_type: dict[ArtifactType, list[dict[str, Any]]] = {} + for artifact in ui_artifacts: try: - # Extract command information - launch_uri = command_entry["launch_uri"] - display_name = command_entry.get("display_name", f"Launch Command {i}") - - # Create CLI-friendly display - command_line = f"β€’ {display_name}: {launch_uri}" - command_lines.append(command_line) - - except Exception as e: - logger.warning(f"Failed to process command entry {command_entry}: {e}") - # Continue processing other commands + artifact_type = ArtifactType(artifact.get("type", "file")) + if artifact_type not in artifacts_by_type: + artifacts_by_type[artifact_type] = [] + artifacts_by_type[artifact_type].append(artifact) + except ValueError: + logger.warning(f"Unknown artifact type: {artifact.get('type')}") continue - if len(command_lines) > 1: # More than just the header - return "\n".join(command_lines) - - return None - - except Exception as e: - logger.error(f"Critical error in CLI command extraction: {e}") - return f"❌ Command display error: {str(e)}" - - def _extract_notebooks_for_cli(self, state: dict[str, Any]) -> str | None: - """Extract notebook links from centralized registry and format for CLI display. - - Extracts registered notebook links from the state and formats them for - terminal display. Provides direct URLs that users can copy-paste or - click in terminal emulators that support link clicking. + if not artifacts_by_type: + return None - :param state: Complete agent state containing notebook registry - :type state: dict[str, Any] - :return: Formatted string with notebook information or None if no notebooks - :rtype: str | None + # Build output with type-specific formatting + output_lines: list[str] = [] - Examples: - Display notebooks in terminal:: + # Define display order and headers + type_config = [ + (ArtifactType.IMAGE, "Generated Figures"), + (ArtifactType.NOTEBOOK, "Generated Notebooks"), + (ArtifactType.COMMAND, "Executable Commands"), + (ArtifactType.HTML, "Interactive Content"), + (ArtifactType.FILE, "Generated Files"), + ] - πŸ““ Generated Notebooks: - β€’ Jupyter Notebook 1: http://localhost:8888/notebooks/analysis.ipynb - β€’ Jupyter Notebook 2: http://localhost:8888/notebooks/results.ipynb - """ - try: - # Get notebook links from centralized registry - ui_notebooks = state.get("ui_notebook_links", []) + for artifact_type, header in type_config: + if artifact_type not in artifacts_by_type: + continue - if not ui_notebooks: - logger.debug("No notebook links found in ui_notebook_links registry") - return None + artifacts = artifacts_by_type[artifact_type] + icon = get_artifact_type_icon(artifact_type) + output_lines.append(f"{icon} {header}:") - logger.info( - f"Processing {len(ui_notebooks)} notebook links from centralized registry for CLI display" - ) - notebook_lines = ["πŸ““ Generated Notebooks:"] + for artifact in artifacts: + line = self._format_artifact_line(artifact, artifact_type) + output_lines.append(line) - for i, notebook_link in enumerate(ui_notebooks, 1): - # Create CLI-friendly display - notebook_line = f"β€’ Jupyter Notebook {i}: {notebook_link}" - notebook_lines.append(notebook_line) + output_lines.append("") # Blank line between sections - if len(notebook_lines) > 1: # More than just the header - return "\n".join(notebook_lines) + # Remove trailing blank line + if output_lines and output_lines[-1] == "": + output_lines.pop() - return None + return "\n".join(output_lines) if output_lines else None except Exception as e: - logger.error(f"Critical error in CLI notebook extraction: {e}") - return f"❌ Notebook display error: {str(e)}" + logger.error(f"Critical error in CLI artifact extraction: {e}") + return f"❌ Artifact display error: {str(e)}" + + def _format_artifact_line(self, artifact: dict[str, Any], artifact_type: ArtifactType) -> str: + """Format a single artifact for CLI display. + + :param artifact: Artifact dictionary from ui_artifacts + :type artifact: dict[str, Any] + :param artifact_type: The artifact's type + :type artifact_type: ArtifactType + :return: Formatted line string + :rtype: str + """ + data = artifact.get("data", {}) + display_name = artifact.get("display_name", "") + capability = artifact.get("capability", "unknown") + created_at = artifact.get("created_at", "") + + # Format timestamp + created_at_str = str(created_at)[:19] if created_at else "unknown time" + + if artifact_type == ArtifactType.IMAGE: + path = data.get("path", "N/A") + return f" β€’ {path} (created by {capability} at {created_at_str})" + + elif artifact_type == ArtifactType.NOTEBOOK: + # Prefer URL for notebooks, fallback to path + url = data.get("url", "") + path = data.get("path", "") + target = url if url else path if path else "N/A" + name = display_name if display_name else "Jupyter Notebook" + return f" β€’ {name}: {target}" + + elif artifact_type == ArtifactType.COMMAND: + uri = data.get("uri", "N/A") + name = display_name if display_name else "Launch Command" + return f" β€’ {name}: {uri}" + + elif artifact_type == ArtifactType.HTML: + url = data.get("url", "") + path = data.get("path", "") + target = url if url else path if path else "N/A" + name = display_name if display_name else "Interactive Content" + return f" β€’ {name}: {target}" + + elif artifact_type == ArtifactType.FILE: + path = data.get("path", "N/A") + name = display_name if display_name else "File" + return f" β€’ {name}: {path}" + + else: + return f" β€’ {display_name or 'Artifact'}: {data}" async def run_cli(config_path="config.yml", show_streaming_updates=False): diff --git a/src/osprey/interfaces/tui/widgets/artifacts.py b/src/osprey/interfaces/tui/widgets/artifacts.py index c0f7df657..d94e85ad7 100644 --- a/src/osprey/interfaces/tui/widgets/artifacts.py +++ b/src/osprey/interfaces/tui/widgets/artifacts.py @@ -328,7 +328,9 @@ def _update_footer_for_focus(self, focused: bool) -> None: try: footer = self.query_one("#gallery-footer", Static) if focused: - footer.update("[dim]j/k[/] navigate Β· [dim]Enter[/] view Β· [dim]o[/] open Β· [dim]Esc[/] exit") + footer.update( + "[dim]j/k[/] navigate Β· [dim]Enter[/] view Β· [dim]o[/] open Β· [dim]Esc[/] exit" + ) else: footer.update("[dim]Press [/]Ctrl+a[dim] to browse artifacts[/]") except Exception: diff --git a/tests/test_tui_artifacts.py b/tests/test_tui_artifacts.py index bb6da73ad..e4d5cc57d 100644 --- a/tests/test_tui_artifacts.py +++ b/tests/test_tui_artifacts.py @@ -72,7 +72,7 @@ def test_artifact_item_uses_display_name_if_provided(self): {"path": "/plot.png"}, display_name="My Custom Plot", ) - item = ArtifactItem(artifact) + _item = ArtifactItem(artifact) # noqa: F841 - verifies item can be created # display_name takes precedence in rendering assert artifact.get("display_name") == "My Custom Plot" From 88987314283dbef98390edd6b7c1f6b7b4466756 Mon Sep 17 00:00:00 2001 From: ThorstenHellert Date: Tue, 13 Jan 2026 13:58:45 -0800 Subject: [PATCH 04/14] feat(state): Add multi-iteration approval support with custom reducers Enable approval_approved and approved_payload fields to be overwritten across multiple approval iterations. Without this, LangGraph's default LastValue channel throws errors when setting approval fields multiple times in multi-iteration workflows (e.g., XOpt optimization). --- CHANGELOG.md | 3 ++ src/osprey/state/state.py | 75 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c5a89fc67..ecfc2b5fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ArtifactViewer modal with type-specific details and actions (copy path, open in system app) - Native image rendering via textual-image (Sixel for iTerm2/WezTerm, Kitty Graphics Protocol) - New/seen tracking with [NEW] badges for artifacts from current turn +- **State**: Multi-iteration approval support with custom reducers + - `overwrite_approval_bool` and `overwrite_approval_payload` reducers enable approval fields to be reset between iterations + - Required for services like XOpt optimization that request approval multiple times per session ### Changed - **Capabilities**: Python capability uses unified `register_artifact()` API directly diff --git a/src/osprey/state/state.py b/src/osprey/state/state.py index 098556a0a..ea02dd70c 100644 --- a/src/osprey/state/state.py +++ b/src/osprey/state/state.py @@ -128,6 +128,71 @@ def merge_session_state(existing: dict[str, Any] | None, new: dict[str, Any]) -> return result +def overwrite_approval_bool(existing: bool | None, new: bool | None) -> bool | None: + """Overwrite approval boolean state for multi-iteration approval flows. + + This custom reducer enables the approval_approved field to be overwritten + across multiple approval iterations. Without this reducer, LangGraph's + default LastValue channel throws an error when trying to set approval_approved + multiple times (e.g., in multi-iteration optimization workflows). + + The reducer simply returns the new value, allowing clean overwrites between + approval cycles. This is essential for services that loop and request + approval for each iteration (like XOpt optimization). + + :param existing: Existing approval state from previous writes + :type existing: Optional[bool] + :param new: New approval state to set + :type new: Optional[bool] + :return: The new approval state value + :rtype: Optional[bool] + + .. note:: + This reducer intentionally ignores the existing value. Each approval + cycle starts fresh, with the new value completely replacing the old. + + Examples: + Multi-iteration approval flow:: + + >>> # Iteration 1: user approves + >>> result = overwrite_approval_bool(None, True) + >>> result + True + >>> # Iteration 2: user approves again (would fail without reducer) + >>> result = overwrite_approval_bool(True, True) + >>> result + True + + .. seealso:: + :class:`AgentState` : Main state class using this reducer + :func:`overwrite_approval_payload` : Companion reducer for approval payload + """ + return new + + +def overwrite_approval_payload( + existing: dict[str, Any] | None, new: dict[str, Any] | None +) -> dict[str, Any] | None: + """Overwrite approval payload for multi-iteration approval flows. + + This custom reducer enables the approved_payload field to be overwritten + across multiple approval iterations. Works in conjunction with + overwrite_approval_bool to support multi-iteration approval workflows. + + :param existing: Existing approval payload from previous writes + :type existing: Optional[Dict[str, Any]] + :param new: New approval payload to set + :type new: Optional[Dict[str, Any]] + :return: The new approval payload value + :rtype: Optional[Dict[str, Any]] + + .. seealso:: + :class:`AgentState` : Main state class using this reducer + :func:`overwrite_approval_bool` : Companion reducer for approval boolean + """ + return new + + def merge_capability_context_data( existing: dict[str, dict[str, dict[str, Any]]] | None, new: dict[str, dict[str, dict[str, Any]]] ) -> dict[str, dict[str, dict[str, Any]]]: @@ -331,8 +396,14 @@ class AgentState(MessagesState): execution_total_time: float | None # Approval handling fields (for interrupt flows) - approval_approved: bool | None # True/False/None for approved/rejected/no-approval - approved_payload: dict[str, Any] | None # Direct payload access + # These use custom reducers to support multi-iteration approval workflows + # (e.g., XOpt optimization that requests approval for each iteration) + approval_approved: Annotated[ + bool | None, overwrite_approval_bool + ] # True/False/None for approved/rejected/no-approval + approved_payload: Annotated[ + dict[str, Any] | None, overwrite_approval_payload + ] # Direct payload access # Control flow fields control_reclassification_reason: str | None From f4133c2f54c18cdd5316e4b1e507711c29df9f8b Mon Sep 17 00:00:00 2001 From: ThorstenHellert Date: Tue, 13 Jan 2026 14:02:53 -0800 Subject: [PATCH 05/14] fix(cli): Style multi-iteration approval prompts with Panel Display repeat approval requests in styled Panel matching the first approval prompt, rather than plain text. Ensures consistent visual treatment for multi-iteration approval workflows. --- CHANGELOG.md | 3 +++ src/osprey/interfaces/cli/direct_conversation.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ecfc2b5fc..e3b438fd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Required for services like XOpt optimization that request approval multiple times per session ### Changed +- **CLI**: Improved approval panel styling for multi-iteration approval flows + - Repeat approval requests now display in styled Panel matching first approval + - Consistent visual treatment improves UX for multi-iteration workflows - **Capabilities**: Python capability uses unified `register_artifact()` API directly - Clean single-accumulation pattern for figures and notebooks - Legacy fields populated at finalization rather than registration diff --git a/src/osprey/interfaces/cli/direct_conversation.py b/src/osprey/interfaces/cli/direct_conversation.py index bd00fc86a..b3de6a991 100644 --- a/src/osprey/interfaces/cli/direct_conversation.py +++ b/src/osprey/interfaces/cli/direct_conversation.py @@ -666,7 +666,19 @@ async def _do_process_user_input(self, user_input: str) -> bool: user_message = interrupt.value.get( "user_message", "Additional approval required" ) - self.console.print(f"\n[{Styles.WARNING}]{user_message}[/{Styles.WARNING}]") + + # Display approval message in a stylish panel (same as first interrupt) + self.console.print("\n") # Add spacing before panel + self.console.print( + Panel( + user_message, + title="[bold red]⚠️ HUMAN APPROVAL REQUIRED[/bold red]", + subtitle="[dim]Respond with 'yes' or 'no'[/dim]", + border_style="yellow", + box=HEAVY, + padding=(1, 2), + ) + ) user_input = await self.prompt_session.prompt_async( self._get_prompt(), style=self.prompt_style From 6dc4c87c50b518c65abb5421b395ab6410a6b8b7 Mon Sep 17 00:00:00 2001 From: ThorstenHellert Date: Tue, 13 Jan 2026 14:03:45 -0800 Subject: [PATCH 06/14] feat(services): Add XOpt Optimizer Service for autonomous optimization Implement XOpt-based autonomous machine parameter optimization service with complete workflow: state identification, strategy decision, YAML generation, human approval, and execution. Key components: - State identification: ReAct agent assesses machine readiness - Decision: LLM selects exploration vs optimization strategy - YAML generation: ReAct agent creates XOpt configurations - Approval: Human-in-the-loop with structured interrupt - Configurable modes: react (LLM) or mock (fast tests) Includes optimization capability, prompt builder, config templates, approval helpers, and comprehensive test suite. --- CHANGELOG.md | 10 + src/osprey/approval/__init__.py | 2 + src/osprey/approval/approval_system.py | 104 ++++ src/osprey/capabilities/optimization.py | 460 +++++++++++++++ .../control_system/mock_connector.py | 6 + src/osprey/prompts/defaults/__init__.py | 6 + src/osprey/prompts/defaults/optimization.py | 272 +++++++++ src/osprey/prompts/loader.py | 22 + src/osprey/registry/registry.py | 34 ++ .../services/xopt_optimizer/__init__.py | 62 +++ .../xopt_optimizer/analysis/__init__.py | 9 + .../services/xopt_optimizer/analysis/node.py | 102 ++++ .../xopt_optimizer/approval/__init__.py | 9 + .../services/xopt_optimizer/approval/node.py | 63 +++ .../xopt_optimizer/decision/__init__.py | 14 + .../services/xopt_optimizer/decision/node.py | 331 +++++++++++ .../services/xopt_optimizer/exceptions.py | 147 +++++ .../xopt_optimizer/execution/__init__.py | 13 + .../services/xopt_optimizer/execution/node.py | 89 +++ src/osprey/services/xopt_optimizer/models.py | 273 +++++++++ src/osprey/services/xopt_optimizer/service.py | 369 ++++++++++++ .../state_identification/__init__.py | 26 + .../state_identification/agent.py | 337 +++++++++++ .../state_identification/node.py | 214 +++++++ .../state_identification/tools/__init__.py | 36 ++ .../tools/channel_access.py | 114 ++++ .../tools/reference_files.py | 236 ++++++++ .../yaml_generation/__init__.py | 25 + .../xopt_optimizer/yaml_generation/agent.py | 484 ++++++++++++++++ .../xopt_optimizer/yaml_generation/node.py | 304 ++++++++++ .../apps/control_assistant/config.yml.j2 | 45 ++ src/osprey/utils/config.py | 29 + tests/conftest.py | 30 + tests/services/xopt_optimizer/__init__.py | 1 + .../test_state_identification.py | 274 +++++++++ .../xopt_optimizer/test_xopt_approval.py | 99 ++++ .../xopt_optimizer/test_xopt_exceptions.py | 149 +++++ .../xopt_optimizer/test_xopt_service.py | 523 ++++++++++++++++++ .../xopt_optimizer/test_xopt_workflow.py | 93 ++++ 39 files changed, 5416 insertions(+) create mode 100644 src/osprey/capabilities/optimization.py create mode 100644 src/osprey/prompts/defaults/optimization.py create mode 100644 src/osprey/services/xopt_optimizer/__init__.py create mode 100644 src/osprey/services/xopt_optimizer/analysis/__init__.py create mode 100644 src/osprey/services/xopt_optimizer/analysis/node.py create mode 100644 src/osprey/services/xopt_optimizer/approval/__init__.py create mode 100644 src/osprey/services/xopt_optimizer/approval/node.py create mode 100644 src/osprey/services/xopt_optimizer/decision/__init__.py create mode 100644 src/osprey/services/xopt_optimizer/decision/node.py create mode 100644 src/osprey/services/xopt_optimizer/exceptions.py create mode 100644 src/osprey/services/xopt_optimizer/execution/__init__.py create mode 100644 src/osprey/services/xopt_optimizer/execution/node.py create mode 100644 src/osprey/services/xopt_optimizer/models.py create mode 100644 src/osprey/services/xopt_optimizer/service.py create mode 100644 src/osprey/services/xopt_optimizer/state_identification/__init__.py create mode 100644 src/osprey/services/xopt_optimizer/state_identification/agent.py create mode 100644 src/osprey/services/xopt_optimizer/state_identification/node.py create mode 100644 src/osprey/services/xopt_optimizer/state_identification/tools/__init__.py create mode 100644 src/osprey/services/xopt_optimizer/state_identification/tools/channel_access.py create mode 100644 src/osprey/services/xopt_optimizer/state_identification/tools/reference_files.py create mode 100644 src/osprey/services/xopt_optimizer/yaml_generation/__init__.py create mode 100644 src/osprey/services/xopt_optimizer/yaml_generation/agent.py create mode 100644 src/osprey/services/xopt_optimizer/yaml_generation/node.py create mode 100644 tests/services/xopt_optimizer/__init__.py create mode 100644 tests/services/xopt_optimizer/test_state_identification.py create mode 100644 tests/services/xopt_optimizer/test_xopt_approval.py create mode 100644 tests/services/xopt_optimizer/test_xopt_exceptions.py create mode 100644 tests/services/xopt_optimizer/test_xopt_service.py create mode 100644 tests/services/xopt_optimizer/test_xopt_workflow.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e3b438fd4..21773bca7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **State**: Multi-iteration approval support with custom reducers - `overwrite_approval_bool` and `overwrite_approval_payload` reducers enable approval fields to be reset between iterations - Required for services like XOpt optimization that request approval multiple times per session +- **Services**: XOpt Optimizer Service for autonomous machine parameter optimization + - State identification node assesses machine readiness using ReAct agent + - Strategy decision node selects exploration vs. optimization approach + - YAML generation agent creates XOpt configurations with ReAct pattern + - Approval node integrates with human-in-the-loop workflow + - Execution and analysis nodes for running and evaluating optimizations + - Configurable modes: `react` (LLM-powered) or `mock` (fast testing) +- **Capabilities**: Optimization capability for routing optimization requests + - Provides `OPTIMIZATION_RESULT` context type for result handling +- **Prompts**: Default optimization prompt builder for XOpt workflows ### Changed - **CLI**: Improved approval panel styling for multi-iteration approval flows diff --git a/src/osprey/approval/__init__.py b/src/osprey/approval/__init__.py index 54f069783..5d24469ab 100644 --- a/src/osprey/approval/__init__.py +++ b/src/osprey/approval/__init__.py @@ -56,6 +56,7 @@ create_code_approval_interrupt, create_memory_approval_interrupt, create_plan_approval_interrupt, + create_xopt_approval_interrupt, get_approval_resume_data, get_approved_payload_from_state, handle_service_with_interrupts, @@ -75,6 +76,7 @@ "create_code_approval_interrupt", "create_memory_approval_interrupt", "create_channel_write_approval_interrupt", + "create_xopt_approval_interrupt", "get_approved_payload_from_state", "get_approval_resume_data", "clear_approval_state", diff --git a/src/osprey/approval/approval_system.py b/src/osprey/approval/approval_system.py index 8b73c0673..e2c2a03fc 100644 --- a/src/osprey/approval/approval_system.py +++ b/src/osprey/approval/approval_system.py @@ -472,6 +472,110 @@ def create_code_approval_interrupt( } +def create_xopt_approval_interrupt( + yaml_config: str, + strategy: str, + objective: str, + machine_state_details: dict[str, Any] | None = None, + step_objective: str = "Execute XOpt optimization", +) -> dict[str, Any]: + """Create structured interrupt data for XOpt optimization approval. + + Generates LangGraph-compatible interrupt data for XOpt configurations that + require human approval before execution. The interrupt provides comprehensive + context including the generated YAML, optimization strategy, and machine + state details. + + :param yaml_config: Generated XOpt YAML configuration + :type yaml_config: str + :param strategy: Selected optimization strategy (exploration/optimization) + :type strategy: str + :param objective: Optimization objective description + :type objective: str + :param machine_state_details: Optional machine state assessment details + :type machine_state_details: Dict[str, Any], optional + :param step_objective: High-level objective for user context + :type step_objective: str + :return: Dictionary containing user_message and resume_payload for LangGraph + :rtype: Dict[str, Any] + + Examples: + Basic XOpt approval:: + + >>> interrupt_data = create_xopt_approval_interrupt( + ... yaml_config="xopt:\\n generator: random", + ... strategy="exploration", + ... objective="Maximize injection efficiency", + ... step_objective="Execute XOpt optimization" + ... ) + >>> 'yes' in interrupt_data['user_message'] + True + + .. warning:: + This function is used for security-critical approval decisions for + optimization operations that may affect machine parameters. + """ + # Format machine state if available + machine_state_section = "" + if machine_state_details: + machine_state_section = f""" +**Machine State Assessment:** +``` +{_format_machine_state(machine_state_details)} +``` +""" + + user_message = f""" +⚠️ **HUMAN APPROVAL REQUIRED** ⚠️ + +**Task:** {step_objective} +**Optimization Objective:** {objective} +**Strategy:** {strategy.upper()} +{machine_state_section} +**Generated XOpt Configuration:** +```yaml +{yaml_config} +``` + +**Review the configuration above carefully.** + +**To proceed, respond with:** +- **`yes`** to approve and execute the optimization +- **`no`** to cancel this operation +""".strip() + + return { + "user_message": user_message, + "resume_payload": { + "approval_type": create_approval_type("xopt_optimizer"), + "step_objective": step_objective, + "yaml_config": yaml_config, + "strategy": strategy, + "objective": objective, + "machine_state_details": machine_state_details, + }, + } + + +def _format_machine_state(details: dict[str, Any]) -> str: + """Format machine state details for display. + + :param details: Machine state details dictionary + :type details: Dict[str, Any] + :return: Formatted string for display + :rtype: str + """ + lines = [] + for key, value in details.items(): + if isinstance(value, dict): + lines.append(f"{key}:") + for k, v in value.items(): + lines.append(f" {k}: {v}") + else: + lines.append(f"{key}: {value}") + return "\n".join(lines) + + # ============================================================================= # STREAMLINED APPROVAL HELPERS # ============================================================================= diff --git a/src/osprey/capabilities/optimization.py b/src/osprey/capabilities/optimization.py new file mode 100644 index 000000000..4059155a9 --- /dev/null +++ b/src/osprey/capabilities/optimization.py @@ -0,0 +1,460 @@ +"""Optimization Capability - Service Gateway for XOpt Machine Optimization. + +This capability acts as the gateway between the main agent graph and the +XOpt optimizer service, providing seamless integration for autonomous +machine optimization workflows with human approval, machine state awareness, +and result analysis. + +The capability provides a clean abstraction layer that: +1. **Service Integration**: Manages communication with the XOpt optimizer service +2. **Approval Workflows**: Integrates with the approval system for optimization control +3. **Context Management**: Handles context data passing and result context creation +4. **Error Handling**: Provides sophisticated error classification and recovery + +Key architectural features: + - Service gateway pattern for clean separation of concerns + - LangGraph-native approval workflow integration + - Comprehensive context management for cross-capability data flow + - Structured result processing with optimization metadata + - Error classification with domain-specific recovery strategies + +.. note:: + This capability requires the XOpt optimizer service to be available in the + framework registry. All optimization execution is managed by the separate service. + +.. warning:: + Optimization operations may require user approval depending on the configured + approval policies. Execution may be suspended pending user confirmation. + +.. seealso:: + :class:`osprey.services.xopt_optimizer.XOptOptimizerService` : Optimization service + :class:`OptimizationResultContext` : Optimization result context structure +""" + +from typing import Any, ClassVar + +from langgraph.types import Command +from pydantic import Field + +from osprey.approval import ( + clear_approval_state, + create_approval_type, + get_approval_resume_data, + handle_service_with_interrupts, +) +from osprey.base.capability import BaseCapability +from osprey.base.decorators import capability_node +from osprey.base.errors import ErrorClassification, ErrorSeverity +from osprey.base.examples import OrchestratorGuide, TaskClassifierGuide +from osprey.context.base import CapabilityContext +from osprey.prompts.loader import get_framework_prompts +from osprey.registry import get_registry +from osprey.services.xopt_optimizer import XOptExecutionRequest, XOptServiceResult +from osprey.state import StateManager +from osprey.utils.config import get_full_configuration +from osprey.utils.logger import get_logger + +# Module-level logger for helper functions +logger = get_logger("optimization") + + +# ======================================================== +# Context Class +# ======================================================== + + +class OptimizationResultContext(CapabilityContext): + """Context for XOpt optimization results. + + Provides structured context for optimization execution results including + the run artifact, strategy used, iteration count, and analysis summary. + This context enables other capabilities to access optimization outcomes + for downstream processing or response generation. + + :param run_artifact: Optimization run output data + :type run_artifact: Dict[str, Any] + :param strategy: Strategy used (exploration/optimization) + :type strategy: str + :param total_iterations: Number of iterations completed + :type total_iterations: int + :param analysis_summary: Summary of optimization analysis + :type analysis_summary: Dict[str, Any] + :param recommendations: List of recommendations from analysis + :type recommendations: List[str] + :param generated_yaml: XOpt YAML configuration used + :type generated_yaml: str + + .. note:: + The run_artifact contains the primary optimization outputs that + other capabilities can use for further processing or analysis. + + .. seealso:: + :class:`osprey.context.base.CapabilityContext` : Base context functionality + :class:`osprey.services.xopt_optimizer.XOptServiceResult` : Service result structure + """ + + run_artifact: dict[str, Any] = Field(default_factory=dict) + strategy: str = "" + total_iterations: int = 0 + analysis_summary: dict[str, Any] = Field(default_factory=dict) + recommendations: list[str] = Field(default_factory=list) + generated_yaml: str = "" + + CONTEXT_TYPE: ClassVar[str] = "OPTIMIZATION_RESULT" + CONTEXT_CATEGORY: ClassVar[str] = "OPTIMIZATION_DATA" + + @property + def context_type(self) -> str: + """Return the context type identifier.""" + return self.CONTEXT_TYPE + + def get_access_details(self, key: str) -> dict[str, Any]: + """Provide access information for optimization results. + + :param key: Context key name for access pattern generation + :type key: str + :return: Dictionary containing access details and patterns + :rtype: Dict[str, Any] + """ + return { + "run_artifact": "Optimization run output data", + "strategy": f"Strategy used: {self.strategy}", + "total_iterations": f"Completed {self.total_iterations} iterations", + "analysis_summary": "Summary of optimization analysis", + "recommendations": "List of recommendations from analysis", + "generated_yaml": "XOpt YAML configuration used", + "access_pattern": f"context.OPTIMIZATION_RESULT.{key}", + } + + def get_summary(self) -> dict[str, Any]: + """Generate summary for display and LLM processing. + + :return: Dictionary containing summarized optimization results + :rtype: Dict[str, Any] + """ + return { + "type": "Optimization Result", + "strategy": self.strategy, + "iterations": self.total_iterations, + "recommendations_count": len(self.recommendations), + "has_artifact": bool(self.run_artifact), + } + + +def _create_optimization_context(service_result: XOptServiceResult) -> OptimizationResultContext: + """Create OptimizationResultContext from service result. + + This helper function transforms the XOpt service result into a structured + context object that can be stored in state and accessed by other capabilities. + + :param service_result: Result from XOpt optimizer service + :type service_result: XOptServiceResult + :return: Structured context for optimization results + :rtype: OptimizationResultContext + """ + return OptimizationResultContext( + run_artifact=service_result.run_artifact, + strategy=service_result.strategy.value, + total_iterations=service_result.total_iterations, + analysis_summary=service_result.analysis_summary, + recommendations=list(service_result.recommendations), + generated_yaml=service_result.generated_yaml, + ) + + +# ======================================================== +# Convention-Based Capability Implementation +# ======================================================== + + +@capability_node +class OptimizationCapability(BaseCapability): + """Machine optimization capability using XOpt. + + Acts as the gateway between the main agent graph and the XOpt optimizer + service, providing seamless integration for optimization workflows + with approval handling and result processing. + + This is a framework-level capability that can be customized per facility + through the prompt builder system and configuration. + + Key architectural features: + - Service gateway pattern for clean separation between capability and service + - Comprehensive context management for cross-capability data access + - LangGraph-native approval workflow integration with interrupt handling + - Structured result processing with optimization metadata + + .. note:: + Requires XOpt optimizer service availability in framework registry. + All actual optimization execution is delegated to the service. + + .. warning:: + Optimization operations may trigger approval interrupts that suspend + execution until user confirmation is received. + """ + + name = "optimization" + description = "Optimize machine parameters using XOpt autonomous optimization" + provides = ["OPTIMIZATION_RESULT"] + requires = [] # Can optionally use CHANNEL_ADDRESSES if available + + # ======================================== + # ORCHESTRATOR / CLASSIFIER GUIDES + # ======================================== + + def _create_orchestrator_guide(self) -> OrchestratorGuide | None: + """Create orchestrator guide from prompt builder system. + + Retrieves orchestrator guidance from the application's prompt builder + system. This guide teaches the orchestrator when and how to include + optimization steps in execution plans. + + :return: Orchestrator guide with examples and notes + :rtype: Optional[OrchestratorGuide] + """ + prompt_provider = get_framework_prompts() + optimization_builder = prompt_provider.get_optimization_prompt_builder() + + return optimization_builder.get_orchestrator_guide() + + def _create_classifier_guide(self) -> TaskClassifierGuide | None: + """Create task classification guide from prompt builder system. + + Retrieves task classification guidance from the application's prompt + builder system. This guide teaches the classifier when user requests + should be routed to optimization operations. + + :return: Classifier guide with examples + :rtype: Optional[TaskClassifierGuide] + """ + prompt_provider = get_framework_prompts() + optimization_builder = prompt_provider.get_optimization_prompt_builder() + + return optimization_builder.get_classifier_guide() + + # ======================================== + # MAIN EXECUTION + # ======================================== + + async def execute(self) -> dict[str, Any]: + """Execute optimization with service integration and approval handling. + + Implements the complete optimization workflow including service invocation, + approval management, and result processing. The method handles both normal + execution scenarios and approval resume scenarios with proper state management. + + :return: State updates with optimization results and context data + :rtype: Dict[str, Any] + + :raises RuntimeError: If XOpt optimizer service is not available in registry + :raises XOptExecutionError: If optimization execution fails + """ + + # ======================================== + # GENERIC SETUP (needed for both paths) + # ======================================== + + # Get unified logger with automatic streaming support + cap_logger = self.get_logger() + step = self._step + + cap_logger.status("Initializing XOpt optimizer service...") + + # Get XOpt service from registry (runtime lookup) + registry = get_registry() + xopt_service = registry.get_service("xopt_optimizer") + + if not xopt_service: + raise RuntimeError("XOpt optimizer service not available in registry") + + # Get the full configurable from main graph + main_configurable = get_full_configuration() + + # Create service config by extending main graph's configurable + service_config = { + "configurable": { + **main_configurable, + "thread_id": f"xopt_service_{step.get('context_key', 'default')}", + "checkpoint_ns": "xopt_optimizer", + } + } + + # ======================================== + # APPROVAL CASE (handle first) + # ======================================== + + # Check if this is a resume from approval + has_approval_resume, approved_payload = get_approval_resume_data( + self._state, create_approval_type("xopt_optimizer") + ) + + if has_approval_resume: + if approved_payload: + cap_logger.resume("Sending approval response to XOpt optimizer service") + resume_response = {"approved": True} + resume_response.update(approved_payload) + else: + cap_logger.key_info("XOpt optimization was rejected by user") + resume_response = {"approved": False} + + try: + service_result = await xopt_service.ainvoke( + Command(resume=resume_response), config=service_config + ) + + cap_logger.info("XOpt optimizer service completed successfully after approval") + approval_cleanup = clear_approval_state() + + # Process results + results_context = _create_optimization_context(service_result) + cap_logger.success( + f"Optimization complete - {service_result.total_iterations} iterations" + ) + + # Store context and merge cleanup + result_updates = StateManager.store_context( + self._state, + "OPTIMIZATION_RESULT", + step.get("context_key"), + results_context, + ) + result_updates.update(approval_cleanup) + return result_updates + + except Exception as e: + # Import here to avoid circular imports + from langgraph.errors import GraphInterrupt + from langgraph.types import interrupt + + # Check if this is a GraphInterrupt (service looped and needs approval for next iteration) + if isinstance(e, GraphInterrupt): + cap_logger.info( + "XOptOptimizer: Service completed iteration and requests approval for next" + ) + + try: + # Extract interrupt data from GraphInterrupt + interrupt_data = e.args[0][0].value + cap_logger.debug( + f"XOptOptimizer: Extracted interrupt data with keys: {list(interrupt_data.keys())}" + ) + + # Re-raise interrupt in main graph context for next iteration + cap_logger.info( + "⏸️ XOptOptimizer: Creating approval interrupt for next iteration" + ) + interrupt(interrupt_data) + + # This line should never be reached + cap_logger.error( + "UNEXPECTED: interrupt() returned instead of pausing execution" + ) + raise RuntimeError("Interrupt mechanism failed in XOptOptimizer") + + except (IndexError, KeyError, AttributeError) as extract_error: + cap_logger.error( + f"XOptOptimizer: Failed to extract interrupt data: {extract_error}" + ) + raise RuntimeError( + f"XOptOptimizer: Failed to handle service interrupt: {extract_error}" + ) from extract_error + else: + # Re-raise non-interrupt exceptions + raise + + # ======================================== + # NORMAL EXECUTION (new request) + # ======================================== + + user_query = self._state.get("input_output", {}).get("user_query", "") + task_objective = self.get_task_objective(default="") + capability_contexts = self._state.get("capability_context_data", {}) + + # Create execution request + execution_request = XOptExecutionRequest( + user_query=user_query, + optimization_objective=task_objective, + capability_context_data=capability_contexts, + require_approval=True, + max_iterations=3, + ) + + cap_logger.status("Invoking XOpt optimizer service...") + + # Handle service invocation with interrupt support + service_result = await handle_service_with_interrupts( + service=xopt_service, + request=execution_request, + config=service_config, + logger=cap_logger, + capability_name="XOptOptimizer", + ) + + # ======================================== + # RESULT PROCESSING + # ======================================== + + cap_logger.status("Processing optimization results...") + + results_context = _create_optimization_context(service_result) + + cap_logger.success(f"Optimization complete - {service_result.total_iterations} iterations") + + return StateManager.store_context( + self._state, + "OPTIMIZATION_RESULT", + step.get("context_key"), + results_context, + ) + + # ======================================== + # ERROR CLASSIFICATION + # ======================================== + + @staticmethod + def classify_error(exc: Exception, context: dict) -> ErrorClassification: + """Classify optimization errors for appropriate handling. + + :param exc: The exception that occurred + :type exc: Exception + :param context: Additional context about the error + :type context: dict + :return: Error classification with severity and user message + :rtype: ErrorClassification + """ + from osprey.services.xopt_optimizer.exceptions import ( + MachineStateAssessmentError, + XOptExecutionError, + YamlGenerationError, + ) + + if isinstance(exc, MachineStateAssessmentError): + return ErrorClassification( + severity=ErrorSeverity.REPLANNING, + user_message=f"Machine not ready for optimization: {exc}", + metadata={ + "replanning_reason": str(exc), + "suggestions": ["Wait for machine conditions to improve", "Check interlocks"], + }, + ) + + elif isinstance(exc, YamlGenerationError): + return ErrorClassification( + severity=ErrorSeverity.REPLANNING, + user_message=f"Failed to generate optimization configuration: {exc}", + metadata={"replanning_reason": str(exc)}, + ) + + elif isinstance(exc, XOptExecutionError): + return ErrorClassification( + severity=ErrorSeverity.CRITICAL, + user_message=f"Optimization execution failed: {exc}", + metadata={"safety_abort_reason": str(exc)}, + ) + + else: + return ErrorClassification( + severity=ErrorSeverity.CRITICAL, + user_message=f"Unexpected optimization error: {exc}", + metadata={"safety_abort_reason": str(exc)}, + ) diff --git a/src/osprey/connectors/control_system/mock_connector.py b/src/osprey/connectors/control_system/mock_connector.py index 15985b8ed..9b24c7189 100644 --- a/src/osprey/connectors/control_system/mock_connector.py +++ b/src/osprey/connectors/control_system/mock_connector.py @@ -438,6 +438,12 @@ def _generate_initial_value(self, channel_name: str) -> float: return 0.0 elif "energy" in ch_lower: return 1900.0 # MeV for typical storage ring + elif "status" in ch_lower or "ready" in ch_lower or "enable" in ch_lower: + return 1.0 # Status/ready/enable flags default to 1 (on/ready) + elif "interlock" in ch_lower: + return 0.0 # Safety interlocks default to 0 (clear/no active interlocks) + elif "mode" in ch_lower: + return 1.0 # Operating mode defaults to 1 (operational) else: return 100.0 diff --git a/src/osprey/prompts/defaults/__init__.py b/src/osprey/prompts/defaults/__init__.py index 0b179c83c..504ddb69f 100644 --- a/src/osprey/prompts/defaults/__init__.py +++ b/src/osprey/prompts/defaults/__init__.py @@ -8,6 +8,7 @@ from .classification import DefaultClassificationPromptBuilder from .error_analysis import DefaultErrorAnalysisPromptBuilder from .memory_extraction import DefaultMemoryExtractionPromptBuilder +from .optimization import DefaultOptimizationPromptBuilder from .orchestrator import DefaultOrchestratorPromptBuilder from .python import DefaultPythonPromptBuilder from .response_generation import DefaultResponseGenerationPromptBuilder @@ -35,6 +36,7 @@ def __init__(self): self._memory_extraction_builder = DefaultMemoryExtractionPromptBuilder() self._time_range_parsing_builder = DefaultTimeRangeParsingPromptBuilder() self._python_builder = DefaultPythonPromptBuilder() + self._optimization_builder = DefaultOptimizationPromptBuilder() # ================================================================= # Infrastructure prompts @@ -71,6 +73,9 @@ def get_time_range_parsing_prompt_builder(self) -> "FrameworkPromptBuilder": def get_python_prompt_builder(self) -> "FrameworkPromptBuilder": return self._python_builder + def get_optimization_prompt_builder(self) -> "FrameworkPromptBuilder": + return self._optimization_builder + __all__ = [ "DefaultClassificationPromptBuilder", @@ -81,6 +86,7 @@ def get_python_prompt_builder(self) -> "FrameworkPromptBuilder": "DefaultMemoryExtractionPromptBuilder", "DefaultTimeRangeParsingPromptBuilder", "DefaultPythonPromptBuilder", + "DefaultOptimizationPromptBuilder", "DefaultOrchestratorPromptBuilder", "DefaultPromptProvider", "TaskExtractionExample", diff --git a/src/osprey/prompts/defaults/optimization.py b/src/osprey/prompts/defaults/optimization.py new file mode 100644 index 000000000..35face1c5 --- /dev/null +++ b/src/osprey/prompts/defaults/optimization.py @@ -0,0 +1,272 @@ +""" +Optimization Capability Prompt Builder + +Default prompts for XOpt optimization capability. +Provides baseline prompts that can be overridden by facility-specific implementations. +""" + +import textwrap + +from osprey.base import ( + ClassifierActions, + ClassifierExample, + OrchestratorExample, + OrchestratorGuide, + PlannedStep, + TaskClassifierGuide, +) +from osprey.prompts.base import FrameworkPromptBuilder +from osprey.registry import get_registry + + +class DefaultOptimizationPromptBuilder(FrameworkPromptBuilder): + """Default optimization capability prompt builder. + + Provides baseline prompts for XOpt optimization workflows. Facilities can + override this builder to inject domain-specific instructions, machine + state definitions, and optimization strategies. + + Override Points: + - get_instructions(): Domain-specific optimization guidance + - get_machine_state_definitions(): Facility-specific machine states + - get_yaml_generation_guidance(): Historical patterns and templates + - get_strategy_selection_guidance(): Strategy selection criteria + """ + + PROMPT_TYPE = "optimization" + + def get_role_definition(self) -> str: + """Get the role definition for optimization. + + :return: Role definition string + :rtype: str + """ + return "You are an expert optimization assistant helping to configure and execute autonomous machine optimization using XOpt." + + def get_task_definition(self) -> str: + """Get the task definition for optimization. + + :return: Task definition or None if task is provided externally + :rtype: Optional[str] + """ + return None # Task is provided via request + + def get_instructions(self) -> str: + """Get domain-specific optimization instructions. + + These instructions are domain-agnostic and apply to all optimization operations. + Facilities should override to provide machine-specific guidance. + + :return: Instructions string for optimization workflows + :rtype: str + """ + # Placeholder - facilities override with domain-specific instructions + return textwrap.dedent( + """ + === OPTIMIZATION INSTRUCTIONS === + + This is a placeholder for domain-specific optimization instructions. + + When implementing actual optimization: + 1. Assess machine readiness before proceeding + 2. Use appropriate optimization strategy (exploration vs optimization) + 3. Generate valid XOpt YAML configuration + 4. Request human approval before execution + 5. Analyze results and provide recommendations + + NOTE: Actual optimization parameters, channel addresses, and safety limits + will be defined based on facility-specific requirements. + """ + ).strip() + + def get_machine_state_definitions(self) -> dict[str, str]: + """Get facility-specific machine state definitions. + + :return: Mapping of state names to descriptions + :rtype: Dict[str, str] + """ + # Placeholder - default states, facilities override + return { + "ready": "Machine is ready for optimization", + "not_ready": "Machine cannot proceed with optimization", + "unknown": "Machine state assessment inconclusive", + } + + def get_yaml_generation_guidance(self) -> str: + """Get guidance for XOpt YAML configuration generation. + + :return: Domain-specific YAML generation guidance + :rtype: str + """ + # Placeholder - facilities override with facility-specific templates + return textwrap.dedent( + """ + YAML Generation Guidance (Placeholder): + + When generating XOpt YAML configurations: + - Use valid XOpt schema structure + - Define appropriate variables, objectives, and constraints + - Select suitable generator and evaluator + + NOTE: Actual YAML templates and parameter definitions will be + provided based on facility-specific requirements and historical examples. + """ + ).strip() + + def get_strategy_selection_guidance(self) -> str: + """Get guidance for exploration vs optimization strategy selection. + + :return: Decision criteria for strategy selection + :rtype: str + """ + # Placeholder - facilities override + return textwrap.dedent( + """ + Strategy Selection Guidance (Placeholder): + + - EXPLORATION: Use when exploring unknown parameter space + - OPTIMIZATION: Use when refining known good regions + + NOTE: Actual strategy selection criteria will be defined + based on operational requirements and machine state. + """ + ).strip() + + def get_orchestrator_guide(self) -> OrchestratorGuide | None: + """Create orchestrator guide for optimization capability. + + :return: Orchestrator guide with examples and instructions + :rtype: Optional[OrchestratorGuide] + """ + registry = get_registry() + + # Define structured examples + basic_optimization_example = OrchestratorExample( + step=PlannedStep( + context_key="optimization_results", + capability="optimization", + task_objective="Optimize the injection efficiency using XOpt", + expected_output=registry.context_types.OPTIMIZATION_RESULT, + success_criteria="Optimization completed with improved efficiency metrics", + inputs=[], + ), + scenario_description="Autonomous optimization of machine parameters", + notes=f"SINGLE STEP ONLY. The optimization service internally handles all machine investigation, channel discovery, and analysis. Output stored under {registry.context_types.OPTIMIZATION_RESULT}.", + ) + + tuning_example = OrchestratorExample( + step=PlannedStep( + context_key="tuning_results", + capability="optimization", + task_objective="Tune magnet settings for improved beam quality", + expected_output=registry.context_types.OPTIMIZATION_RESULT, + success_criteria="Magnet settings optimized with measurable improvement", + inputs=[], + ), + scenario_description="Parameter tuning for specific performance goals", + notes="SINGLE STEP ONLY. Do NOT pre-plan channel_finding or python steps - the optimization service handles this internally.", + ) + + return OrchestratorGuide( + instructions=textwrap.dedent( + f""" + **CRITICAL: Optimization is a SELF-CONTAINED, AUTONOMOUS service.** + + The optimization capability is NOT a simple executor - it is an intelligent + agent service that INTERNALLY handles: + - Machine state investigation (finding channels, reading values) + - Strategy selection (exploration vs optimization) + - Configuration generation (creating XOpt YAML) + - Execution and result analysis + + **DO NOT orchestrate pre-requisite steps before optimization.** + + WRONG approach (do NOT do this): + 1. channel_finding -> find injection channels + 2. python -> analyze current state + 3. optimization -> run optimization + + CORRECT approach (do this): + 1. optimization -> "Optimize injection efficiency" (single step) + + The optimization service will autonomously investigate the machine, + find relevant channels, assess readiness, and handle everything. + + **When to plan "optimization" steps:** + - User requests autonomous tuning or optimization of machine parameters + - Need to maximize or minimize a performance metric + - User wants to explore parameter space or find optimal settings + - Multi-parameter search or exploration is required + + **Step Structure:** + - context_key: Unique identifier for output (e.g., "optimization_results") + - task_objective: Clear, high-level description of the optimization goal + - inputs: Empty or minimal - the service investigates on its own + + **Output: {registry.context_types.OPTIMIZATION_RESULT}** + - Contains: Run artifact, strategy used, iteration count, recommendations + - Available to downstream steps via context system + - Includes generated XOpt configuration and analysis + + **Important Notes:** + - Human approval is ALWAYS required before execution + - The service includes its own result analysis and recommendations + - Just describe WHAT to optimize, not HOW to investigate the machine + """ + ), + examples=[basic_optimization_example, tuning_example], + priority=50, + ) + + def get_classifier_guide(self) -> TaskClassifierGuide | None: + """Create classifier guide for optimization capability. + + :return: Classifier guide with examples + :rtype: Optional[TaskClassifierGuide] + """ + return TaskClassifierGuide( + instructions="Determine if the user query requires autonomous machine optimization, parameter tuning, or multi-parameter search.", + examples=[ + ClassifierExample( + query="Optimize the injection efficiency", + result=True, + reason="This requires autonomous optimization of machine parameters.", + ), + ClassifierExample( + query="What is the current beam current?", + result=False, + reason="This is a read operation, not optimization.", + ), + ClassifierExample( + query="Tune the magnets for better beam quality", + result=True, + reason="This requires parameter tuning/optimization.", + ), + ClassifierExample( + query="Set the magnet to 5 amps", + result=False, + reason="This is a direct write operation, not autonomous optimization.", + ), + ClassifierExample( + query="Find the optimal settings for maximum intensity", + result=True, + reason="This requires optimization to find optimal parameters.", + ), + ClassifierExample( + query="Plot the beam current over time", + result=False, + reason="This is a visualization request, not optimization.", + ), + ClassifierExample( + query="Run an optimization campaign on the injector", + result=True, + reason="This explicitly requests an optimization campaign.", + ), + ClassifierExample( + query="Maximize the charge at the end of the linac", + result=True, + reason="This requires optimization to maximize a metric.", + ), + ], + actions_if_true=ClassifierActions(), + ) diff --git a/src/osprey/prompts/loader.py b/src/osprey/prompts/loader.py index bffe98496..52da2fef7 100644 --- a/src/osprey/prompts/loader.py +++ b/src/osprey/prompts/loader.py @@ -300,6 +300,28 @@ def get_python_prompt_builder(self) -> FrameworkPromptBuilder: """ raise NotImplementedError + def get_optimization_prompt_builder(self) -> FrameworkPromptBuilder: + """Provide prompt builder for XOpt optimization capability. + + This prompt builder is used by the optimization capability to + configure and execute autonomous machine optimization using XOpt. + It includes guidance for machine state assessment, YAML generation, + and strategy selection. + + :return: Optimization capability prompt builder instance + :rtype: FrameworkPromptBuilder + :raises NotImplementedError: Must be implemented by concrete providers + + .. note:: + Optimization prompts should include facility-specific machine states, + historical YAML examples, and domain-specific optimization patterns. + + .. seealso:: + :class:`OptimizationCapability` : Framework capability that uses this prompt + :class:`XOptOptimizerService` : Optimization service infrastructure + """ + raise NotImplementedError + class FrameworkPromptLoader: """Global registry and dependency injection system for framework prompt providers. diff --git a/src/osprey/registry/registry.py b/src/osprey/registry/registry.py index ecb7f38e1..1a7a704a4 100644 --- a/src/osprey/registry/registry.py +++ b/src/osprey/registry/registry.py @@ -316,6 +316,16 @@ def get_registry_config(self) -> RegistryConfig: requires=[], functional_node="python_node", ), + # Optimization capability (framework-level) + CapabilityRegistration( + name="optimization", + module_path="osprey.capabilities.optimization", + class_name="OptimizationCapability", + description="Autonomous machine parameter optimization using XOpt", + provides=["OPTIMIZATION_RESULT"], + requires=[], + functional_node="optimization_node", + ), # Communication capabilities (framework-level) - always active CapabilityRegistration( name="respond", @@ -367,6 +377,12 @@ def get_registry_config(self) -> RegistryConfig: module_path="osprey.capabilities.python", class_name="PythonResultsContext", ), + # Optimization result context (framework-level) + ContextClassRegistration( + context_type="OPTIMIZATION_RESULT", + module_path="osprey.capabilities.optimization", + class_name="OptimizationResultContext", + ), ], # Framework-level data sources data_sources=[ @@ -396,6 +412,23 @@ def get_registry_config(self) -> RegistryConfig: "python_approval_node", ], ), + # XOpt optimizer service (framework-level) + ServiceRegistration( + name="xopt_optimizer", + module_path="osprey.services.xopt_optimizer.service", + class_name="XOptOptimizerService", + description="XOpt-based autonomous machine optimization service", + provides=["OPTIMIZATION_RESULT"], + requires=[], + internal_nodes=[ + "state_identification", + "decision", + "yaml_generation", + "approval", + "execution", + "analysis", + ], + ), ], # Framework prompt providers (defaults - typically overridden by applications) framework_prompt_providers=[ @@ -411,6 +444,7 @@ def get_registry_config(self) -> RegistryConfig: "memory_extraction": "DefaultMemoryExtractionPromptBuilder", "time_range_parsing": "DefaultTimeRangeParsingPromptBuilder", "python": "DefaultPythonPromptBuilder", + "optimization": "DefaultOptimizationPromptBuilder", }, ) ], diff --git a/src/osprey/services/xopt_optimizer/__init__.py b/src/osprey/services/xopt_optimizer/__init__.py new file mode 100644 index 000000000..42bb0f262 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/__init__.py @@ -0,0 +1,62 @@ +"""XOpt Optimizer Service - Autonomous Machine Optimization. + +This service provides a LangGraph-based subsystem for autonomous machine optimization +in accelerator control environments. It follows the same architectural patterns as +the Python Executor Service, providing intelligent optimization workflows with +human approval, machine state awareness, and result analysis. + +Key Components: + - XOptOptimizerService: Main LangGraph orchestrator for optimization workflows + - XOptExecutionRequest: Request model for optimization execution + - XOptServiceResult: Structured result from optimization execution + - XOptExecutionState: Internal LangGraph state for workflow tracking + +Design Principles: + - Framework-level service adaptable to any facility through configuration + - Prompt builder system for facility-specific customization + - Configuration-driven machine state definitions + - Pluggable tools that leverage existing Osprey capabilities + +.. seealso:: + :mod:`osprey.services.python_executor` : Similar service for Python execution + :mod:`osprey.capabilities.optimization` : Capability that uses this service +""" + +from .exceptions import ( + ConfigurationError, + ErrorCategory, + MachineStateAssessmentError, + MaxIterationsExceededError, + XOptExecutionError, + XOptExecutorException, + YamlGenerationError, +) +from .models import ( + MachineState, + XOptError, + XOptExecutionRequest, + XOptExecutionState, + XOptServiceResult, + XOptStrategy, +) +from .service import XOptOptimizerService + +__all__ = [ + # Service + "XOptOptimizerService", + # Models + "XOptExecutionRequest", + "XOptExecutionState", + "XOptServiceResult", + "XOptError", + "MachineState", + "XOptStrategy", + # Exceptions + "XOptExecutorException", + "ErrorCategory", + "MachineStateAssessmentError", + "YamlGenerationError", + "XOptExecutionError", + "MaxIterationsExceededError", + "ConfigurationError", +] diff --git a/src/osprey/services/xopt_optimizer/analysis/__init__.py b/src/osprey/services/xopt_optimizer/analysis/__init__.py new file mode 100644 index 000000000..c07598012 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/analysis/__init__.py @@ -0,0 +1,9 @@ +"""Analysis Subsystem for XOpt Optimizer. + +This subsystem analyzes XOpt results and decides whether to continue +with additional iterations or complete the optimization. +""" + +from .node import create_analysis_node + +__all__ = ["create_analysis_node"] diff --git a/src/osprey/services/xopt_optimizer/analysis/node.py b/src/osprey/services/xopt_optimizer/analysis/node.py new file mode 100644 index 000000000..784720574 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/analysis/node.py @@ -0,0 +1,102 @@ +"""Analysis Node for XOpt Optimizer Service. + +This node analyzes XOpt results and decides whether to continue +with additional iterations or complete the optimization. +""" + +from typing import Any + +from osprey.utils.logger import get_logger + +from ..models import XOptExecutionState + +logger = get_logger("xopt_optimizer") + + +def create_analysis_node(): + """Create the analysis node for LangGraph integration. + + This factory function creates a node that analyzes XOpt results + and decides whether to continue the optimization loop. + + Returns: + Async function that takes XOptExecutionState and returns state updates + """ + + async def analysis_node(state: XOptExecutionState) -> dict[str, Any]: + """Analyze XOpt results and decide whether to continue. + + Simple continuation logic based on iteration count. + Future implementation may include: + - Convergence detection + - Improvement rate analysis + - Domain-specific completion criteria + """ + node_logger = get_logger("xopt_optimizer", state=state) + node_logger.status("Analyzing optimization results...") + + run_artifact = state.get("run_artifact") + iteration = state.get("iteration_count", 0) + 1 + max_iterations = state.get("max_iterations", 3) + + # Simple continuation logic (can be refined) + # Future: Add convergence detection, improvement rate analysis, etc. + should_continue = iteration < max_iterations + + # Generate analysis result + # NOTE: This is a placeholder implementation for testing the workflow + analysis_result = { + "status": "PLACEHOLDER_TEST_SUCCESS", + "message": "XOpt optimizer service workflow test completed successfully", + "iteration": iteration, + "max_iterations": max_iterations, + "run_artifact": run_artifact, + "should_continue": should_continue, + "note": ( + "This is a placeholder implementation. All subsystems (state identification, " + "decision, YAML generation, approval, execution, analysis) executed successfully " + "with placeholder logic. Real optimization will be implemented when domain " + "requirements are defined by facility operators." + ), + } + + # Generate recommendations (placeholder - clearly indicate test status) + recommendations = [] + if should_continue: + recommendations.append(f"[TEST] Continuing to iteration {iteration + 1}") + else: + recommendations.append( + f"[TEST SUCCESS] XOpt optimizer workflow completed {iteration} iterations successfully" + ) + recommendations.append( + "[PLACEHOLDER] All subsystems executed with placeholder logic - " + "ready for real implementation when domain requirements are defined" + ) + recommendations.append( + "[NEXT STEPS] Implement real machine state assessment, YAML generation, " + "and XOpt execution based on facility-specific requirements" + ) + + node_logger.info(f"Iteration {iteration}/{max_iterations} complete") + + if should_continue: + node_logger.info("Continuing to next iteration") + return { + "analysis_result": analysis_result, + "recommendations": recommendations, + "iteration_count": iteration, + "should_continue": True, + "current_stage": "state_id", + } + else: + node_logger.info("Optimization complete") + return { + "analysis_result": analysis_result, + "recommendations": recommendations, + "iteration_count": iteration, + "should_continue": False, + "is_successful": True, + "current_stage": "complete", + } + + return analysis_node diff --git a/src/osprey/services/xopt_optimizer/approval/__init__.py b/src/osprey/services/xopt_optimizer/approval/__init__.py new file mode 100644 index 000000000..f75df93b4 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/approval/__init__.py @@ -0,0 +1,9 @@ +"""Approval Subsystem for XOpt Optimizer. + +This subsystem handles human approval for XOpt configurations using +the standard Osprey LangGraph interrupt pattern. +""" + +from .node import create_approval_node + +__all__ = ["create_approval_node"] diff --git a/src/osprey/services/xopt_optimizer/approval/node.py b/src/osprey/services/xopt_optimizer/approval/node.py new file mode 100644 index 000000000..4d9a49eed --- /dev/null +++ b/src/osprey/services/xopt_optimizer/approval/node.py @@ -0,0 +1,63 @@ +"""Approval Node for XOpt Optimizer Service. + +This node handles human approval for XOpt configurations using the standard +Osprey LangGraph interrupt pattern. The approval interrupt data is pre-created +by the yaml_generation node, following the pattern from Python executor's +analyzer node. +""" + +from typing import Any + +from langgraph.types import interrupt + +from osprey.utils.logger import get_logger + +from ..models import XOptExecutionState + +logger = get_logger("xopt_optimizer") + + +def create_approval_node(): + """Create a pure approval node function for LangGraph integration. + + This factory function creates a specialized approval node that serves as a + clean interrupt handler. The node is designed with single responsibility: + processing LangGraph interrupts for user approval. + + The approval interrupt data is pre-created by the yaml_generation node, + following the pattern from Python executor's analyzer node. + + Returns: + Async function that takes XOptExecutionState and returns state updates + """ + + async def approval_node(state: XOptExecutionState) -> dict[str, Any]: + """Process approval interrupt and return user response for workflow routing.""" + + # Get logger with streaming support + node_logger = get_logger("xopt_optimizer", state=state) + node_logger.status("Requesting human approval...") + + # Get the pre-created interrupt data from yaml_generation node + interrupt_data = state.get("approval_interrupt_data") + if not interrupt_data: + raise RuntimeError( + "No approval interrupt data found in state. " + "The yaml_generation node should create this data." + ) + + node_logger.info("Requesting human approval for XOpt configuration") + + # This is the ONLY critical line - everything else is routing + human_response = interrupt(interrupt_data) + + # Simple approval processing for routing + approved = human_response.get("approved", False) + node_logger.info(f"Approval result: {approved}") + + return { + "approval_result": human_response, + "approved": approved, + } + + return approval_node diff --git a/src/osprey/services/xopt_optimizer/decision/__init__.py b/src/osprey/services/xopt_optimizer/decision/__init__.py new file mode 100644 index 000000000..aaeace597 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/decision/__init__.py @@ -0,0 +1,14 @@ +"""Decision Subsystem for XOpt Optimizer. + +This subsystem routes the workflow based on machine state assessment, +selecting the appropriate optimization strategy (exploration, optimization, +or abort). + +Supports two modes (configured via xopt_optimizer.decision.mode): +- "llm": LLM-based decision making with structured output +- "mock": Fast placeholder that always selects exploration (default for tests) +""" + +from .node import DECISION_SYSTEM_PROMPT, StrategyDecision, create_decision_node + +__all__ = ["create_decision_node", "StrategyDecision", "DECISION_SYSTEM_PROMPT"] diff --git a/src/osprey/services/xopt_optimizer/decision/node.py b/src/osprey/services/xopt_optimizer/decision/node.py new file mode 100644 index 000000000..a4ed56def --- /dev/null +++ b/src/osprey/services/xopt_optimizer/decision/node.py @@ -0,0 +1,331 @@ +"""Decision Node for XOpt Optimizer Service. + +This node routes the workflow based on machine state assessment, +selecting the appropriate optimization strategy. + +Supports two modes (configured via xopt_optimizer.decision.mode): +- "llm": LLM-based decision making with structured output +- "mock": Fast placeholder that always selects exploration (default) +""" + +from typing import Any + +from pydantic import BaseModel, Field + +from osprey.utils.config import get_model_config, get_xopt_optimizer_config +from osprey.utils.logger import get_logger + +from ..models import MachineState, XOptExecutionState, XOptStrategy + +logger = get_logger("xopt_optimizer") + + +# ============================================================================= +# STRUCTURED OUTPUT MODEL FOR LLM DECISION +# ============================================================================= + + +class StrategyDecision(BaseModel): + """Structured output for LLM strategy decision. + + This model is used with LangChain's `with_structured_output` to ensure + the LLM returns a valid strategy selection with reasoning. + """ + + strategy: XOptStrategy = Field( + description="The optimization strategy: 'exploration' or 'optimization'" + ) + reasoning: str = Field( + description="Brief explanation of why this strategy was selected" + ) + + +# ============================================================================= +# SYSTEM PROMPT FOR LLM DECISION +# ============================================================================= + +DECISION_SYSTEM_PROMPT = """You are an expert accelerator optimizer decision system. + +Your task is to select the appropriate optimization strategy based on: +1. The **Machine State** value (ready, not_ready, or unknown) - this is the authoritative assessment +2. The user's optimization objective + +## Available Strategies + +- **exploration**: Use when the machine is READY and starting a new optimization campaign. + This strategy prioritizes coverage and discovery over immediate optimization. + +- **optimization**: Use when the machine is READY, the objective is clear, + and you want to aggressively optimize toward the goal. This strategy prioritizes + finding the best solution quickly. + +- **abort**: Use ONLY when machine_state is NOT_READY or UNKNOWN. + +## Decision Guidelines + +IMPORTANT: Trust the Machine State value. The state assessment has already been performed +by a dedicated agent that checked all safety criteria. Your job is to select a strategy +based on that assessment. + +1. If machine_state is "ready" -> Select EXPLORATION or OPTIMIZATION based on objective +2. If machine_state is "not_ready" or "unknown" -> Select ABORT +3. For new optimization objectives, prefer EXPLORATION +4. For well-defined objectives with good starting conditions, consider OPTIMIZATION + +Select the most appropriate strategy and explain your reasoning briefly.""" + + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + + +def _get_decision_config() -> dict[str, Any]: + """Get decision node configuration from osprey config. + + Reads from config structure: + xopt_optimizer: + decision: + mode: "mock" # or "llm" + model_config_name: "xopt_decision" # References models section + + Returns: + Configuration dict with mode and model_config + """ + xopt_config = get_xopt_optimizer_config() + decision_config = xopt_config.get("decision", {}) + + # Resolve model config from name reference + # Falls back to "orchestrator" model if xopt-specific model not configured + model_config = None + model_config_name = decision_config.get("model_config_name", "xopt_decision") + try: + model_config = get_model_config(model_config_name) + # Check if the model config is valid (has provider) + if not model_config or not model_config.get("provider"): + logger.debug( + f"Model '{model_config_name}' not configured, falling back to orchestrator" + ) + model_config = get_model_config("orchestrator") + except Exception as e: + logger.warning( + f"Could not load model config '{model_config_name}': {e}, " + "falling back to orchestrator" + ) + model_config = get_model_config("orchestrator") + + return { + "mode": decision_config.get("mode", "mock"), # Default to mock for fast testing + "model_config": model_config, + } + + +# ============================================================================= +# LLM-BASED DECISION +# ============================================================================= + + +async def _make_llm_decision( + objective: str, + machine_state: MachineState, + machine_state_details: dict[str, Any] | None, + model_config: dict[str, Any], +) -> StrategyDecision: + """Make strategy decision using LLM with structured output. + + Args: + objective: The optimization objective + machine_state: Current machine state assessment + machine_state_details: Additional machine state details + model_config: Model configuration for the LLM + + Returns: + StrategyDecision with selected strategy and reasoning + """ + from osprey.models.langchain import get_langchain_model + + # Get the model and configure for structured output + model = get_langchain_model(model_config=model_config) + structured_model = model.with_structured_output(StrategyDecision) + + # Build the user message with context + # Only include key summary info, not raw agent response (which may contain confusing intermediate reasoning) + details_text = "" + if machine_state_details: + # Extract only the relevant summary fields, excluding raw_response + summary_fields = {} + if "channels_checked" in machine_state_details: + summary_fields["channels_checked"] = machine_state_details["channels_checked"] + if "key_observations" in machine_state_details: + summary_fields["key_observations"] = machine_state_details["key_observations"] + if summary_fields: + details_text = f"\n\nMachine State Details:\n{summary_fields}" + + user_message = f"""Please select the appropriate optimization strategy. + +**Optimization Objective:** {objective} +**Machine State:** {machine_state.value} +{details_text} + +Based on this information, which strategy should we use?""" + + # Invoke the model + result = await structured_model.ainvoke( + [ + {"role": "system", "content": DECISION_SYSTEM_PROMPT}, + {"role": "user", "content": user_message}, + ] + ) + + return result + + +# ============================================================================= +# MOCK DECISION +# ============================================================================= + + +def _make_mock_decision( + machine_state: MachineState, + machine_state_details: dict[str, Any] | None, +) -> StrategyDecision: + """Make strategy decision using mock logic (for testing). + + Always selects exploration unless machine is not ready. + + Args: + machine_state: Current machine state assessment + machine_state_details: Additional machine state details + + Returns: + StrategyDecision with selected strategy and reasoning + """ + if machine_state == MachineState.NOT_READY: + reason = ( + machine_state_details.get("reason", "Machine not ready") + if machine_state_details + else "Machine not ready" + ) + return StrategyDecision( + strategy=XOptStrategy.ABORT, + reasoning=reason, + ) + + if machine_state == MachineState.UNKNOWN: + reason = ( + machine_state_details.get("reason", "Machine state unknown") + if machine_state_details + else "Machine state unknown" + ) + return StrategyDecision( + strategy=XOptStrategy.ABORT, + reasoning=reason, + ) + + # Default to exploration for mock mode + return StrategyDecision( + strategy=XOptStrategy.EXPLORATION, + reasoning="Machine ready, starting with exploration", + ) + + +# ============================================================================= +# NODE FACTORY +# ============================================================================= + + +def create_decision_node(): + """Create the decision node for LangGraph integration. + + This factory function creates a node that routes based on machine + state assessment, selecting the appropriate optimization strategy. + + The decision mode is controlled via configuration: + - xopt_optimizer.decision.mode: "mock" | "llm" + + Returns: + Async function that takes XOptExecutionState and returns state updates + """ + + async def decision_node(state: XOptExecutionState) -> dict[str, Any]: + """Route based on machine state assessment. + + Supports two modes: + - "mock": Fast placeholder that always selects exploration (default) + - "llm": LLM-based decision making with structured output + """ + node_logger = get_logger("xopt_optimizer", state=state) + # Get configuration + decision_config = _get_decision_config() + mode = decision_config.get("mode", "mock") + is_mock = mode == "mock" + mode_indicator = " (mock)" if is_mock else "" + + node_logger.status(f"Selecting optimization strategy{mode_indicator}...") + + machine_state = state.get("machine_state") + machine_state_details = state.get("machine_state_details") + request = state.get("request") + objective = request.optimization_objective if request else "Unknown objective" + + try: + # Make decision based on mode + if mode == "llm": + decision = await _make_llm_decision( + objective=objective, + machine_state=machine_state, + machine_state_details=machine_state_details, + model_config=decision_config.get("model_config"), + ) + else: + decision = _make_mock_decision( + machine_state=machine_state, + machine_state_details=machine_state_details, + ) + + # Handle abort strategy + if decision.strategy == XOptStrategy.ABORT: + node_logger.key_info(f"Strategy: ABORT{mode_indicator}") + return { + "selected_strategy": XOptStrategy.ABORT, + "decision_reasoning": decision.reasoning, + "is_failed": True, + "failure_reason": f"Strategy decision: {decision.reasoning}", + "current_stage": "failed", + } + + # Strategy selected successfully - log as key_info for visibility + node_logger.key_info(f"Strategy: {decision.strategy.value.upper()}{mode_indicator}") + + return { + "selected_strategy": decision.strategy, + "decision_reasoning": decision.reasoning, + "current_stage": "yaml_gen", + } + + except Exception as e: + node_logger.error(f"Strategy decision failed: {e}") + + # Fall back to mock decision on error + node_logger.warning("Falling back to mock decision due to error") + decision = _make_mock_decision( + machine_state=machine_state, + machine_state_details=machine_state_details, + ) + + if decision.strategy == XOptStrategy.ABORT: + return { + "selected_strategy": XOptStrategy.ABORT, + "decision_reasoning": f"Fallback: {decision.reasoning}", + "is_failed": True, + "failure_reason": f"Strategy decision failed: {e}", + "current_stage": "failed", + } + + return { + "selected_strategy": decision.strategy, + "decision_reasoning": f"Fallback: {decision.reasoning}", + "current_stage": "yaml_gen", + } + + return decision_node diff --git a/src/osprey/services/xopt_optimizer/exceptions.py b/src/osprey/services/xopt_optimizer/exceptions.py new file mode 100644 index 000000000..8d259ffa1 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/exceptions.py @@ -0,0 +1,147 @@ +"""Exception Hierarchy for XOpt Optimizer Service. + +This module defines a clean, categorized exception hierarchy that provides precise +error classification for all failure modes in the XOpt optimizer service. The +exceptions are designed to support intelligent retry logic, user-friendly error +reporting, and comprehensive debugging information. + +Error Categories: + - MACHINE_STATE: Machine not ready - may retry after delay + - YAML_GENERATION: Code generation issues - retry with feedback + - EXECUTION: XOpt runtime errors + - CONFIGURATION: Invalid configuration + - WORKFLOW: Service-level workflow issues +""" + +from enum import Enum +from typing import Any + + +class ErrorCategory(str, Enum): + """Categorization of errors for retry logic.""" + + MACHINE_STATE = "machine_state" # Machine not ready - may retry after delay + YAML_GENERATION = "yaml_generation" # Code generation issues - retry with feedback + EXECUTION = "execution" # XOpt runtime errors + CONFIGURATION = "configuration" # Invalid configuration + WORKFLOW = "workflow" # Service-level workflow issues + + +class XOptExecutorException(Exception): + """Base exception for all XOpt optimizer service errors. + + Provides categorization and structured error information for + proper error handling and retry logic. + + :param message: Human-readable error description + :param category: Error category for recovery strategy + :param technical_details: Additional technical information for debugging + """ + + def __init__( + self, + message: str, + category: ErrorCategory = ErrorCategory.WORKFLOW, + technical_details: dict[str, Any] | None = None, + ): + super().__init__(message) + self.message = message + self.category = category + self.technical_details = technical_details or {} + + def is_retriable(self) -> bool: + """Check if this error type typically warrants a retry.""" + return self.category in (ErrorCategory.MACHINE_STATE, ErrorCategory.YAML_GENERATION) + + def should_retry_yaml_generation(self) -> bool: + """Check if YAML should be regenerated.""" + return self.category == ErrorCategory.YAML_GENERATION + + +class MachineStateAssessmentError(XOptExecutorException): + """Failed to assess machine state. + + Raised when the state identification agent cannot determine + machine readiness. May be retryable after addressing machine issues. + + :param message: Error description + :param assessment_details: Details from the assessment attempt + """ + + def __init__( + self, + message: str, + assessment_details: dict[str, Any] | None = None, + **kwargs, + ): + super().__init__(message, category=ErrorCategory.MACHINE_STATE, **kwargs) + self.assessment_details = assessment_details or {} + + +class YamlGenerationError(XOptExecutorException): + """Failed to generate valid XOpt YAML configuration. + + Raised when the YAML generation agent produces invalid configuration. + Usually retryable with error feedback. + + :param message: Error description + :param generated_yaml: The invalid YAML that was generated + :param validation_errors: List of validation errors found + """ + + def __init__( + self, + message: str, + generated_yaml: str | None = None, + validation_errors: list[str] | None = None, + **kwargs, + ): + super().__init__(message, category=ErrorCategory.YAML_GENERATION, **kwargs) + self.generated_yaml = generated_yaml + self.validation_errors = validation_errors or [] + + +class XOptExecutionError(XOptExecutorException): + """XOpt execution failed at runtime. + + Raised when XOpt itself fails during execution. + + :param message: Error description + :param yaml_used: The YAML configuration that was used + :param xopt_error: The original XOpt error message + """ + + def __init__( + self, + message: str, + yaml_used: str | None = None, + xopt_error: str | None = None, + **kwargs, + ): + super().__init__(message, category=ErrorCategory.EXECUTION, **kwargs) + self.yaml_used = yaml_used + self.xopt_error = xopt_error + + +class MaxIterationsExceededError(XOptExecutorException): + """Maximum optimization iterations exceeded without convergence. + + :param message: Error description + :param iterations_completed: Number of iterations that were completed + """ + + def __init__(self, message: str, iterations_completed: int = 0, **kwargs): + super().__init__(message, category=ErrorCategory.WORKFLOW, **kwargs) + self.iterations_completed = iterations_completed + + +class ConfigurationError(XOptExecutorException): + """Invalid service configuration. + + :param message: Error description + :param config_key: The configuration key that is invalid + """ + + def __init__(self, message: str, config_key: str | None = None, **kwargs): + super().__init__(message, category=ErrorCategory.CONFIGURATION, **kwargs) + self.config_key = config_key diff --git a/src/osprey/services/xopt_optimizer/execution/__init__.py b/src/osprey/services/xopt_optimizer/execution/__init__.py new file mode 100644 index 000000000..1d154ac87 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/execution/__init__.py @@ -0,0 +1,13 @@ +"""Execution Subsystem for XOpt Optimizer. + +This subsystem executes XOpt optimization runs using the generated +YAML configuration. + +PLACEHOLDER: Current implementation is a no-op placeholder. +Actual XOpt execution will be implemented when XOpt prototype +integration is ready. +""" + +from .node import create_executor_node + +__all__ = ["create_executor_node"] diff --git a/src/osprey/services/xopt_optimizer/execution/node.py b/src/osprey/services/xopt_optimizer/execution/node.py new file mode 100644 index 000000000..fed5b09f1 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/execution/node.py @@ -0,0 +1,89 @@ +"""Execution Node for XOpt Optimizer Service. + +This node executes XOpt optimization runs using the generated YAML configuration. + +PLACEHOLDER: This implementation is a no-op that returns placeholder results. + +TODO: Replace with actual XOpt prototype integration when ready. +This will require: +- Integration with existing XOpt Python prototype +- Proper error handling for XOpt execution failures +- Result artifact capture + +DO NOT add accelerator-specific execution logic without operator input. +""" + +from typing import Any + +from osprey.utils.logger import get_logger + +from ..models import XOptExecutionState + +logger = get_logger("xopt_optimizer") + + +async def _run_xopt_placeholder(yaml_config: str) -> dict[str, Any]: + """Placeholder for XOpt execution. + + PLACEHOLDER: Returns mock results. + + TODO: Replace with actual XOpt prototype integration. + This will involve: + - Parsing the YAML configuration + - Setting up XOpt with proper generator and evaluator + - Running the optimization loop + - Capturing results and artifacts + """ + return { + "status": "completed", + "evaluations": 0, + "best_value": None, + "best_parameters": {}, + "yaml_used": yaml_config, + "note": "This is a placeholder result. Actual XOpt execution will be " + "implemented when XOpt prototype integration is ready.", + } + + +def create_executor_node(): + """Create the execution node for LangGraph integration. + + This factory function creates a node that executes XOpt optimization + runs. Currently implements a placeholder. + + Returns: + Async function that takes XOptExecutionState and returns state updates + """ + + async def executor_node(state: XOptExecutionState) -> dict[str, Any]: + """Execute XOpt optimization. + + PLACEHOLDER: Returns mock results. + """ + node_logger = get_logger("xopt_optimizer", state=state) + node_logger.status("Executing XOpt optimization...") + + yaml_config = state.get("generated_yaml") + + try: + # PLACEHOLDER: Call placeholder XOpt execution + run_artifact = await _run_xopt_placeholder(yaml_config) + + node_logger.info("XOpt execution completed") + return { + "run_artifact": run_artifact, + "execution_failed": False, + "current_stage": "analysis", + } + + except Exception as e: + node_logger.error(f"XOpt execution failed: {e}") + return { + "execution_error": str(e), + "execution_failed": True, + "is_failed": True, + "failure_reason": f"XOpt execution error: {e}", + "current_stage": "failed", + } + + return executor_node diff --git a/src/osprey/services/xopt_optimizer/models.py b/src/osprey/services/xopt_optimizer/models.py new file mode 100644 index 000000000..2f72d7f79 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/models.py @@ -0,0 +1,273 @@ +"""Core Models and State Management for XOpt Optimizer Service. + +This module provides the foundational data structures, state management classes, +and configuration utilities that support the XOpt optimizer service's +LangGraph-based workflow. + +The module is organized into several key areas: + +**Type Definitions**: Core data structures for execution requests, results, and +metadata tracking. These provide type-safe interfaces for service communication +and ensure consistent data handling across the optimization pipeline. + +**State Management**: LangGraph-compatible state classes that track execution +progress, approval workflows, and intermediate results throughout the service +execution lifecycle. + +**Enumerations**: Machine state and strategy enums that drive workflow decisions +and routing logic. + +Key Design Principles: + - **Type Safety**: All public interfaces use Pydantic models or dataclasses + with comprehensive type annotations + - **LangGraph Integration**: State classes implement TypedDict patterns for + seamless integration with LangGraph's state management and checkpointing + - **Placeholder-First**: Machine-affecting components use placeholders until + domain requirements are defined by operators +""" + +from __future__ import annotations + +import dataclasses +from dataclasses import dataclass, field +from enum import Enum +from typing import Annotated, Any, TypedDict + +from pydantic import BaseModel, Field + +from osprey.utils.logger import get_logger + +logger = get_logger("xopt_optimizer") + + +# ============================================================================= +# CUSTOM REDUCERS FOR STATE MANAGEMENT +# ============================================================================= + + +def preserve_once_set(existing: Any | None, new: Any | None) -> Any | None: + """Preserve field value once set - never allow it to be replaced or lost. + + This reducer ensures that critical fields like 'request' are never lost during + LangGraph state updates, including checkpoint resumption with Command objects. + + Args: + existing: Current value of the field (may be None) + new: New value being applied to the field (may be None) + + Returns: + The existing value if it's set, otherwise the new value + """ + if existing is not None: + return existing + return new + + +# ============================================================================= +# ENUMERATIONS +# ============================================================================= + + +class MachineState(str, Enum): + """Machine states for optimization readiness. + + NOTE: These are placeholders. Actual states will be determined + based on facility requirements and operator feedback. + + DO NOT add accelerator-specific states without operator input. + """ + + READY = "ready" # Machine ready for optimization + NOT_READY = "not_ready" # Cannot proceed (reason in details) + UNKNOWN = "unknown" # Assessment inconclusive + + # Domain-specific states to be added based on facility requirements, e.g.: + # NO_CHARGE = "no_charge" + # NO_BEAM = "no_beam" + # INTERLOCK_ACTIVE = "interlock_active" + + +class XOptStrategy(str, Enum): + """Optimization strategy to execute.""" + + EXPLORATION = "exploration" # Explore parameter space + OPTIMIZATION = "optimization" # Optimize toward goal + ABORT = "abort" # Cannot proceed + + +# ============================================================================= +# ERROR TRACKING +# ============================================================================= + + +@dataclass +class XOptError: + """Structured error information for debugging and iteration refinement. + + Captures error context to help subsequent nodes understand what failed + and potentially adjust their approach. + + :param error_type: Category of error (state_assessment, yaml_generation, execution, analysis) + :param error_message: Human-readable error message + :param stage: Pipeline stage where error occurred + :param attempt_number: Which attempt this error occurred in + :param details: Additional error details for debugging + """ + + error_type: str + error_message: str + stage: str + attempt_number: int = 1 + details: dict[str, Any] = field(default_factory=dict) + + def to_prompt_text(self) -> str: + """Format error for inclusion in agent prompts.""" + parts = [f"**Attempt {self.attempt_number} - {self.stage.upper()} FAILED**"] + parts.append(f"\n**Error Type:** {self.error_type}") + parts.append(f"**Error:** {self.error_message}") + if self.details: + parts.append(f"\n**Details:** {self.details}") + return "\n".join(parts) + + +# ============================================================================= +# REQUEST MODEL +# ============================================================================= + + +class XOptExecutionRequest(BaseModel): + """Request model for XOpt optimization service. + + Serializable request that captures all information needed to run + an optimization workflow. Matches the pattern from PythonExecutionRequest. + + :param user_query: User's optimization request + :param optimization_objective: What to optimize + :param max_iterations: Maximum optimization iterations + :param retries: Maximum YAML generation retries per iteration + :param reference_files_path: Path to reference documentation + :param historical_yamls_path: Path to historical YAML configurations + :param capability_context_data: Capability context from main graph state + :param require_approval: Whether human approval is required + :param session_context: Session info including chat_id, user_id + """ + + user_query: str = Field(..., description="User's optimization request") + optimization_objective: str = Field(..., description="What to optimize") + + max_iterations: int = Field(default=3, description="Maximum optimization iterations") + retries: int = Field(default=3, description="Maximum YAML generation retries per iteration") + + # Paths to reference data (configured per deployment) + reference_files_path: str | None = None + historical_yamls_path: str | None = None + + # Capability context (for cross-capability data access) + capability_context_data: dict[str, Any] | None = Field( + None, description="Capability context data from main graph state" + ) + + # Standard Osprey fields + require_approval: bool = Field(default=True) + session_context: dict[str, Any] | None = Field( + None, description="Session info including chat_id, user_id" + ) + + +# ============================================================================= +# SERVICE RESULT +# ============================================================================= + + +@dataclasses.dataclass(frozen=True, slots=True) +class XOptServiceResult: + """Structured, type-safe result from XOpt optimizer service. + + This eliminates the need for validation and error checking in capabilities. + The service guarantees this structure is always returned on success. + On failure, the service raises appropriate exceptions. + + :param run_artifact: Optimization run output data + :param generated_yaml: XOpt YAML configuration used + :param strategy: Strategy used (exploration/optimization) + :param total_iterations: Number of iterations completed + :param analysis_summary: Summary of optimization analysis + :param recommendations: List of recommendations from analysis + """ + + run_artifact: dict[str, Any] + generated_yaml: str + strategy: XOptStrategy + total_iterations: int + analysis_summary: dict[str, Any] + recommendations: tuple[str, ...] # Use tuple for frozen dataclass + + def __post_init__(self): + """Validate immutable structure.""" + # Frozen dataclass handles immutability + + +# ============================================================================= +# STATE MANAGEMENT +# ============================================================================= + + +class XOptExecutionState(TypedDict): + """LangGraph state for XOpt optimizer service. + + This state is used internally by the service and includes both the + original request and execution tracking fields. + + CRITICAL: The 'request' field uses the preserve_once_set reducer to ensure + it's never lost during state updates or checkpoint resumption (e.g., approval workflows). + + NOTE: capability_context_data is extracted to top level for ContextManager compatibility. + """ + + # Original request (preserved via reducer) - NEVER lost once set + request: Annotated[XOptExecutionRequest, preserve_once_set] + + # Capability context data (for cross-capability integration) + capability_context_data: dict[str, dict[str, dict[str, Any]]] | None + + # Error tracking (matches Python executor pattern) + error_chain: list[XOptError] + yaml_generation_attempt: int # For YAML regeneration retries + + # Machine state assessment + machine_state: MachineState | None + machine_state_details: dict[str, Any] | None # Readings, reasoning, etc. + + # Decision + selected_strategy: XOptStrategy | None + decision_reasoning: str | None + + # YAML configuration + generated_yaml: str | None + yaml_generation_failed: bool | None + + # Approval state (standard Osprey pattern) + requires_approval: bool | None + approval_interrupt_data: dict[str, Any] | None # LangGraph interrupt data + approval_result: dict[str, Any] | None # Response from interrupt + approved: bool | None # Final approval status + + # Execution + run_artifact: Any | None # In-memory result from XOpt + execution_error: str | None + execution_failed: bool | None + + # Analysis + analysis_result: dict[str, Any] | None + recommendations: list[str] | None + + # Loop control + iteration_count: int + max_iterations: int + should_continue: bool + + # Control flags + is_successful: bool + is_failed: bool + failure_reason: str | None + current_stage: str # "state_id", "decision", "yaml_gen", "approval", "execution", "analysis", "complete", "failed" diff --git a/src/osprey/services/xopt_optimizer/service.py b/src/osprey/services/xopt_optimizer/service.py new file mode 100644 index 000000000..01f3cdf7a --- /dev/null +++ b/src/osprey/services/xopt_optimizer/service.py @@ -0,0 +1,369 @@ +"""XOpt Optimizer Service - LangGraph-based Orchestrator. + +This module provides the main service class that orchestrates the XOpt optimization +workflow using LangGraph. It follows the same patterns as PythonExecutorService +for consistency across the Osprey framework. + +The service implements a multi-stage workflow: +1. State Identification - Assess machine readiness +2. Decision - Select optimization strategy +3. YAML Generation - Create XOpt configuration +4. Approval - Human approval of configuration +5. Execution - Run XOpt optimization +6. Analysis - Analyze results and decide continuation + +The workflow supports iteration loops where analysis can route back to +state identification for multi-iteration optimization campaigns. +""" + +from typing import Any + +from langgraph.graph import StateGraph +from langgraph.types import Command + +from osprey.graph.graph_builder import ( + create_async_postgres_checkpointer, + create_memory_checkpointer, +) +from osprey.utils.config import get_full_configuration +from osprey.utils.logger import get_logger + +from .analysis import create_analysis_node +from .approval import create_approval_node +from .decision import create_decision_node +from .exceptions import XOptExecutionError +from .execution import create_executor_node +from .models import ( + XOptExecutionRequest, + XOptExecutionState, + XOptServiceResult, + XOptStrategy, +) +from .state_identification import create_state_identification_node +from .yaml_generation import create_yaml_generation_node + +logger = get_logger("xopt_optimizer") + + +class XOptOptimizerService: + """XOpt Optimizer Service - LangGraph-based orchestrator. + + Follows the same patterns as PythonExecutorService for consistency + across the Osprey framework. + + The service provides: + - Multi-stage optimization workflow with approval gates + - Iterative optimization with configurable loop control + - Machine state assessment before optimization + - Configuration-driven strategy selection + """ + + def __init__(self): + """Initialize the XOpt optimizer service.""" + self.config = self._load_config() + self._compiled_graph = None + + def get_compiled_graph(self): + """Get the compiled LangGraph for this service. + + Lazily compiles the graph on first access. + + Returns: + CompiledGraph: The compiled LangGraph workflow + """ + if self._compiled_graph is None: + self._compiled_graph = self._build_and_compile_graph() + return self._compiled_graph + + async def ainvoke(self, input_data, config): + """Main service entry point handling execution requests and workflow resumption. + + This method serves as the primary interface for the XOpt optimizer service, + accepting both fresh execution requests and workflow resumption commands. + + Args: + input_data: XOptExecutionRequest for new execution, or Command for resumption + config: LangGraph configuration including thread_id and service settings + + Returns: + XOptServiceResult on success + + Raises: + XOptExecutionError: If optimization fails + TypeError: If input_data is not a supported type + """ + if isinstance(input_data, Command): + # This is a resume command (approval response) + if hasattr(input_data, "resume") and input_data.resume: + logger.info("Resuming XOpt service execution after approval") + approval_result = input_data.resume.get("approved", False) + logger.info(f"Approval result: {approval_result}") + + # Pass Command directly to let LangGraph handle checkpoint resume + compiled_graph = self.get_compiled_graph() + result = await compiled_graph.ainvoke(input_data, config) + + # Check for execution failure and raise exception + self._handle_execution_failure(result) + + return self._create_service_result(result) + else: + raise ValueError( + "Invalid Command received by service - missing or invalid resume data" + ) + + elif isinstance(input_data, XOptExecutionRequest): + logger.debug("Converting XOptExecutionRequest to internal state") + internal_state = self._create_initial_state(input_data) + + compiled_graph = self.get_compiled_graph() + result = await compiled_graph.ainvoke(internal_state, config) + + # Check for execution failure and raise exception + self._handle_execution_failure(result) + + return self._create_service_result(result) + + else: + supported_types = [XOptExecutionRequest.__name__, "Command"] + raise TypeError( + f"XOpt optimizer service received unsupported input type: {type(input_data).__name__}. " + f"Supported types: {', '.join(supported_types)}" + ) + + def _create_initial_state(self, request: XOptExecutionRequest) -> XOptExecutionState: + """Convert XOptExecutionRequest to internal service state. + + Initialize ALL state fields to avoid KeyError during execution. + + Args: + request: The execution request from the capability + + Returns: + XOptExecutionState: Initialized state for the LangGraph workflow + """ + return XOptExecutionState( + # Request (preserved via reducer) + request=request, + # Capability context + capability_context_data=request.capability_context_data, + # Error tracking + error_chain=[], + yaml_generation_attempt=0, + # Machine state + machine_state=None, + machine_state_details=None, + # Decision + selected_strategy=None, + decision_reasoning=None, + # YAML + generated_yaml=None, + yaml_generation_failed=None, + # Approval + requires_approval=None, + approval_interrupt_data=None, + approval_result=None, + approved=None, + # Execution + run_artifact=None, + execution_error=None, + execution_failed=None, + # Analysis + analysis_result=None, + recommendations=None, + # Loop control + iteration_count=0, + max_iterations=request.max_iterations, + should_continue=False, + # Control flags + is_successful=False, + is_failed=False, + failure_reason=None, + current_stage="state_id", + ) + + def _build_and_compile_graph(self): + """Build and compile the XOpt optimizer LangGraph. + + Creates a StateGraph with all nodes and conditional edges for the + optimization workflow, then compiles it with checkpointing support. + + Returns: + CompiledGraph: The compiled workflow graph + """ + workflow = StateGraph(XOptExecutionState) + + # Add nodes + workflow.add_node("state_identification", create_state_identification_node()) + workflow.add_node("decision", create_decision_node()) + workflow.add_node("yaml_generation", create_yaml_generation_node()) + workflow.add_node("approval", create_approval_node()) + workflow.add_node("execution", create_executor_node()) + workflow.add_node("analysis", create_analysis_node()) + + # Define flow + workflow.set_entry_point("state_identification") + workflow.add_edge("state_identification", "decision") + + workflow.add_conditional_edges( + "decision", + self._decision_router, + {"continue": "yaml_generation", "abort": "__end__"}, + ) + + workflow.add_conditional_edges( + "yaml_generation", + self._yaml_generation_router, + { + "approve": "approval", + "execute": "execution", + "retry": "yaml_generation", + "__end__": "__end__", + }, + ) + + workflow.add_conditional_edges( + "approval", + self._approval_router, + {"approved": "execution", "rejected": "__end__"}, + ) + + workflow.add_edge("execution", "analysis") + + workflow.add_conditional_edges( + "analysis", + self._loop_router, + {"continue": "state_identification", "complete": "__end__"}, + ) + + # Compile with checkpointer for interrupt support + checkpointer = self._create_checkpointer() + compiled = workflow.compile(checkpointer=checkpointer) + + logger.info("XOpt optimizer service graph compiled successfully") + return compiled + + def _decision_router(self, state: XOptExecutionState) -> str: + """Route after machine state decision. + + Args: + state: Current execution state + + Returns: + str: "abort" if failed or abort strategy, "continue" otherwise + """ + if state.get("is_failed"): + return "abort" + if state.get("selected_strategy") == XOptStrategy.ABORT: + return "abort" + return "continue" + + def _yaml_generation_router(self, state: XOptExecutionState) -> str: + """Route after YAML generation. + + Args: + state: Current execution state + + Returns: + str: Routing decision based on generation result + """ + if state.get("is_failed"): + return "__end__" + if state.get("yaml_generation_failed"): + return "retry" + if state.get("requires_approval"): + return "approve" + return "execute" + + def _approval_router(self, state: XOptExecutionState) -> str: + """Route after approval process. + + Args: + state: Current execution state + + Returns: + str: "approved" if approved, "rejected" otherwise + """ + return "approved" if state.get("approved") else "rejected" + + def _loop_router(self, state: XOptExecutionState) -> str: + """Route after analysis - continue loop or complete. + + Args: + state: Current execution state + + Returns: + str: "continue" to loop back, "complete" to end + """ + if state.get("is_failed"): + return "complete" + return "continue" if state.get("should_continue") else "complete" + + def _handle_execution_failure(self, result: dict) -> None: + """Check result and raise exception if execution failed. + + Args: + result: Final state from graph execution + + Raises: + XOptExecutionError: If optimization failed + """ + if not result.get("is_successful", False) and result.get("is_failed", False): + failure_reason = result.get("failure_reason", "XOpt optimization failed") + logger.error(f"XOpt execution failed: {failure_reason}") + raise XOptExecutionError( + message=f"XOpt optimization failed: {failure_reason}", + xopt_error=result.get("execution_error"), + ) + + def _create_service_result(self, result: dict) -> XOptServiceResult: + """Create structured service result from final state. + + Args: + result: Final state from graph execution + + Returns: + XOptServiceResult: Structured result for capability consumption + """ + recommendations = result.get("recommendations") or [] + return XOptServiceResult( + run_artifact=result.get("run_artifact", {}), + generated_yaml=result.get("generated_yaml", ""), + strategy=result.get("selected_strategy", XOptStrategy.EXPLORATION), + total_iterations=result.get("iteration_count", 0), + analysis_summary=result.get("analysis_result", {}), + recommendations=tuple(recommendations), # Convert to tuple for frozen dataclass + ) + + def _create_checkpointer(self): + """Create checkpointer using same logic as main graph. + + Returns: + Checkpointer: PostgreSQL or in-memory checkpointer + """ + # Check if we should use PostgreSQL (production mode) + use_postgres = self.config.get("langgraph", {}).get("use_postgres", False) + + if use_postgres: + try: + # Try PostgreSQL when explicitly requested + checkpointer = create_async_postgres_checkpointer() + logger.info("XOpt optimizer service using async PostgreSQL checkpointer") + return checkpointer + except Exception as e: + # Fall back to memory saver if PostgreSQL fails + logger.warning(f"PostgreSQL checkpointer failed for XOpt optimizer service: {e}") + logger.info("XOpt optimizer service falling back to in-memory checkpointer") + return create_memory_checkpointer() + else: + # Default to memory saver for R&D mode + logger.info("XOpt optimizer service using in-memory checkpointer") + return create_memory_checkpointer() + + def _load_config(self) -> dict[str, Any]: + """Load service configuration. + + Returns: + dict: Full configuration dictionary + """ + return get_full_configuration() diff --git a/src/osprey/services/xopt_optimizer/state_identification/__init__.py b/src/osprey/services/xopt_optimizer/state_identification/__init__.py new file mode 100644 index 000000000..64fb7f876 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/state_identification/__init__.py @@ -0,0 +1,26 @@ +"""State Identification Subsystem for XOpt Optimizer. + +This subsystem assesses machine readiness for optimization using a ReAct agent +with tools for reading reference files and channel values. + +Supports two modes: +- "react": ReAct agent with tools for reading reference files and channels (default) +- "mock": Fast placeholder that always returns READY + +Configuration: + xopt_optimizer: + state_identification: + mode: "react" # or "mock" + mock_files: true # Use mock file data (for testing without real files) + reference_path: "path/to/docs" # Optional path to reference files + model_config_name: "xopt_state_identification" +""" + +from .agent import StateIdentificationAgent, create_state_identification_agent +from .node import create_state_identification_node + +__all__ = [ + "create_state_identification_node", + "create_state_identification_agent", + "StateIdentificationAgent", +] diff --git a/src/osprey/services/xopt_optimizer/state_identification/agent.py b/src/osprey/services/xopt_optimizer/state_identification/agent.py new file mode 100644 index 000000000..c73b1f555 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/state_identification/agent.py @@ -0,0 +1,337 @@ +"""ReAct Agent for Machine State Identification. + +This module provides a ReAct agent that assesses machine readiness for optimization. +The agent uses tools to: +1. Read reference documentation about machine ready criteria +2. Read current channel values from the control system +3. Determine if the machine is READY, NOT_READY, or UNKNOWN + +The agent adapts based on available resources: +- Reference files can be mock (for testing) or real (from configured path) +- Channel access uses the existing control system connector (mock or real via config) +""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.messages import HumanMessage, SystemMessage +from langgraph.prebuilt import create_react_agent +from pydantic import BaseModel, Field + +from osprey.models.langchain import get_langchain_model +from osprey.utils.logger import get_logger + +from ..models import MachineState +from .tools import create_channel_access_tools, create_reference_file_tools + +logger = get_logger("xopt_optimizer") + + +# ============================================================================= +# STRUCTURED OUTPUT MODEL +# ============================================================================= + + +class MachineStateAssessment(BaseModel): + """Structured output for machine state assessment. + + This model is used with LangChain's `with_structured_output` for the + final assessment after the agent has gathered information. + """ + + state: MachineState = Field( + description="The assessed machine state: 'ready', 'not_ready', or 'unknown'" + ) + reasoning: str = Field( + description="Explanation of why this state was determined, including key observations" + ) + channels_checked: list[str] = Field( + default_factory=list, + description="List of channel names that were checked during assessment", + ) + key_observations: dict[str, Any] = Field( + default_factory=dict, + description="Key observations from channel readings and reference docs", + ) + + +# ============================================================================= +# SYSTEM PROMPT +# ============================================================================= + +STATE_IDENTIFICATION_PROMPT = """You are a machine state assessment agent for accelerator optimization. + +Your task is to determine if the machine is ready for optimization by: +1. Reading reference documentation to understand the ready criteria +2. Checking current channel values against those criteria +3. Providing a clear assessment with reasoning + +## Your Workflow + +1. **Read Documentation First**: Use `list_reference_files` to see available docs, then + `read_reference_file` to understand the machine ready criteria. + +2. **Check Channel Values**: Based on what you learn from the docs, use `read_channel_values` + to check the relevant channels. Pass channel names as comma-separated values. + +3. **Make Assessment**: Based on the criteria and current values, determine: + - **READY**: All criteria are met, machine can proceed with optimization + - **NOT_READY**: One or more criteria are not met, optimization should not proceed + - **UNKNOWN**: Unable to determine state (missing data, conflicting info, etc.) + +## Important Guidelines + +- Always read the reference docs first to understand what criteria to check +- Check all relevant channels mentioned in the documentation +- Be conservative: if unsure, report UNKNOWN rather than guessing +- Include specific channel values in your reasoning +- List all channels you checked in your response + +## Response Format + +After gathering information, provide your assessment with: +- The machine state (ready/not_ready/unknown) +- Clear reasoning explaining your decision +- List of channels you checked +- Key observations from your investigation +""" + + +# ============================================================================= +# AGENT CLASS +# ============================================================================= + + +class StateIdentificationAgent: + """ReAct agent for assessing machine readiness for optimization. + + This agent: + 1. Reads reference documentation about machine ready criteria + 2. Checks current channel values from the control system + 3. Determines if the machine is READY, NOT_READY, or UNKNOWN + + The tools adapt based on configuration: + - mock_files=True: Uses hardcoded mock reference data + - mock_files=False: Reads real files from reference_path + - Channel access uses ConnectorFactory (mock or real via control_system.type config) + """ + + def __init__( + self, + reference_path: str | None = None, + mock_files: bool = False, + model_config: dict[str, Any] | None = None, + ): + """Initialize the state identification agent. + + Args: + reference_path: Path to reference documentation directory. + Ignored if mock_files=True. + mock_files: If True, use mock reference file data for testing. + model_config: Configuration for the LLM model to use. + """ + self.reference_path = reference_path + self.mock_files = mock_files + self.model_config = model_config + self._agent = None + + def _get_tools(self) -> list[Any]: + """Get tools for the agent. + + Returns: + List of LangChain tools for file reading and channel access + """ + tools = [] + + # Reference file tools (mock or real) + tools.extend( + create_reference_file_tools( + reference_path=self.reference_path, + mock_mode=self.mock_files, + ) + ) + + # Channel access tools (uses existing mock connector via config) + tools.extend(create_channel_access_tools()) + + return tools + + def _get_model(self): + """Get the LangChain model for the agent. + + Returns: + LangChain BaseChatModel instance + + Raises: + ValueError: If no model_config is available + """ + if self.model_config: + return get_langchain_model(model_config=self.model_config) + + raise ValueError( + "No model_config provided to StateIdentificationAgent. " + "Ensure xopt_optimizer.state_identification.model_config_name is set in config.yml " + "or that 'orchestrator' model is configured as fallback." + ) + + def _get_agent(self): + """Get or create the ReAct agent. + + Returns: + Compiled ReAct agent graph + """ + if self._agent is None: + model = self._get_model() + tools = self._get_tools() + + self._agent = create_react_agent( + model=model, + tools=tools, + ) + + return self._agent + + async def assess_state( + self, + objective: str, + additional_context: dict[str, Any] | None = None, + ) -> tuple[MachineState, dict[str, Any]]: + """Assess machine readiness for optimization. + + Args: + objective: The optimization objective (provides context for assessment) + additional_context: Optional additional context + + Returns: + Tuple of (MachineState, details dict with reasoning and observations) + + Raises: + ValueError: If assessment fails + """ + agent = self._get_agent() + + # Build the user message + user_message = f"""Assess whether the machine is ready for optimization. + +**Optimization Objective:** {objective} + +Please: +1. First read the reference documentation to understand the machine ready criteria +2. Then check the relevant channel values +3. Provide your assessment of the machine state + +Remember to check ALL relevant criteria before making your assessment. +""" + + if additional_context: + user_message += f"\n**Additional Context:** {additional_context}" + + logger.info("Starting state identification agent...") + + try: + result = await agent.ainvoke( + { + "messages": [ + SystemMessage(content=STATE_IDENTIFICATION_PROMPT), + HumanMessage(content=user_message), + ] + } + ) + + # Extract the final response + messages = result.get("messages", []) + if not messages: + raise ValueError("Agent did not produce any output") + + # Get the last message content + last_message = messages[-1] + content = ( + last_message.content + if hasattr(last_message, "content") + else str(last_message) + ) + + # Parse the assessment from the response + assessment = self._parse_assessment(content) + + logger.info(f"State assessment complete: {assessment['state'].value}") + return assessment["state"], { + "reasoning": assessment["reasoning"], + "channels_checked": assessment.get("channels_checked", []), + "key_observations": assessment.get("key_observations", {}), + "raw_response": content, + } + + except Exception as e: + logger.error(f"State identification agent failed: {e}") + raise ValueError(f"State assessment failed: {e}") from e + + def _parse_assessment(self, content: str) -> dict[str, Any]: + """Parse machine state assessment from agent response. + + Args: + content: The agent's response text + + Returns: + Dict with state, reasoning, channels_checked, and key_observations + """ + content_lower = content.lower() + + # Determine state from response + if "not_ready" in content_lower or "not ready" in content_lower: + state = MachineState.NOT_READY + elif "unknown" in content_lower and ( + "state: unknown" in content_lower + or "state is unknown" in content_lower + or "cannot determine" in content_lower + or "unable to determine" in content_lower + ): + state = MachineState.UNKNOWN + elif "ready" in content_lower: + # Check it's not "not ready" + if "not ready" not in content_lower and "not_ready" not in content_lower: + state = MachineState.READY + else: + state = MachineState.NOT_READY + else: + # Default to unknown if we can't parse + logger.warning("Could not parse state from response, defaulting to UNKNOWN") + state = MachineState.UNKNOWN + + # Extract channel names mentioned in the response + channels_checked = [] + # Look for common channel patterns + import re + + channel_pattern = r"[A-Z][A-Z0-9_:]+:[A-Z0-9_:]+" + channels_checked = list(set(re.findall(channel_pattern, content))) + + return { + "state": state, + "reasoning": content, + "channels_checked": channels_checked, + "key_observations": {}, + } + + +def create_state_identification_agent( + reference_path: str | None = None, + mock_files: bool = False, + model_config: dict[str, Any] | None = None, +) -> StateIdentificationAgent: + """Factory function to create a state identification agent. + + Args: + reference_path: Path to reference documentation directory + mock_files: If True, use mock reference file data + model_config: Optional model configuration + + Returns: + Configured StateIdentificationAgent instance + """ + return StateIdentificationAgent( + reference_path=reference_path, + mock_files=mock_files, + model_config=model_config, + ) diff --git a/src/osprey/services/xopt_optimizer/state_identification/node.py b/src/osprey/services/xopt_optimizer/state_identification/node.py new file mode 100644 index 000000000..bf3b960b9 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/state_identification/node.py @@ -0,0 +1,214 @@ +"""State Identification Node for XOpt Optimizer Service. + +This node assesses machine readiness for optimization. + +Supports two modes (configured via xopt_optimizer.state_identification.mode): +- "react": ReAct agent with tools for reading reference files, querying channels, + and determining machine state (default) +- "mock": Fast placeholder that always returns READY + +Configuration: + xopt_optimizer: + state_identification: + mode: "react" # or "mock" + mock_files: true # Use mock file data (for testing without real files) + reference_path: "path/to/docs" # Optional path to reference files + model_config_name: "xopt_state_identification" # Model config reference +""" + +from typing import Any + +from osprey.utils.config import get_model_config, get_xopt_optimizer_config +from osprey.utils.logger import get_logger + +from ..models import MachineState, XOptExecutionState + +logger = get_logger("xopt_optimizer") + + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + + +def _get_state_identification_config() -> dict[str, Any]: + """Get state identification configuration from osprey config. + + Reads from config structure: + xopt_optimizer: + state_identification: + mode: "mock" # or "react" + mock_files: true # Use mock file data (default: true for testing) + reference_path: "path/to/docs" # Optional path to reference files + model_config_name: "xopt_state_identification" # References models section + + Returns: + Configuration dict with mode, mock_files, reference_path, and model_config + """ + xopt_config = get_xopt_optimizer_config() + state_id_config = xopt_config.get("state_identification", {}) + + # Resolve model config from name reference + # Falls back to "orchestrator" model if xopt-specific model not configured + model_config = None + model_config_name = state_id_config.get("model_config_name", "xopt_state_identification") + try: + model_config = get_model_config(model_config_name) + # Check if the model config is valid (has provider) + if not model_config or not model_config.get("provider"): + logger.debug( + f"Model '{model_config_name}' not configured, falling back to orchestrator" + ) + model_config = get_model_config("orchestrator") + except Exception as e: + logger.warning( + f"Could not load model config '{model_config_name}': {e}, " + "falling back to orchestrator" + ) + model_config = get_model_config("orchestrator") + + return { + "mode": state_id_config.get("mode", "react"), # Default to react agent + "mock_files": state_id_config.get("mock_files", True), # Default to mock files + "reference_path": state_id_config.get("reference_path"), # Optional + "model_config": model_config, + } + + +# ============================================================================= +# MOCK STATE ASSESSMENT +# ============================================================================= + + +def _assess_state_mock() -> tuple[MachineState, dict[str, Any]]: + """Assess machine state using mock logic (for testing). + + Always returns READY with placeholder details. + + Returns: + Tuple of (MachineState, details dict) + """ + return MachineState.READY, { + "assessment": "mock", + "note": "Mock implementation - always returns READY", + } + + +# ============================================================================= +# REACT AGENT STATE ASSESSMENT +# ============================================================================= + + +async def _assess_state_react( + objective: str, + model_config: dict[str, Any], + mock_files: bool = True, + reference_path: str | None = None, +) -> tuple[MachineState, dict[str, Any]]: + """Assess machine state using ReAct agent with tools. + + The agent: + 1. Reads reference documentation about machine ready criteria + 2. Checks current channel values from the control system + 3. Determines if the machine is READY, NOT_READY, or UNKNOWN + + Args: + objective: The optimization objective + model_config: Model configuration for the agent + mock_files: If True, use mock reference file data + reference_path: Path to reference documentation (if not using mock) + + Returns: + Tuple of (MachineState, details dict) + """ + from .agent import create_state_identification_agent + + agent = create_state_identification_agent( + reference_path=reference_path, + mock_files=mock_files, + model_config=model_config, + ) + + try: + machine_state, details = await agent.assess_state(objective=objective) + return machine_state, details + except Exception as e: + logger.warning(f"ReAct agent failed, falling back to mock: {e}") + return _assess_state_mock() + + +# ============================================================================= +# NODE FACTORY +# ============================================================================= + + +def create_state_identification_node(): + """Create the state identification node for LangGraph integration. + + This factory function creates a node that assesses machine readiness + for optimization. + + The assessment mode is controlled via configuration: + - xopt_optimizer.state_identification.mode: "mock" | "react" + + Returns: + Async function that takes XOptExecutionState and returns state updates + """ + + async def state_identification_node(state: XOptExecutionState) -> dict[str, Any]: + """Assess machine readiness for optimization. + + Supports two modes: + - "mock": Fast placeholder that always returns READY (default) + - "react": ReAct agent with tools for reading reference files and channels + """ + node_logger = get_logger("xopt_optimizer", state=state) + + # Get configuration + state_id_config = _get_state_identification_config() + mode = state_id_config.get("mode", "react") + is_mock = mode == "mock" + + request = state.get("request") + objective = request.optimization_objective if request else "Unknown objective" + + # Log with (mock) indicator if in mock mode + mode_indicator = " (mock)" if is_mock else "" + node_logger.status(f"Assessing machine state{mode_indicator}...") + + try: + if mode == "react": + machine_state, details = await _assess_state_react( + objective=objective, + model_config=state_id_config.get("model_config"), + mock_files=state_id_config.get("mock_files", True), + reference_path=state_id_config.get("reference_path"), + ) + else: + machine_state, details = _assess_state_mock() + + # Log result with (mock) indicator + node_logger.key_info(f"Machine state: {machine_state.value.upper()}{mode_indicator}") + + return { + "machine_state": machine_state, + "machine_state_details": details, + "current_stage": "decision", + } + + except Exception as e: + node_logger.error(f"State assessment failed: {e}") + + # Fall back to mock on error + node_logger.warning("Falling back to mock state assessment due to error") + machine_state, details = _assess_state_mock() + + node_logger.key_info(f"Machine state: {machine_state.value.upper()} (mock)") + + return { + "machine_state": machine_state, + "machine_state_details": details, + "current_stage": "decision", + } + + return state_identification_node diff --git a/src/osprey/services/xopt_optimizer/state_identification/tools/__init__.py b/src/osprey/services/xopt_optimizer/state_identification/tools/__init__.py new file mode 100644 index 000000000..68f0b23ab --- /dev/null +++ b/src/osprey/services/xopt_optimizer/state_identification/tools/__init__.py @@ -0,0 +1,36 @@ +"""Tools for State Identification Agent. + +This module provides tools that the State Identification ReAct agent +uses to assess machine readiness for optimization: + +- Reference File Tools: List and read documentation about machine ready criteria +- Channel Access Tools: Read current values from control system channels + +Tool Modes: + Reference files support mock mode for testing without real files. + Channel access uses the existing MockConnector when control_system.type: mock. +""" + +from .channel_access import ( + clear_connector_cache, + create_channel_access_tools, + create_read_channels_tool, +) +from .reference_files import ( + MOCK_REFERENCE_FILES, + create_list_files_tool, + create_read_file_tool, + create_reference_file_tools, +) + +__all__ = [ + # Reference file tools + "create_reference_file_tools", + "create_list_files_tool", + "create_read_file_tool", + "MOCK_REFERENCE_FILES", + # Channel access tools + "create_channel_access_tools", + "create_read_channels_tool", + "clear_connector_cache", +] diff --git a/src/osprey/services/xopt_optimizer/state_identification/tools/channel_access.py b/src/osprey/services/xopt_optimizer/state_identification/tools/channel_access.py new file mode 100644 index 000000000..6d37a6a89 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/state_identification/tools/channel_access.py @@ -0,0 +1,114 @@ +"""Channel Access Tools for State Identification Agent. + +Provides tools for reading control system channel values to assess +machine state for optimization readiness. + +Uses the existing ConnectorFactory which automatically selects the +appropriate connector based on configuration: +- control_system.type: "mock" -> MockConnector (synthetic data) +- control_system.type: "epics" -> EPICSConnector (real control system) + +No additional mocking needed - the MockConnector already provides +realistic synthetic data for any channel name. +""" + +from typing import Any + +from langchain_core.tools import tool + +from osprey.connectors.factory import ConnectorFactory +from osprey.utils.logger import get_logger + +logger = get_logger("xopt_optimizer") + +# Module-level connector cache for reuse within a session +_connector_cache: dict[str, Any] = {} + + +async def _get_connector(): + """Get or create control system connector. + + Uses ConnectorFactory which reads from config to determine + connector type (mock, epics, etc.). + + Returns: + Connected ControlSystemConnector instance + """ + if "control_system" not in _connector_cache: + connector = await ConnectorFactory.create_control_system_connector() + _connector_cache["control_system"] = connector + return _connector_cache["control_system"] + + +def create_read_channels_tool(): + """Create a tool for reading control system channel values. + + Returns: + LangChain tool function for reading channels + """ + + @tool + async def read_channel_values(channel_names: str) -> str: + """Read current values from control system channels. + + Use this tool to check the current state of machine parameters + that are relevant for determining optimization readiness. + + Args: + channel_names: Comma-separated list of channel names to read. + Example: "BEAM:CURRENT,VACUUM:PRESSURE,SAFETY:INTERLOCK" + + Returns: + Formatted string with channel values and metadata, or error message + """ + # Parse channel names + channels = [name.strip() for name in channel_names.split(",") if name.strip()] + + if not channels: + return "Error: No channel names provided. Provide comma-separated channel names." + + try: + connector = await _get_connector() + + results = [] + for channel in channels: + try: + value = await connector.read_channel(channel) + # Format the result with relevant metadata + result_line = f"{channel}: {value.value}" + if value.metadata: + if value.metadata.units: + result_line += f" {value.metadata.units}" + if value.metadata.severity: + result_line += f" (severity: {value.metadata.severity})" + results.append(result_line) + except Exception as e: + results.append(f"{channel}: ERROR - {e}") + + logger.debug(f"Read {len(channels)} channels for state assessment") + return "\n".join(results) + + except Exception as e: + logger.error(f"Failed to read channels: {e}") + return f"Error connecting to control system: {e}" + + return read_channel_values + + +def create_channel_access_tools() -> list[Any]: + """Create all channel access tools. + + Returns: + List of LangChain tools [read_channel_values] + """ + return [ + create_read_channels_tool(), + ] + + +def clear_connector_cache(): + """Clear the connector cache. + + Useful for testing to ensure fresh connector creation. + """ + _connector_cache.clear() diff --git a/src/osprey/services/xopt_optimizer/state_identification/tools/reference_files.py b/src/osprey/services/xopt_optimizer/state_identification/tools/reference_files.py new file mode 100644 index 000000000..10f356179 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/state_identification/tools/reference_files.py @@ -0,0 +1,236 @@ +"""Reference File Tools for State Identification Agent. + +Provides tools for listing and reading reference documentation files +that inform machine state assessment decisions. + +Supports two modes: +- mock_mode=True: Returns hardcoded mock data for testing +- mock_mode=False: Reads actual files from reference_path + +The mock data provides realistic examples that help test the agent's +reasoning without requiring actual reference files to be set up. +""" + +from pathlib import Path +from typing import Any + +from langchain_core.tools import tool + +from osprey.utils.logger import get_logger + +logger = get_logger("xopt_optimizer") + + +# ============================================================================= +# MOCK DATA FOR TESTING +# ============================================================================= + +MOCK_REFERENCE_FILES = { + "machine_ready_criteria.md": """# Machine Ready Criteria + +This document defines the criteria for determining if the machine is ready for optimization. + +## Ready Conditions + +The machine is considered **READY** for optimization when ALL of the following are true: + +1. **Beam Current**: Above 10 mA (channel: `BEAM:CURRENT`) +2. **Vacuum Pressure**: Below 1e-8 Torr (channel: `VACUUM:PRESSURE`) +3. **Interlock Status**: No active interlocks (channel: `SAFETY:INTERLOCK`, value should be 0) +4. **Machine Mode**: Not in maintenance mode (channel: `MACHINE:MODE`, value should be 1 for "operational") + +## Not Ready Conditions + +The machine is **NOT_READY** when ANY of the following are true: + +1. Beam current is zero or below 1 mA +2. Any safety interlock is active (SAFETY:INTERLOCK != 0) +3. Machine is in maintenance mode (MACHINE:MODE == 0) +4. Vacuum pressure exceeds safe limits + +## Unknown State + +Report **UNKNOWN** if: +- Unable to read critical channels +- Channel values are stale or unreliable +- Conflicting information from different sources +""", + "optimization_channels.md": """# Optimization Channels Reference + +This document lists the control system channels relevant to optimization operations. + +## Primary Monitoring Channels + +| Channel Name | Description | Units | Normal Range | +|--------------|-------------|-------|--------------| +| BEAM:CURRENT | Beam current | mA | 10-500 | +| VACUUM:PRESSURE | Vacuum level | Torr | < 1e-8 | +| SAFETY:INTERLOCK | Interlock status | - | 0 (clear) | +| MACHINE:MODE | Operating mode | - | 1 (operational) | + +## How to Use + +1. Read SAFETY:INTERLOCK first - if non-zero, machine is NOT_READY +2. Check MACHINE:MODE - must be 1 for operational +3. Verify BEAM:CURRENT is in acceptable range +4. Confirm VACUUM:PRESSURE is within limits + +## Channel Naming Convention + +All channels follow the pattern: `SYSTEM:SUBSYSTEM:PARAMETER` +- BEAM: Beam-related measurements +- VACUUM: Vacuum system readings +- SAFETY: Safety and interlock status +- MACHINE: Overall machine state +""", + "safety_procedures.md": """# Safety Procedures for Optimization + +## Pre-Optimization Checklist + +Before starting any optimization run: + +1. Verify no personnel in restricted areas +2. Confirm all interlocks are clear (SAFETY:INTERLOCK == 0) +3. Check beam current stability over last 5 minutes +4. Ensure vacuum levels are nominal + +## Abort Conditions + +Immediately abort optimization if: +- Any interlock activates +- Beam current drops below 5 mA +- Operator requests stop +- Unexpected machine behavior detected + +## Contact Information + +For questions about machine readiness, contact the control room operator. +""", +} + + +# ============================================================================= +# TOOL IMPLEMENTATIONS +# ============================================================================= + + +def create_list_files_tool(reference_path: str | None = None, mock_mode: bool = False): + """Create a tool for listing available reference files. + + Args: + reference_path: Path to reference files directory (ignored in mock mode) + mock_mode: If True, return mock file list + + Returns: + LangChain tool function + """ + + @tool + def list_reference_files() -> str: + """List available reference documentation files. + + Returns a list of file names that can be read with the read_reference_file tool. + These files contain important information about machine ready criteria, + channel definitions, and safety procedures. + + Returns: + Newline-separated list of available file names + """ + if mock_mode: + files = list(MOCK_REFERENCE_FILES.keys()) + logger.debug(f"[mock] Listing {len(files)} reference files") + return "\n".join(files) + + if not reference_path: + return "No reference files configured. Reference path not specified." + + path = Path(reference_path) + if not path.exists(): + return f"Reference path does not exist: {reference_path}" + + # List markdown and text files + files = [] + for ext in ["*.md", "*.txt", "*.yaml", "*.yml"]: + files.extend(p.name for p in path.glob(ext)) + + if not files: + return f"No reference files found in {reference_path}" + + logger.debug(f"Found {len(files)} reference files in {reference_path}") + return "\n".join(sorted(files)) + + return list_reference_files + + +def create_read_file_tool(reference_path: str | None = None, mock_mode: bool = False): + """Create a tool for reading reference file contents. + + Args: + reference_path: Path to reference files directory (ignored in mock mode) + mock_mode: If True, return mock file contents + + Returns: + LangChain tool function + """ + + @tool + def read_reference_file(filename: str) -> str: + """Read the contents of a reference documentation file. + + Use list_reference_files first to see available files, then use this + tool to read specific files that are relevant to assessing machine state. + + Args: + filename: Name of the file to read (from list_reference_files output) + + Returns: + Contents of the file, or error message if file not found + """ + if mock_mode: + if filename in MOCK_REFERENCE_FILES: + logger.debug(f"[mock] Reading reference file: {filename}") + return MOCK_REFERENCE_FILES[filename] + else: + available = ", ".join(MOCK_REFERENCE_FILES.keys()) + return f"File not found: {filename}. Available files: {available}" + + if not reference_path: + return "Cannot read file - reference path not configured." + + file_path = Path(reference_path) / filename + if not file_path.exists(): + return f"File not found: {filename}" + + # Security check - ensure file is within reference_path + try: + file_path.resolve().relative_to(Path(reference_path).resolve()) + except ValueError: + return f"Access denied - file is outside reference directory: {filename}" + + try: + content = file_path.read_text() + logger.debug(f"Read reference file: {filename} ({len(content)} chars)") + return content + except Exception as e: + return f"Error reading file {filename}: {e}" + + return read_reference_file + + +def create_reference_file_tools( + reference_path: str | None = None, + mock_mode: bool = False, +) -> list[Any]: + """Create all reference file tools. + + Args: + reference_path: Path to reference files directory + mock_mode: If True, use mock data instead of real files + + Returns: + List of LangChain tools [list_reference_files, read_reference_file] + """ + return [ + create_list_files_tool(reference_path, mock_mode), + create_read_file_tool(reference_path, mock_mode), + ] diff --git a/src/osprey/services/xopt_optimizer/yaml_generation/__init__.py b/src/osprey/services/xopt_optimizer/yaml_generation/__init__.py new file mode 100644 index 000000000..1f76a4ab6 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/yaml_generation/__init__.py @@ -0,0 +1,25 @@ +"""YAML Generation Subsystem for XOpt Optimizer. + +This subsystem generates XOpt YAML configurations using either: +1. ReAct mode (default): Agent-based generation that dynamically adapts: + - If example files exist: Agent reads them and learns patterns + - If no examples: Agent generates from built-in XOpt knowledge +2. Mock mode: Placeholder YAML for quick testing (use for fast iteration) + +The mode is controlled via configuration: + osprey.xopt_optimizer.yaml_generation.mode: "react" | "mock" + +When using ReAct mode with examples, place YAML files in: + osprey.xopt_optimizer.yaml_generation.examples_path: "path/to/yamls" + +Example files are optional - the agent adapts its behavior based on availability. +""" + +from .agent import YamlGenerationAgent, create_yaml_generation_agent +from .node import create_yaml_generation_node + +__all__ = [ + "create_yaml_generation_node", + "YamlGenerationAgent", + "create_yaml_generation_agent", +] diff --git a/src/osprey/services/xopt_optimizer/yaml_generation/agent.py b/src/osprey/services/xopt_optimizer/yaml_generation/agent.py new file mode 100644 index 000000000..1e09cf929 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/yaml_generation/agent.py @@ -0,0 +1,484 @@ +"""ReAct Agent for XOpt YAML Configuration Generation. + +This module provides a ReAct agent that generates XOpt YAML configurations. +The agent dynamically adapts based on whether example files are available: + +- **With examples**: Agent gets file reading tools and is instructed to + learn from historical configurations before generating new ones. +- **Without examples**: Agent generates YAML from its built-in knowledge + of XOpt configuration patterns. + +This design avoids requiring pre-created example files while still +benefiting from them when available. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from langchain_core.tools import tool +from langgraph.prebuilt import create_react_agent + +from osprey.models.langchain import get_langchain_model +from osprey.utils.logger import get_logger + +logger = get_logger("xopt_optimizer") + + +# ============================================================================= +# DYNAMIC PROMPTS +# ============================================================================= + +# Prompt when example files ARE available +PROMPT_WITH_EXAMPLES = """You are an expert XOpt configuration generator for accelerator optimization. + +You have access to example XOpt YAML configurations that you should read and learn from. + +## Your Workflow + +1. **READ EXAMPLES FIRST**: Use the `list_yaml_files` tool to see what examples are available, + then use `read_yaml_file` to read them. Study the structure carefully: + - Variable definitions (types, bounds) + - Objective specifications + - Generator selection patterns + - Comments explaining configuration choices + +2. **Understand the Objective**: Parse the user's optimization request to understand: + - What they want to optimize + - What strategy is appropriate (exploration vs optimization) + - Any constraints mentioned + +3. **Generate Configuration**: Create a valid XOpt YAML based on: + - The patterns you learned from examples + - The specific optimization objective + - Best practices for XOpt + +## Output Format + +Your final output MUST be a complete, valid YAML configuration wrapped in ```yaml``` code blocks. +Include comments explaining key configuration choices. + +## Important Notes + +- Always read the examples first - they show the expected structure +- Use placeholder names (param_1, param_2, objective_1) unless the user provides specific names +- Do NOT invent specific accelerator channel names or parameters +- Always include: generator, evaluator, vocs sections +""" + +# Prompt when NO example files are available +PROMPT_WITHOUT_EXAMPLES = """You are an expert XOpt configuration generator for accelerator optimization. + +No example configurations are available, so you must generate YAML from your knowledge of XOpt. + +## XOpt Configuration Structure + +A valid XOpt YAML configuration includes: + +```yaml +# Generator - how to sample new points +generator: + name: random # Options: random, latin_hypercube, sobol, bayesian + +# Evaluator - how to assess points +evaluator: + function: objective_function_name + +# VOCS - Variables, Objectives, Constraints, Statics +vocs: + variables: + param_name: + type: continuous # or discrete, ordinal + lower: 0.0 + upper: 10.0 + objectives: + objective_name: + type: minimize # or maximize + constraints: {} + statics: {} + +# Runtime settings +n_initial: 5 +max_evaluations: 20 +``` + +## Your Workflow + +1. **Understand the Objective**: Parse the user's optimization request +2. **Select Generator**: Based on strategy (exploration β†’ random/latin_hypercube, optimization β†’ bayesian) +3. **Define Variables**: Create placeholder variables with reasonable defaults +4. **Define Objectives**: Based on what user wants to optimize +5. **Generate YAML**: Complete, valid configuration + +## Output Format + +Your final output MUST be a complete, valid YAML configuration wrapped in ```yaml``` code blocks. +Include comments explaining key configuration choices. + +## Important Notes + +- Use placeholder names (param_1, param_2, objective_1) unless the user provides specific names +- Do NOT invent specific accelerator channel names or parameters +- Include reasonable default bounds (e.g., 0.0 to 10.0 for continuous variables) +- Always include: generator, evaluator, vocs sections +""" + + +# ============================================================================= +# FILE TOOLS (only created when examples exist) +# ============================================================================= + + +def _create_file_tools(examples_path: Path) -> list[Any]: + """Create file reading tools for the agent. + + These tools are only created when example files exist. + + Args: + examples_path: Path to directory containing example YAML files + + Returns: + List of LangChain tools for file operations + """ + + @tool + def list_yaml_files() -> str: + """List available YAML configuration files in the examples directory. + + Use this tool first to see what examples are available before reading them. + + Returns: + List of available YAML files with brief descriptions from their first comment. + """ + yaml_files = list(examples_path.glob("**/*.yaml")) + list(examples_path.glob("**/*.yml")) + + if not yaml_files: + return "No YAML files found in examples directory." + + results = ["Available YAML configurations:"] + for yaml_file in yaml_files: + rel_path = yaml_file.relative_to(examples_path) + # Try to extract description from first comment line + try: + first_lines = yaml_file.read_text(encoding="utf-8").split("\n")[:5] + description = "" + for line in first_lines: + if line.startswith("#") and not line.startswith("# ="): + description = line.lstrip("# ").strip() + break + if description: + results.append(f" - {rel_path}: {description}") + else: + results.append(f" - {rel_path}") + except Exception: + results.append(f" - {rel_path}") + + return "\n".join(results) + + @tool + def read_yaml_file(filename: str) -> str: + """Read the contents of a YAML configuration file. + + Use this after listing files to read specific examples and learn their structure. + + Args: + filename: Name of the YAML file to read (e.g., 'exploration_basic.yaml') + + Returns: + Contents of the YAML file, or error message if not found. + """ + # Security: only allow reading from examples directory + file_path = examples_path / filename + + # Check for path traversal attacks + try: + file_path = file_path.resolve() + examples_resolved = examples_path.resolve() + if not str(file_path).startswith(str(examples_resolved)): + return f"Error: Cannot read files outside examples directory." + except Exception: + return f"Error: Invalid file path." + + if not file_path.exists(): + # Try searching subdirectories + matches = list(examples_path.glob(f"**/{filename}")) + if matches: + file_path = matches[0] + else: + return f"Error: File '{filename}' not found. Use list_yaml_files to see available files." + + try: + content = file_path.read_text(encoding="utf-8") + return f"=== {filename} ===\n{content}" + except Exception as e: + return f"Error reading file: {e}" + + return [list_yaml_files, read_yaml_file] + + +# ============================================================================= +# AGENT CLASS +# ============================================================================= + + +class YamlGenerationAgent: + """ReAct agent for generating XOpt YAML configurations. + + This agent dynamically adapts based on whether example files are available: + - With examples: Gets file tools and prompt to read examples first + - Without examples: Generates from knowledge with appropriate prompt + + Attributes: + examples_path: Path to directory containing example YAML files (optional) + model_config: Configuration for the LLM model to use + """ + + def __init__( + self, + examples_path: str | Path | None = None, + model_config: dict[str, Any] | None = None, + ): + """Initialize the YAML generation agent. + + Args: + examples_path: Path to directory containing example YAML files. + If None or directory doesn't exist/is empty, agent generates from knowledge. + model_config: Optional model configuration. If not provided, + uses the 'fast' model from osprey config. + """ + self.examples_path = Path(examples_path) if examples_path else None + self.model_config = model_config + self._agent = None + self._has_examples = False + + def _check_examples_exist(self) -> bool: + """Check if example YAML files exist. + + Returns: + True if examples directory exists and contains YAML files + """ + if not self.examples_path: + return False + + if not self.examples_path.exists(): + return False + + yaml_files = list(self.examples_path.glob("**/*.yaml")) + list( + self.examples_path.glob("**/*.yml") + ) + return len(yaml_files) > 0 + + def _get_tools(self) -> list[Any]: + """Get tools for the agent based on file availability. + + Returns: + List of tools (file tools if examples exist, empty otherwise) + """ + if self._has_examples and self.examples_path: + return _create_file_tools(self.examples_path) + else: + return [] + + def _get_prompt(self) -> str: + """Get the appropriate system prompt based on file availability. + + Returns: + System prompt string + """ + if self._has_examples: + return PROMPT_WITH_EXAMPLES + else: + return PROMPT_WITHOUT_EXAMPLES + + def _get_model(self): + """Get the LangChain model for the agent. + + Uses model_config provided during initialization. + The node.py handles fallback to orchestrator model if xopt-specific + model is not configured. + + Returns: + LangChain BaseChatModel instance + + Raises: + ValueError: If no model_config is available + """ + if self.model_config: + return get_langchain_model(model_config=self.model_config) + + # This shouldn't happen if node.py fallback is working + raise ValueError( + "No model_config provided to YamlGenerationAgent. " + "Ensure xopt_optimizer.yaml_generation.model_config_name is set in config.yml " + "or that 'orchestrator' model is configured as fallback." + ) + + def _get_agent(self): + """Get or create the ReAct agent with dynamic configuration. + + Returns: + Compiled ReAct agent graph + """ + if self._agent is None: + # Check for examples at agent creation time + self._has_examples = self._check_examples_exist() + + model = self._get_model() + tools = self._get_tools() + + # Create agent with or without tools + self._agent = create_react_agent( + model=model, + tools=tools, + ) + + return self._agent + + async def generate_yaml( + self, + objective: str, + strategy: str, + additional_context: dict[str, Any] | None = None, + ) -> str: + """Generate XOpt YAML configuration using the ReAct agent. + + Args: + objective: The optimization objective (e.g., "maximize injection efficiency") + strategy: The selected strategy ("exploration" or "optimization") + additional_context: Optional additional context to include in the prompt + + Returns: + Generated YAML configuration as a string + + Raises: + ValueError: If YAML generation fails or produces invalid output + """ + agent = self._get_agent() + + # Build the user message + user_message = f"""Generate an XOpt YAML configuration for the following: + +**Optimization Objective:** {objective} +**Strategy:** {strategy} + +{"First, use the tools to read available example configurations. " if self._has_examples else ""}Generate a complete, valid YAML configuration based on { + "what you learn from the examples" if self._has_examples else "your knowledge of XOpt configuration patterns" +}. + +Remember: +- Use generic parameter names unless specific names are provided +- Include comments explaining your configuration choices +- Output the final YAML in ```yaml``` code blocks +""" + + if additional_context: + user_message += f"\n**Additional Context:** {additional_context}" + + # Run the agent + logger.info("Starting YAML generation agent...") + + try: + result = await agent.ainvoke( + { + "messages": [ + {"role": "system", "content": self._get_prompt()}, + {"role": "user", "content": user_message}, + ] + } + ) + + # Extract the final response + messages = result.get("messages", []) + if not messages: + raise ValueError("Agent did not produce any output") + + # Get the last message content + last_message = messages[-1] + content = ( + last_message.content + if hasattr(last_message, "content") + else str(last_message) + ) + + # Extract YAML from response + yaml_content = self._extract_yaml(content) + + if not yaml_content: + logger.warning(f"Could not extract YAML from response. Response: {content[:500]}") + raise ValueError("Agent did not produce valid YAML output") + + logger.info(f"YAML generation complete: {len(yaml_content)} characters") + return yaml_content + + except Exception as e: + logger.error(f"YAML generation agent failed: {e}") + raise ValueError(f"YAML generation failed: {e}") from e + + def _extract_yaml(self, content: str) -> str | None: + """Extract YAML content from agent response. + + Args: + content: The agent's response text + + Returns: + Extracted YAML content or None if not found + """ + import re + + # Try to find YAML code blocks + yaml_pattern = r"```yaml\n(.*?)```" + matches = re.findall(yaml_pattern, content, re.DOTALL) + + if matches: + return matches[-1].strip() + + # Try generic code blocks + code_pattern = r"```\n(.*?)```" + matches = re.findall(code_pattern, content, re.DOTALL) + + for match in matches: + # Check if it looks like YAML + if "generator:" in match or "vocs:" in match or "evaluator:" in match: + return match.strip() + + # If no code blocks, check if the whole response is YAML-like + if "generator:" in content and "vocs:" in content: + # Try to extract just the YAML part + lines = content.split("\n") + yaml_lines = [] + in_yaml = False + + for line in lines: + if line.strip().startswith(("#", "generator:", "evaluator:", "vocs:")): + in_yaml = True + if in_yaml: + yaml_lines.append(line) + + if yaml_lines: + return "\n".join(yaml_lines).strip() + + return None + + +def create_yaml_generation_agent( + examples_path: str | Path | None = None, + model_config: dict[str, Any] | None = None, +) -> YamlGenerationAgent: + """Factory function to create a YAML generation agent. + + The agent dynamically adapts based on whether example files exist: + - If examples_path has YAML files: Agent gets tools to read them + - If no examples: Agent generates from its built-in knowledge + + Args: + examples_path: Path to directory containing example YAML files (optional) + model_config: Optional model configuration + + Returns: + Configured YamlGenerationAgent instance + """ + return YamlGenerationAgent( + examples_path=examples_path, + model_config=model_config, + ) diff --git a/src/osprey/services/xopt_optimizer/yaml_generation/node.py b/src/osprey/services/xopt_optimizer/yaml_generation/node.py new file mode 100644 index 000000000..27584b0dc --- /dev/null +++ b/src/osprey/services/xopt_optimizer/yaml_generation/node.py @@ -0,0 +1,304 @@ +"""YAML Generation Node for XOpt Optimizer Service. + +This node generates XOpt YAML configurations and prepares approval interrupt data. +It follows the Python executor's analyzer pattern where the node that generates +content also creates the approval interrupt data. + +Supports two modes (configured via osprey.xopt_optimizer.yaml_generation.mode): +- "react": ReAct agent generates YAML (default) - adapts based on file availability: + - If example files exist: Agent reads them and learns patterns + - If no examples: Agent generates from built-in XOpt knowledge +- "mock": Placeholder YAML for quick testing (use for fast iteration) + +Example YAML files are optional. If provided, place them in: + osprey.xopt_optimizer.yaml_generation.examples_path: "path/to/yamls" + +DO NOT add accelerator-specific YAML parameters without operator input. +""" + +from pathlib import Path +from typing import Any + +from osprey.utils.config import get_model_config, get_xopt_optimizer_config +from osprey.utils.logger import get_logger + +from ..exceptions import YamlGenerationError +from ..models import XOptError, XOptExecutionState, XOptStrategy + +logger = get_logger("xopt_optimizer") + +# Default path for example YAML files (relative to working directory) +DEFAULT_EXAMPLES_PATH = "_agent_data/xopt_examples/yaml_templates" + + +def _get_yaml_generation_config() -> dict[str, Any]: + """Get YAML generation configuration from osprey config. + + Reads from config structure: + xopt_optimizer: + yaml_generation: + mode: "react" + examples_path: "..." + model_config_name: "xopt_yaml_generation" # References models section + + Returns: + Configuration dict with mode, examples_path, and model_config + """ + xopt_config = get_xopt_optimizer_config() + yaml_config = xopt_config.get("yaml_generation", {}) + + # Resolve model config from name reference + # Falls back to "orchestrator" model if xopt-specific model not configured + model_config = None + model_config_name = yaml_config.get("model_config_name", "xopt_yaml_generation") + try: + model_config = get_model_config(model_config_name) + # Check if the model config is valid (has provider) + if not model_config or not model_config.get("provider"): + logger.debug(f"Model '{model_config_name}' not configured, falling back to orchestrator") + model_config = get_model_config("orchestrator") + except Exception as e: + logger.warning(f"Could not load model config '{model_config_name}': {e}, falling back to orchestrator") + model_config = get_model_config("orchestrator") + + return { + "mode": yaml_config.get("mode", "react"), # Default to react (agent-based) + "examples_path": yaml_config.get("examples_path"), # None if not specified + "model_config": model_config, + } + + +def _generate_placeholder_yaml(objective: str, strategy: XOptStrategy) -> str: + """Generate placeholder XOpt YAML configuration. + + PLACEHOLDER: This generates a minimal valid YAML structure. + Used when yaml_generation.mode is "mock". + + DO NOT add accelerator-specific parameters without operator input. + """ + return f"""# XOpt Optimization Configuration +# PLACEHOLDER - Generated for: {objective} +# Strategy: {strategy.value} + +# NOTE: This is a MOCK configuration for testing the workflow. +# Set yaml_generation.mode: "react" to use the ReAct agent. + +generator: + name: random # Placeholder generator + # Real implementation would use appropriate generator based on strategy + +evaluator: + function: placeholder_objective + # Real implementation would define actual objective function + +vocs: + variables: + param_1: + type: continuous + lower: 0.0 + upper: 10.0 + param_2: + type: continuous + lower: -1.0 + upper: 1.0 + objectives: + objective_1: + type: minimize + constraints: {{}} + statics: {{}} + +n_initial: 5 +max_evaluations: 20 + +# NOTE: This is a placeholder configuration. +# Actual XOpt parameters will be determined based on: +# - Historical YAML examples from the facility +# - Operator-defined parameter bounds +# - Machine-specific safety constraints +""" + + +async def _generate_yaml_with_react_agent( + objective: str, + strategy: XOptStrategy, + examples_path: str | None, + model_config: dict[str, Any] | None = None, +) -> str: + """Generate YAML using the ReAct agent. + + The agent dynamically adapts: + - If examples_path has YAML files: Agent gets file tools and reads examples + - If no examples: Agent generates from built-in XOpt knowledge + + Args: + objective: The optimization objective + strategy: The selected strategy + examples_path: Path to example YAML files (optional) + model_config: Optional model configuration for the agent + + Returns: + Generated YAML configuration string + """ + from .agent import create_yaml_generation_agent + + # Check if examples path exists - if not, agent will work without file tools + if examples_path: + path = Path(examples_path) + if not path.exists(): + examples_path = None + + # Create and run the agent (it adapts based on whether examples exist) + agent = create_yaml_generation_agent( + examples_path=examples_path, + model_config=model_config, + ) + + try: + yaml_config = await agent.generate_yaml( + objective=objective, + strategy=strategy.value, + ) + return yaml_config + except Exception as e: + logger.warning(f"ReAct agent failed, falling back to mock: {e}") + return _generate_placeholder_yaml(objective, strategy) + + +def _validate_yaml(yaml_config: str) -> None: + """Validate generated YAML configuration. + + PLACEHOLDER: Basic validation only. + Real implementation would use XOpt schema validation. + """ + if not yaml_config or not yaml_config.strip(): + raise YamlGenerationError( + "Generated YAML is empty", + generated_yaml=yaml_config, + validation_errors=["Empty YAML configuration"], + ) + # Future: Add XOpt schema validation + + +def create_yaml_generation_node(): + """Create the YAML generation node for LangGraph integration. + + This factory function creates a node that generates XOpt YAML + configurations and prepares approval interrupt data. + + The generation mode is controlled via configuration: + - osprey.xopt_optimizer.yaml_generation.mode: "mock" | "react" + + Returns: + Async function that takes XOptExecutionState and returns state updates + """ + + async def yaml_generation_node(state: XOptExecutionState) -> dict[str, Any]: + """Generate XOpt YAML configuration. + + Supports two modes: + - "mock": Fast placeholder generation for testing (default) + - "react": ReAct agent reads examples and generates YAML + + Also prepares approval interrupt data following the Python + executor's analyzer pattern. + """ + node_logger = get_logger("xopt_optimizer", state=state) + + # Get configuration + yaml_gen_config = _get_yaml_generation_config() + mode = yaml_gen_config.get("mode", "mock") + is_mock = mode == "mock" + mode_indicator = " (mock)" if is_mock else "" + + node_logger.status(f"Generating XOpt configuration{mode_indicator}...") + + # Track generation attempts + attempt = state.get("yaml_generation_attempt", 0) + 1 + request = state.get("request") + strategy = state.get("selected_strategy", XOptStrategy.EXPLORATION) + objective = request.optimization_objective if request else "Unknown objective" + + try: + # Generate YAML configuration based on mode + if mode == "react": + yaml_config = await _generate_yaml_with_react_agent( + objective=objective, + strategy=strategy, + examples_path=yaml_gen_config.get("examples_path", DEFAULT_EXAMPLES_PATH), + model_config=yaml_gen_config.get("model_config"), + ) + else: + yaml_config = _generate_placeholder_yaml(objective, strategy) + + # Validate YAML + _validate_yaml(yaml_config) + + node_logger.key_info(f"YAML configuration generated{mode_indicator}") + + # Prepare approval interrupt data (following Python executor pattern) + requires_approval = request.require_approval if request else True + + if requires_approval: + # Import here to avoid circular imports + from osprey.approval.approval_system import create_xopt_approval_interrupt + + machine_state_details = state.get("machine_state_details") + + approval_interrupt_data = create_xopt_approval_interrupt( + yaml_config=yaml_config, + strategy=strategy.value, + objective=objective, + machine_state_details=machine_state_details, + step_objective=f"Execute XOpt optimization: {objective}", + ) + + return { + "generated_yaml": yaml_config, + "yaml_generation_attempt": attempt, + "yaml_generation_failed": False, + "requires_approval": True, + "approval_interrupt_data": approval_interrupt_data, + "current_stage": "approval", + } + else: + return { + "generated_yaml": yaml_config, + "yaml_generation_attempt": attempt, + "yaml_generation_failed": False, + "requires_approval": False, + "current_stage": "execution", + } + + except YamlGenerationError: + # Re-raise YAML generation errors + raise + + except Exception as e: + node_logger.warning(f"YAML generation failed: {e}") + + error = XOptError( + error_type="yaml_generation", + error_message=str(e), + stage="yaml_generation", + attempt_number=attempt, + ) + error_chain = list(state.get("error_chain", [])) + [error] + + # Check retry limit + max_retries = request.retries if request else 3 + retry_limit_exceeded = len(error_chain) >= max_retries + + return { + "yaml_generation_attempt": attempt, + "yaml_generation_failed": True, + "error_chain": error_chain, + "is_failed": retry_limit_exceeded, + "failure_reason": ( + f"YAML generation failed after {max_retries} attempts" + if retry_limit_exceeded + else None + ), + "current_stage": "yaml_gen" if not retry_limit_exceeded else "failed", + } + + return yaml_generation_node diff --git a/src/osprey/templates/apps/control_assistant/config.yml.j2 b/src/osprey/templates/apps/control_assistant/config.yml.j2 index b6bb0c574..b5ff424dc 100644 --- a/src/osprey/templates/apps/control_assistant/config.yml.j2 +++ b/src/osprey/templates/apps/control_assistant/config.yml.j2 @@ -57,6 +57,14 @@ models: provider: {{ default_provider }} model_id: {{ default_model }} max_tokens: 4096 # For channel finder semantic search + xopt_yaml_generation: + provider: {{ default_provider }} + model_id: {{ default_model }} + max_tokens: 4096 # For XOpt YAML configuration generation + xopt_decision: + provider: {{ default_provider }} + model_id: {{ default_model }} + max_tokens: 1024 # For XOpt strategy decision # ============================================================ # API CONFIGURATION @@ -495,6 +503,43 @@ python_executor: max_execution_retries: 3 execution_timeout_seconds: 600 +# ============================================================ +# XOPT OPTIMIZER CONFIGURATION +# ============================================================ +# XOpt optimization service for autonomous machine tuning +# Each sub-agent can have its own model configuration + +xopt_optimizer: + # Maximum optimization iterations per request + max_iterations: 3 + + # State Identification Node + # Assesses machine readiness before optimization using ReAct agent + state_identification: + mode: "react" # "react" (agent with tools, default) or "mock" (always returns READY) + mock_files: true # Use mock reference files (set false to use real files at reference_path) + # reference_path: "path/to/reference/docs" # Path to real reference files (when mock_files: false) + # model_config_name: "xopt_state_identification" # Reference to models section (falls back to orchestrator) + + # Strategy Decision Node + # Selects optimization strategy based on machine state and objective + decision: + mode: "llm" # "llm" (LLM-based with structured output) or "mock" (defaults to exploration) + model_config_name: "xopt_decision" # Reference to models section + + # YAML Generation Agent + # Generates XOpt configuration files using ReAct pattern + yaml_generation: + mode: "react" # "react" (agent-based, default) or "mock" (fast placeholder) + # examples_path: "_agent_data/xopt_examples/yaml_templates" # Optional - agent adapts if missing + model_config_name: "xopt_yaml_generation" # Reference to models section + + # Analysis Agent (placeholder for future implementation) + # Analyzes optimization results and decides continuation + # analysis: + # mode: "mock" # "mock" or "react" + # model_config_name: "xopt_analysis" + # ============================================================ # APPLICATION METADATA # ============================================================ diff --git a/src/osprey/utils/config.py b/src/osprey/utils/config.py index af793944f..45c4965f2 100644 --- a/src/osprey/utils/config.py +++ b/src/osprey/utils/config.py @@ -372,6 +372,8 @@ def _build_configurable(self) -> dict[str, Any]: "applications": self.get("applications", []), "current_application": self._get_current_application(), "registry_path": self.get("registry_path"), + # ===== XOPT OPTIMIZER ===== + "xopt_optimizer": self.get("xopt_optimizer", {}), } return configurable @@ -724,6 +726,33 @@ def get_agent_control_defaults() -> dict[str, Any]: return configurable.get("agent_control_defaults", {}) +def get_xopt_optimizer_config(config_path: str | None = None) -> dict[str, Any]: + """Get XOpt optimizer configuration with automatic context detection. + + This function provides access to the xopt_optimizer configuration section, + which controls the XOpt optimization service for autonomous machine tuning. + + Configuration structure: + xopt_optimizer: + max_iterations: 3 + state_identification: + mode: "react" # or "mock" + mock_files: true + decision: + mode: "llm" # or "mock" + yaml_generation: + mode: "react" # or "mock" + + Args: + config_path: Optional explicit path to configuration file + + Returns: + Dictionary with xopt_optimizer configuration + """ + configurable = _get_configurable(config_path) + return configurable.get("xopt_optimizer", {}) + + def get_session_info() -> dict[str, Any]: """Get session information with automatic context detection.""" configurable = _get_configurable() diff --git a/tests/conftest.py b/tests/conftest.py index 8598cf0c7..fe9799767 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -393,6 +393,21 @@ def get_registry_config(self): "models": { "orchestrator": {"provider": "openai", "model_id": "gpt-4"}, "python_code_generator": {"provider": "openai", "model_id": "gpt-4"}, + "xopt_yaml_generation": {"provider": "openai", "model_id": "gpt-4"}, + "xopt_decision": {"provider": "openai", "model_id": "gpt-4"}, + }, + "xopt_optimizer": { + "state_identification": { + "mode": "mock", # Use mock for fast tests + }, + "decision": { + "mode": "mock", # Use mock for fast tests + "model_config_name": "xopt_decision", + }, + "yaml_generation": { + "mode": "mock", # Use mock for fast tests + "model_config_name": "xopt_yaml_generation", + }, }, } @@ -488,6 +503,21 @@ def get_registry_config(self): "models": { "orchestrator": {"provider": "openai", "model_id": "gpt-4"}, "python_code_generator": {"provider": "openai", "model_id": "gpt-4"}, + "xopt_yaml_generation": {"provider": "openai", "model_id": "gpt-4"}, + "xopt_decision": {"provider": "openai", "model_id": "gpt-4"}, + }, + "xopt_optimizer": { + "state_identification": { + "mode": "mock", # Use mock for fast tests + }, + "decision": { + "mode": "mock", # Use mock for fast tests + "model_config_name": "xopt_decision", + }, + "yaml_generation": { + "mode": "mock", # Use mock for fast tests + "model_config_name": "xopt_yaml_generation", + }, }, } diff --git a/tests/services/xopt_optimizer/__init__.py b/tests/services/xopt_optimizer/__init__.py new file mode 100644 index 000000000..5a9c87b73 --- /dev/null +++ b/tests/services/xopt_optimizer/__init__.py @@ -0,0 +1 @@ +"""Tests for XOpt Optimizer Service.""" diff --git a/tests/services/xopt_optimizer/test_state_identification.py b/tests/services/xopt_optimizer/test_state_identification.py new file mode 100644 index 000000000..dc7f1c6a3 --- /dev/null +++ b/tests/services/xopt_optimizer/test_state_identification.py @@ -0,0 +1,274 @@ +"""Unit Tests for State Identification Agent and Tools. + +This module tests the state identification subsystem including: +- Reference file tools (mock and real modes) +- Channel access tools +- State identification agent +""" + +import pytest + +from osprey.services.xopt_optimizer.models import MachineState +from osprey.services.xopt_optimizer.state_identification.tools.reference_files import ( + MOCK_REFERENCE_FILES, + create_list_files_tool, + create_read_file_tool, + create_reference_file_tools, +) + + +class TestReferenceFileTools: + """Test reference file tools.""" + + def test_mock_list_files_returns_files(self): + """Mock list files tool should return mock file names.""" + list_files = create_list_files_tool(mock_mode=True) + result = list_files.invoke({}) + + assert "machine_ready_criteria.md" in result + assert "optimization_channels.md" in result + assert "safety_procedures.md" in result + + def test_mock_read_file_returns_content(self): + """Mock read file tool should return mock content.""" + read_file = create_read_file_tool(mock_mode=True) + result = read_file.invoke({"filename": "machine_ready_criteria.md"}) + + assert "Machine Ready Criteria" in result + assert "BEAM:CURRENT" in result + assert "READY" in result + assert "NOT_READY" in result + + def test_mock_read_file_not_found(self): + """Mock read file tool should handle missing files.""" + read_file = create_read_file_tool(mock_mode=True) + result = read_file.invoke({"filename": "nonexistent.md"}) + + assert "not found" in result.lower() + assert "machine_ready_criteria.md" in result # Shows available files + + def test_real_mode_no_path_returns_message(self): + """Real mode without path should return informative message.""" + list_files = create_list_files_tool(reference_path=None, mock_mode=False) + result = list_files.invoke({}) + + assert "not configured" in result.lower() or "not specified" in result.lower() + + def test_real_mode_missing_path_returns_message(self): + """Real mode with non-existent path should return error message.""" + list_files = create_list_files_tool( + reference_path="/nonexistent/path", mock_mode=False + ) + result = list_files.invoke({}) + + assert "does not exist" in result.lower() + + def test_create_reference_file_tools_returns_two_tools(self): + """create_reference_file_tools should return list and read tools.""" + tools = create_reference_file_tools(mock_mode=True) + + assert len(tools) == 2 + tool_names = [t.name for t in tools] + assert "list_reference_files" in tool_names + assert "read_reference_file" in tool_names + + def test_mock_files_contain_expected_content(self): + """Mock files should contain machine state assessment content.""" + # Check that mock files have the essential content for the agent + criteria_file = MOCK_REFERENCE_FILES["machine_ready_criteria.md"] + assert "BEAM:CURRENT" in criteria_file + assert "SAFETY:INTERLOCK" in criteria_file + assert "READY" in criteria_file + assert "NOT_READY" in criteria_file + + channels_file = MOCK_REFERENCE_FILES["optimization_channels.md"] + assert "Channel Name" in channels_file + assert "VACUUM:PRESSURE" in channels_file + + +class TestReferenceFileToolsWithRealPath: + """Test reference file tools with real file paths.""" + + def test_real_mode_with_temp_dir(self, tmp_path): + """Real mode should read files from provided path.""" + # Create test file + test_file = tmp_path / "test_criteria.md" + test_file.write_text("# Test Criteria\nThis is a test file.") + + list_files = create_list_files_tool( + reference_path=str(tmp_path), mock_mode=False + ) + result = list_files.invoke({}) + + assert "test_criteria.md" in result + + def test_real_mode_read_file(self, tmp_path): + """Real mode should read file contents correctly.""" + # Create test file + test_file = tmp_path / "test_criteria.md" + test_content = "# Test Criteria\nBeam current must be > 10 mA" + test_file.write_text(test_content) + + read_file = create_read_file_tool( + reference_path=str(tmp_path), mock_mode=False + ) + result = read_file.invoke({"filename": "test_criteria.md"}) + + assert "Test Criteria" in result + assert "Beam current" in result + + def test_real_mode_prevents_path_traversal(self, tmp_path): + """Real mode should prevent reading files outside reference path.""" + read_file = create_read_file_tool( + reference_path=str(tmp_path), mock_mode=False + ) + + # Try to read parent directory file + result = read_file.invoke({"filename": "../../../etc/passwd"}) + + assert "denied" in result.lower() or "not found" in result.lower() + + +class TestChannelAccessTools: + """Test channel access tools.""" + + @pytest.mark.asyncio + async def test_read_channels_tool_creation(self): + """read_channels tool should be creatable.""" + from osprey.services.xopt_optimizer.state_identification.tools.channel_access import ( + create_read_channels_tool, + ) + + tool = create_read_channels_tool() + assert tool.name == "read_channel_values" + + @pytest.mark.asyncio + async def test_create_channel_access_tools_returns_list(self): + """create_channel_access_tools should return tool list.""" + from osprey.services.xopt_optimizer.state_identification.tools.channel_access import ( + create_channel_access_tools, + ) + + tools = create_channel_access_tools() + assert len(tools) == 1 + assert tools[0].name == "read_channel_values" + + +class TestMockConnectorStatusChannel: + """Test that MockConnector returns expected values for status channels.""" + + @pytest.mark.asyncio + async def test_machine_status_returns_one(self): + """MockConnector should return 1 for MACHINE:STATUS channel.""" + from osprey.connectors.control_system.mock_connector import MockConnector + + connector = MockConnector() + await connector.connect({}) + + result = await connector.read_channel("MACHINE:STATUS") + + # Value should be approximately 1 (with small noise) + assert 0.9 < result.value < 1.1, f"Expected ~1, got {result.value}" + + await connector.disconnect() + + @pytest.mark.asyncio + async def test_status_channels_return_one(self): + """MockConnector should return 1 for various status-like channels.""" + from osprey.connectors.control_system.mock_connector import MockConnector + + connector = MockConnector() + await connector.connect({"noise_level": 0}) # No noise for exact test + + for channel in ["MACHINE:STATUS", "BEAM:READY", "SYSTEM:ENABLE"]: + result = await connector.read_channel(channel) + assert result.value == 1.0, f"Expected 1 for {channel}, got {result.value}" + + await connector.disconnect() + + +class TestStateIdentificationAgent: + """Test the state identification agent.""" + + def test_agent_creation(self): + """Agent should be creatable with mock mode.""" + from osprey.services.xopt_optimizer.state_identification import ( + create_state_identification_agent, + ) + + # This should not raise - agent is created but not invoked + agent = create_state_identification_agent( + mock_files=True, + model_config={"provider": "openai", "model_id": "gpt-4"}, + ) + + assert agent is not None + assert agent.mock_files is True + + def test_agent_creation_with_reference_path(self, tmp_path): + """Agent should accept reference path.""" + from osprey.services.xopt_optimizer.state_identification import ( + create_state_identification_agent, + ) + + agent = create_state_identification_agent( + reference_path=str(tmp_path), + mock_files=False, + model_config={"provider": "openai", "model_id": "gpt-4"}, + ) + + assert agent.reference_path == str(tmp_path) + assert agent.mock_files is False + + def test_agent_tools_include_file_and_channel_tools(self): + """Agent should have both file and channel tools.""" + from osprey.services.xopt_optimizer.state_identification import ( + create_state_identification_agent, + ) + + agent = create_state_identification_agent( + mock_files=True, + model_config={"provider": "openai", "model_id": "gpt-4"}, + ) + + tools = agent._get_tools() + + tool_names = [t.name for t in tools] + assert "list_reference_files" in tool_names + assert "read_reference_file" in tool_names + assert "read_channel_values" in tool_names + + +class TestMachineStateAssessmentModel: + """Test the MachineStateAssessment Pydantic model.""" + + def test_assessment_model_creation(self): + """MachineStateAssessment should be creatable.""" + from osprey.services.xopt_optimizer.state_identification.agent import ( + MachineStateAssessment, + ) + + assessment = MachineStateAssessment( + state=MachineState.READY, + reasoning="All criteria met", + channels_checked=["BEAM:CURRENT", "VACUUM:PRESSURE"], + key_observations={"beam_current": 100.0}, + ) + + assert assessment.state == MachineState.READY + assert "All criteria" in assessment.reasoning + assert len(assessment.channels_checked) == 2 + + def test_assessment_model_defaults(self): + """MachineStateAssessment should have sensible defaults.""" + from osprey.services.xopt_optimizer.state_identification.agent import ( + MachineStateAssessment, + ) + + assessment = MachineStateAssessment( + state=MachineState.NOT_READY, + reasoning="Interlock active", + ) + + assert assessment.channels_checked == [] + assert assessment.key_observations == {} diff --git a/tests/services/xopt_optimizer/test_xopt_approval.py b/tests/services/xopt_optimizer/test_xopt_approval.py new file mode 100644 index 000000000..e5e2c37b5 --- /dev/null +++ b/tests/services/xopt_optimizer/test_xopt_approval.py @@ -0,0 +1,99 @@ +"""Unit Tests for XOpt Approval Interrupt Function. + +This module tests the create_xopt_approval_interrupt function. +""" + +from osprey.approval import create_xopt_approval_interrupt + + +class TestCreateXOptApprovalInterrupt: + """Test create_xopt_approval_interrupt function.""" + + def test_basic_interrupt_creation(self): + """Should create interrupt data with required fields.""" + result = create_xopt_approval_interrupt( + yaml_config="xopt:\n generator: random", + strategy="exploration", + objective="Maximize efficiency", + ) + + assert "user_message" in result + assert "resume_payload" in result + + # Check user message content + assert "HUMAN APPROVAL REQUIRED" in result["user_message"] + assert "Maximize efficiency" in result["user_message"] + assert "EXPLORATION" in result["user_message"] + assert "xopt:" in result["user_message"] + + # Check resume payload + payload = result["resume_payload"] + assert payload["approval_type"] == "xopt_optimizer" + assert payload["yaml_config"] == "xopt:\n generator: random" + assert payload["strategy"] == "exploration" + assert payload["objective"] == "Maximize efficiency" + + def test_interrupt_with_machine_state_details(self): + """Should include machine state details when provided.""" + machine_details = { + "beam_current": 50.0, + "status": "ready", + } + + result = create_xopt_approval_interrupt( + yaml_config="test: yaml", + strategy="optimization", + objective="Test objective", + machine_state_details=machine_details, + ) + + # Machine state should appear in message + assert "Machine State Assessment" in result["user_message"] + assert "beam_current" in result["user_message"] + + # Should be in payload + assert result["resume_payload"]["machine_state_details"] == machine_details + + def test_interrupt_with_custom_step_objective(self): + """Should use custom step objective.""" + result = create_xopt_approval_interrupt( + yaml_config="test: yaml", + strategy="exploration", + objective="Test", + step_objective="Custom optimization task", + ) + + assert "Custom optimization task" in result["user_message"] + assert result["resume_payload"]["step_objective"] == "Custom optimization task" + + def test_interrupt_contains_approval_instructions(self): + """Should contain clear approval instructions.""" + result = create_xopt_approval_interrupt( + yaml_config="test: yaml", + strategy="exploration", + objective="Test", + ) + + message = result["user_message"] + assert "yes" in message.lower() + assert "no" in message.lower() + assert "approve" in message.lower() + + def test_interrupt_yaml_displayed_correctly(self): + """Should display YAML in code block.""" + yaml_config = """xopt: + generator: + name: bayesian + vocs: + variables: + x1: [0, 10] +""" + result = create_xopt_approval_interrupt( + yaml_config=yaml_config, + strategy="optimization", + objective="Test", + ) + + # YAML should be in code block + assert "```yaml" in result["user_message"] + assert yaml_config in result["user_message"] diff --git a/tests/services/xopt_optimizer/test_xopt_exceptions.py b/tests/services/xopt_optimizer/test_xopt_exceptions.py new file mode 100644 index 000000000..4f526bd05 --- /dev/null +++ b/tests/services/xopt_optimizer/test_xopt_exceptions.py @@ -0,0 +1,149 @@ +"""Unit Tests for XOpt Optimizer Exceptions. + +This module tests the exception hierarchy for the XOpt optimizer service. +""" + +from osprey.services.xopt_optimizer.exceptions import ( + ConfigurationError, + ErrorCategory, + MachineStateAssessmentError, + MaxIterationsExceededError, + XOptExecutionError, + XOptExecutorException, + YamlGenerationError, +) + + +class TestErrorCategory: + """Test ErrorCategory enum.""" + + def test_error_categories_exist(self): + """All expected error categories should exist.""" + assert ErrorCategory.MACHINE_STATE.value == "machine_state" + assert ErrorCategory.YAML_GENERATION.value == "yaml_generation" + assert ErrorCategory.EXECUTION.value == "execution" + assert ErrorCategory.CONFIGURATION.value == "configuration" + assert ErrorCategory.WORKFLOW.value == "workflow" + + +class TestXOptExecutorException: + """Test base exception class.""" + + def test_base_exception_creation(self): + """Base exception should be creatable with message.""" + exc = XOptExecutorException("Test error message") + assert str(exc) == "Test error message" + assert exc.message == "Test error message" + assert exc.category == ErrorCategory.WORKFLOW # Default + + def test_base_exception_with_category(self): + """Base exception should accept custom category.""" + exc = XOptExecutorException("Test error", category=ErrorCategory.MACHINE_STATE) + assert exc.category == ErrorCategory.MACHINE_STATE + + def test_is_retriable(self): + """is_retriable should return True for retriable categories.""" + machine_exc = XOptExecutorException("Test", category=ErrorCategory.MACHINE_STATE) + yaml_exc = XOptExecutorException("Test", category=ErrorCategory.YAML_GENERATION) + workflow_exc = XOptExecutorException("Test", category=ErrorCategory.WORKFLOW) + + assert machine_exc.is_retriable() is True + assert yaml_exc.is_retriable() is True + assert workflow_exc.is_retriable() is False + + def test_should_retry_yaml_generation(self): + """should_retry_yaml_generation should return True for YAML errors.""" + yaml_exc = XOptExecutorException("Test", category=ErrorCategory.YAML_GENERATION) + other_exc = XOptExecutorException("Test", category=ErrorCategory.EXECUTION) + + assert yaml_exc.should_retry_yaml_generation() is True + assert other_exc.should_retry_yaml_generation() is False + + +class TestMachineStateAssessmentError: + """Test MachineStateAssessmentError.""" + + def test_creation(self): + """Should be creatable with message and details.""" + exc = MachineStateAssessmentError( + "Machine not ready", + assessment_details={"reason": "No beam"}, + ) + assert exc.message == "Machine not ready" + assert exc.category == ErrorCategory.MACHINE_STATE + assert exc.assessment_details == {"reason": "No beam"} + + def test_is_retriable(self): + """Machine state errors should be retriable.""" + exc = MachineStateAssessmentError("Test") + assert exc.is_retriable() is True + + +class TestYamlGenerationError: + """Test YamlGenerationError.""" + + def test_creation(self): + """Should be creatable with message and yaml details.""" + exc = YamlGenerationError( + "Invalid YAML", + generated_yaml="bad: yaml", + validation_errors=["Missing field X"], + ) + assert exc.message == "Invalid YAML" + assert exc.category == ErrorCategory.YAML_GENERATION + assert exc.generated_yaml == "bad: yaml" + assert exc.validation_errors == ["Missing field X"] + + def test_should_retry_yaml_generation(self): + """YAML generation errors should trigger retry.""" + exc = YamlGenerationError("Test") + assert exc.should_retry_yaml_generation() is True + + +class TestXOptExecutionError: + """Test XOptExecutionError.""" + + def test_creation(self): + """Should be creatable with message and execution details.""" + exc = XOptExecutionError( + "XOpt failed", + yaml_used="test: yaml", + xopt_error="Runtime error", + ) + assert exc.message == "XOpt failed" + assert exc.category == ErrorCategory.EXECUTION + assert exc.yaml_used == "test: yaml" + assert exc.xopt_error == "Runtime error" + + def test_not_retriable(self): + """Execution errors should not be retriable.""" + exc = XOptExecutionError("Test") + assert exc.is_retriable() is False + + +class TestMaxIterationsExceededError: + """Test MaxIterationsExceededError.""" + + def test_creation(self): + """Should be creatable with message and iteration count.""" + exc = MaxIterationsExceededError( + "Max iterations reached", + iterations_completed=5, + ) + assert exc.message == "Max iterations reached" + assert exc.category == ErrorCategory.WORKFLOW + assert exc.iterations_completed == 5 + + +class TestConfigurationError: + """Test ConfigurationError.""" + + def test_creation(self): + """Should be creatable with message and config key.""" + exc = ConfigurationError( + "Invalid config", + config_key="xopt_optimizer.max_iterations", + ) + assert exc.message == "Invalid config" + assert exc.category == ErrorCategory.CONFIGURATION + assert exc.config_key == "xopt_optimizer.max_iterations" diff --git a/tests/services/xopt_optimizer/test_xopt_service.py b/tests/services/xopt_optimizer/test_xopt_service.py new file mode 100644 index 000000000..ad10f9d1f --- /dev/null +++ b/tests/services/xopt_optimizer/test_xopt_service.py @@ -0,0 +1,523 @@ +"""Unit Tests for XOpt Optimizer Service. + +This module provides unit tests for the XOpt optimizer service including +service initialization, graph compilation, and basic workflow validation. + +Test Coverage: + - Service initialization and configuration + - LangGraph compilation + - State initialization + - Basic workflow routing +""" + +import os + +import pytest + +from osprey.services.xopt_optimizer import ( + MachineState, + XOptExecutionRequest, + XOptExecutionState, + XOptServiceResult, + XOptStrategy, +) +from osprey.services.xopt_optimizer.models import XOptError + +# ============================================================================= +# MODEL TESTS +# ============================================================================= + + +class TestXOptModels: + """Test XOpt data models.""" + + def test_machine_state_enum(self): + """MachineState enum should have expected values.""" + assert MachineState.READY.value == "ready" + assert MachineState.NOT_READY.value == "not_ready" + assert MachineState.UNKNOWN.value == "unknown" + + def test_xopt_strategy_enum(self): + """XOptStrategy enum should have expected values.""" + assert XOptStrategy.EXPLORATION.value == "exploration" + assert XOptStrategy.OPTIMIZATION.value == "optimization" + assert XOptStrategy.ABORT.value == "abort" + + def test_xopt_execution_request_creation(self): + """XOptExecutionRequest should be creatable with required fields.""" + request = XOptExecutionRequest( + user_query="Optimize injection efficiency", + optimization_objective="Maximize injection efficiency", + ) + assert request.user_query == "Optimize injection efficiency" + assert request.optimization_objective == "Maximize injection efficiency" + assert request.max_iterations == 3 # Default + assert request.require_approval is True # Default + + def test_xopt_execution_request_custom_params(self): + """XOptExecutionRequest should accept custom parameters.""" + request = XOptExecutionRequest( + user_query="Test query", + optimization_objective="Test objective", + max_iterations=5, + retries=2, + require_approval=False, + ) + assert request.max_iterations == 5 + assert request.retries == 2 + assert request.require_approval is False + + def test_xopt_error_dataclass(self): + """XOptError should be creatable and formattable.""" + error = XOptError( + error_type="test_error", + error_message="Test error message", + stage="yaml_generation", + attempt_number=1, + details={"key": "value"}, + ) + assert error.error_type == "test_error" + assert error.error_message == "Test error message" + assert error.stage == "yaml_generation" + assert error.attempt_number == 1 + + # Test prompt text formatting + prompt_text = error.to_prompt_text() + assert "YAML_GENERATION FAILED" in prompt_text + assert "Test error message" in prompt_text + + def test_xopt_service_result_creation(self): + """XOptServiceResult should be creatable with all fields.""" + result = XOptServiceResult( + run_artifact={"status": "completed"}, + generated_yaml="test: yaml", + strategy=XOptStrategy.EXPLORATION, + total_iterations=3, + analysis_summary={"summary": "test"}, + recommendations=("Recommendation 1", "Recommendation 2"), + ) + assert result.run_artifact == {"status": "completed"} + assert result.strategy == XOptStrategy.EXPLORATION + assert result.total_iterations == 3 + assert len(result.recommendations) == 2 + + +# ============================================================================= +# SERVICE INITIALIZATION TESTS +# ============================================================================= + + +class TestServiceInitialization: + """Test service initialization and configuration.""" + + def test_service_initializes(self, test_config): + """Service should initialize without errors.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + assert service is not None + assert service.config is not None + + def test_service_builds_graph(self, test_config): + """Service should build LangGraph on initialization.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + graph = service.get_compiled_graph() + assert graph is not None + + def test_service_creates_initial_state(self, test_config): + """Service should create proper initial state from request.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + request = XOptExecutionRequest( + user_query="Test query", + optimization_objective="Test objective", + max_iterations=5, + ) + + state = service._create_initial_state(request) + + assert state["request"] == request + assert state["max_iterations"] == 5 + assert state["iteration_count"] == 0 + assert state["is_successful"] is False + assert state["is_failed"] is False + assert state["current_stage"] == "state_id" + assert state["error_chain"] == [] + + +# ============================================================================= +# NODE TESTS +# ============================================================================= + + +class TestStateIdentificationNode: + """Test state identification node.""" + + @pytest.mark.asyncio + async def test_state_identification_returns_ready(self, test_config): + """State identification should return READY (mock mode).""" + os.environ["CONFIG_FILE"] = str(test_config) + + # Ensure config cache is cleared after setting CONFIG_FILE + from osprey.utils import config as config_module + + config_module._default_config = None + config_module._default_configurable = None + config_module._config_cache.clear() + + # Register MockConnector for channel access (if react mode runs) + from osprey.connectors.control_system.mock_connector import MockConnector + from osprey.connectors.factory import ConnectorFactory + + ConnectorFactory.register_control_system("mock", MockConnector) + + from osprey.services.xopt_optimizer.state_identification import ( + create_state_identification_node, + ) + + node = create_state_identification_node() + + # Create minimal state + state = XOptExecutionState( + request=XOptExecutionRequest( + user_query="Test", optimization_objective="Test objective" + ), + capability_context_data=None, + error_chain=[], + yaml_generation_attempt=0, + machine_state=None, + machine_state_details=None, + selected_strategy=None, + decision_reasoning=None, + generated_yaml=None, + yaml_generation_failed=None, + requires_approval=None, + approval_interrupt_data=None, + approval_result=None, + approved=None, + run_artifact=None, + execution_error=None, + execution_failed=None, + analysis_result=None, + recommendations=None, + iteration_count=0, + max_iterations=3, + should_continue=False, + is_successful=False, + is_failed=False, + failure_reason=None, + current_stage="state_id", + ) + + result = await node(state) + + assert result["machine_state"] == MachineState.READY + assert result["current_stage"] == "decision" + assert "machine_state_details" in result + + +class TestDecisionNode: + """Test decision node.""" + + @pytest.mark.asyncio + async def test_decision_routes_to_yaml_gen_when_ready(self, test_config): + """Decision node should route to yaml_gen when machine is READY.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer.decision import create_decision_node + + node = create_decision_node() + + state = XOptExecutionState( + request=XOptExecutionRequest( + user_query="Test", optimization_objective="Test objective" + ), + capability_context_data=None, + error_chain=[], + yaml_generation_attempt=0, + machine_state=MachineState.READY, + machine_state_details={"assessment": "test"}, + selected_strategy=None, + decision_reasoning=None, + generated_yaml=None, + yaml_generation_failed=None, + requires_approval=None, + approval_interrupt_data=None, + approval_result=None, + approved=None, + run_artifact=None, + execution_error=None, + execution_failed=None, + analysis_result=None, + recommendations=None, + iteration_count=0, + max_iterations=3, + should_continue=False, + is_successful=False, + is_failed=False, + failure_reason=None, + current_stage="decision", + ) + + result = await node(state) + + assert result["selected_strategy"] == XOptStrategy.EXPLORATION + assert result["current_stage"] == "yaml_gen" + assert "decision_reasoning" in result + + @pytest.mark.asyncio + async def test_decision_aborts_when_not_ready(self, test_config): + """Decision node should abort when machine is NOT_READY.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer.decision import create_decision_node + + node = create_decision_node() + + state = XOptExecutionState( + request=XOptExecutionRequest( + user_query="Test", optimization_objective="Test objective" + ), + capability_context_data=None, + error_chain=[], + yaml_generation_attempt=0, + machine_state=MachineState.NOT_READY, + machine_state_details={"reason": "Machine offline"}, + selected_strategy=None, + decision_reasoning=None, + generated_yaml=None, + yaml_generation_failed=None, + requires_approval=None, + approval_interrupt_data=None, + approval_result=None, + approved=None, + run_artifact=None, + execution_error=None, + execution_failed=None, + analysis_result=None, + recommendations=None, + iteration_count=0, + max_iterations=3, + should_continue=False, + is_successful=False, + is_failed=False, + failure_reason=None, + current_stage="decision", + ) + + result = await node(state) + + assert result["selected_strategy"] == XOptStrategy.ABORT + assert result["is_failed"] is True + assert result["current_stage"] == "failed" + + +class TestAnalysisNode: + """Test analysis node.""" + + @pytest.mark.asyncio + async def test_analysis_continues_when_under_max_iterations(self, test_config): + """Analysis should continue when under max iterations.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer.analysis import create_analysis_node + + node = create_analysis_node() + + state = XOptExecutionState( + request=XOptExecutionRequest( + user_query="Test", optimization_objective="Test objective" + ), + capability_context_data=None, + error_chain=[], + yaml_generation_attempt=0, + machine_state=MachineState.READY, + machine_state_details={}, + selected_strategy=XOptStrategy.EXPLORATION, + decision_reasoning="test", + generated_yaml="test: yaml", + yaml_generation_failed=False, + requires_approval=False, + approval_interrupt_data=None, + approval_result=None, + approved=True, + run_artifact={"status": "completed"}, + execution_error=None, + execution_failed=False, + analysis_result=None, + recommendations=None, + iteration_count=0, # First iteration + max_iterations=3, + should_continue=False, + is_successful=False, + is_failed=False, + failure_reason=None, + current_stage="analysis", + ) + + result = await node(state) + + assert result["should_continue"] is True + assert result["iteration_count"] == 1 + assert result["current_stage"] == "state_id" + + @pytest.mark.asyncio + async def test_analysis_completes_at_max_iterations(self, test_config): + """Analysis should complete when reaching max iterations.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer.analysis import create_analysis_node + + node = create_analysis_node() + + state = XOptExecutionState( + request=XOptExecutionRequest( + user_query="Test", optimization_objective="Test objective" + ), + capability_context_data=None, + error_chain=[], + yaml_generation_attempt=0, + machine_state=MachineState.READY, + machine_state_details={}, + selected_strategy=XOptStrategy.EXPLORATION, + decision_reasoning="test", + generated_yaml="test: yaml", + yaml_generation_failed=False, + requires_approval=False, + approval_interrupt_data=None, + approval_result=None, + approved=True, + run_artifact={"status": "completed"}, + execution_error=None, + execution_failed=False, + analysis_result=None, + recommendations=None, + iteration_count=2, # Already at iteration 2, next will be 3 (max) + max_iterations=3, + should_continue=False, + is_successful=False, + is_failed=False, + failure_reason=None, + current_stage="analysis", + ) + + result = await node(state) + + assert result["should_continue"] is False + assert result["iteration_count"] == 3 + assert result["is_successful"] is True + assert result["current_stage"] == "complete" + + +# ============================================================================= +# ROUTING TESTS +# ============================================================================= + + +class TestServiceRouting: + """Test service routing logic.""" + + def test_decision_router_continues_on_ready(self, test_config): + """Decision router should return 'continue' when ready.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + + state = { + "is_failed": False, + "selected_strategy": XOptStrategy.EXPLORATION, + } + + result = service._decision_router(state) + assert result == "continue" + + def test_decision_router_aborts_on_failed(self, test_config): + """Decision router should return 'abort' when failed.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + + state = { + "is_failed": True, + "selected_strategy": XOptStrategy.EXPLORATION, + } + + result = service._decision_router(state) + assert result == "abort" + + def test_decision_router_aborts_on_abort_strategy(self, test_config): + """Decision router should return 'abort' when strategy is ABORT.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + + state = { + "is_failed": False, + "selected_strategy": XOptStrategy.ABORT, + } + + result = service._decision_router(state) + assert result == "abort" + + def test_approval_router_approved(self, test_config): + """Approval router should return 'approved' when approved.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + + state = {"approved": True} + result = service._approval_router(state) + assert result == "approved" + + def test_approval_router_rejected(self, test_config): + """Approval router should return 'rejected' when not approved.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + + state = {"approved": False} + result = service._approval_router(state) + assert result == "rejected" + + def test_loop_router_continues(self, test_config): + """Loop router should return 'continue' when should_continue is True.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + + state = {"is_failed": False, "should_continue": True} + result = service._loop_router(state) + assert result == "continue" + + def test_loop_router_completes(self, test_config): + """Loop router should return 'complete' when should_continue is False.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import XOptOptimizerService + + service = XOptOptimizerService() + + state = {"is_failed": False, "should_continue": False} + result = service._loop_router(state) + assert result == "complete" diff --git a/tests/services/xopt_optimizer/test_xopt_workflow.py b/tests/services/xopt_optimizer/test_xopt_workflow.py new file mode 100644 index 000000000..d648d1656 --- /dev/null +++ b/tests/services/xopt_optimizer/test_xopt_workflow.py @@ -0,0 +1,93 @@ +"""Integration test for XOpt optimizer service workflow. + +This test runs the full service workflow without approval to verify +the placeholder implementation works end-to-end. +""" + +import os + +import pytest + + +class TestXOptWorkflow: + """Test complete XOpt workflow execution.""" + + @pytest.mark.asyncio + async def test_full_workflow_without_approval(self, test_config): + """Test complete workflow execution without approval. + + This runs the service through all nodes: + state_id -> decision -> yaml_gen -> execution -> analysis + """ + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import ( + XOptExecutionRequest, + XOptOptimizerService, + XOptServiceResult, + XOptStrategy, + ) + + service = XOptOptimizerService() + + # Create request with approval disabled + request = XOptExecutionRequest( + user_query="Optimize injection efficiency", + optimization_objective="Maximize injection efficiency", + max_iterations=2, + require_approval=False, # Skip approval for this test + ) + + # Configure for execution + config = { + "configurable": { + "thread_id": "test_workflow", + "checkpoint_ns": "xopt_test", + } + } + + # Run the service + result = await service.ainvoke(request, config) + + # Verify result structure + assert isinstance(result, XOptServiceResult) + assert result.strategy == XOptStrategy.EXPLORATION + assert result.total_iterations == 2 # We set max_iterations=2 + assert result.generated_yaml is not None + # Check for valid XOpt YAML structure (generator and vocs are required) + yaml_lower = result.generated_yaml.lower() + assert "generator" in yaml_lower, "Generated YAML should contain generator config" + assert "vocs" in yaml_lower, "Generated YAML should contain vocs config" + assert len(result.recommendations) > 0 + + @pytest.mark.asyncio + async def test_single_iteration_workflow(self, test_config): + """Test workflow with single iteration.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import ( + XOptExecutionRequest, + XOptOptimizerService, + XOptServiceResult, + ) + + service = XOptOptimizerService() + + request = XOptExecutionRequest( + user_query="Quick test", + optimization_objective="Test objective", + max_iterations=1, + require_approval=False, + ) + + config = { + "configurable": { + "thread_id": "test_single", + "checkpoint_ns": "xopt_test", + } + } + + result = await service.ainvoke(request, config) + + assert isinstance(result, XOptServiceResult) + assert result.total_iterations == 1 From 52d28f5b9d371b2573836fcb23566029de462a24 Mon Sep 17 00:00:00 2001 From: ThorstenHellert Date: Mon, 19 Jan 2026 07:17:06 -0800 Subject: [PATCH 07/14] docs(xopt): Add Badger Environment integration notes to execution node Document how XOpt/Badger's Environment abstraction maps to Osprey's ConnectorFactory. The Environment defines variables/observables while the Interface (OspreyInterface) bridges to Osprey's connector system. This provides the architectural insight for XOpt integration without implementing domain-specific logic that requires expert input. Also includes ruff formatting fixes from quick_check. --- .../services/xopt_optimizer/decision/node.py | 7 +--- .../services/xopt_optimizer/execution/node.py | 37 +++++++++++++++++++ .../state_identification/agent.py | 4 +- .../state_identification/node.py | 3 +- .../xopt_optimizer/yaml_generation/agent.py | 20 ++++++---- .../xopt_optimizer/yaml_generation/node.py | 8 +++- .../test_state_identification.py | 16 ++------ 7 files changed, 63 insertions(+), 32 deletions(-) diff --git a/src/osprey/services/xopt_optimizer/decision/node.py b/src/osprey/services/xopt_optimizer/decision/node.py index a4ed56def..220ad4dd4 100644 --- a/src/osprey/services/xopt_optimizer/decision/node.py +++ b/src/osprey/services/xopt_optimizer/decision/node.py @@ -35,9 +35,7 @@ class StrategyDecision(BaseModel): strategy: XOptStrategy = Field( description="The optimization strategy: 'exploration' or 'optimization'" ) - reasoning: str = Field( - description="Brief explanation of why this strategy was selected" - ) + reasoning: str = Field(description="Brief explanation of why this strategy was selected") # ============================================================================= @@ -109,8 +107,7 @@ def _get_decision_config() -> dict[str, Any]: model_config = get_model_config("orchestrator") except Exception as e: logger.warning( - f"Could not load model config '{model_config_name}': {e}, " - "falling back to orchestrator" + f"Could not load model config '{model_config_name}': {e}, falling back to orchestrator" ) model_config = get_model_config("orchestrator") diff --git a/src/osprey/services/xopt_optimizer/execution/node.py b/src/osprey/services/xopt_optimizer/execution/node.py index fed5b09f1..0b9442c32 100644 --- a/src/osprey/services/xopt_optimizer/execution/node.py +++ b/src/osprey/services/xopt_optimizer/execution/node.py @@ -11,6 +11,39 @@ - Result artifact capture DO NOT add accelerator-specific execution logic without operator input. + +## Badger/XOpt Environment Integration + +XOpt/Badger uses an "Environment" abstraction that defines the optimization problem: +- **variables**: Tunable parameters with bounds (e.g., magnet setpoints) +- **observables**: Measurable outputs (e.g., beam position, emittance) + +The Environment communicates with the control system via an "Interface" that +implements `get_values()` and `set_values()`. This maps naturally to Osprey's +ConnectorFactory: + +```python +# Example: Badger Interface using Osprey's ConnectorFactory +from osprey.connectors.factory import ConnectorFactory + +class OspreyInterface(Interface): + name = "osprey" + + async def get_values(self, channel_names): + connector = await ConnectorFactory.create_control_system_connector() + return {name: (await connector.read_channel(name)).value + for name in channel_names} + + async def set_values(self, channel_inputs): + connector = await ConnectorFactory.create_control_system_connector() + for name, value in channel_inputs.items(): + await connector.write_channel(name, value) +``` + +This allows XOpt to work with any control system backend (EPICS, mock, etc.) +configured in Osprey's config.yml, including generated soft IOCs. + +See: https://github.com/xopt-org/Badger (Environment and Interface classes) """ from typing import Any @@ -30,9 +63,13 @@ async def _run_xopt_placeholder(yaml_config: str) -> dict[str, Any]: TODO: Replace with actual XOpt prototype integration. This will involve: - Parsing the YAML configuration + - Creating a Badger Environment with OspreyInterface (see module docstring) - Setting up XOpt with proper generator and evaluator - Running the optimization loop - Capturing results and artifacts + + The Environment defines variables/observables; the OspreyInterface + bridges to Osprey's ConnectorFactory for control system access. """ return { "status": "completed", diff --git a/src/osprey/services/xopt_optimizer/state_identification/agent.py b/src/osprey/services/xopt_optimizer/state_identification/agent.py index c73b1f555..0cddb219a 100644 --- a/src/osprey/services/xopt_optimizer/state_identification/agent.py +++ b/src/osprey/services/xopt_optimizer/state_identification/agent.py @@ -247,9 +247,7 @@ async def assess_state( # Get the last message content last_message = messages[-1] content = ( - last_message.content - if hasattr(last_message, "content") - else str(last_message) + last_message.content if hasattr(last_message, "content") else str(last_message) ) # Parse the assessment from the response diff --git a/src/osprey/services/xopt_optimizer/state_identification/node.py b/src/osprey/services/xopt_optimizer/state_identification/node.py index bf3b960b9..db643943a 100644 --- a/src/osprey/services/xopt_optimizer/state_identification/node.py +++ b/src/osprey/services/xopt_optimizer/state_identification/node.py @@ -62,8 +62,7 @@ def _get_state_identification_config() -> dict[str, Any]: model_config = get_model_config("orchestrator") except Exception as e: logger.warning( - f"Could not load model config '{model_config_name}': {e}, " - "falling back to orchestrator" + f"Could not load model config '{model_config_name}': {e}, falling back to orchestrator" ) model_config = get_model_config("orchestrator") diff --git a/src/osprey/services/xopt_optimizer/yaml_generation/agent.py b/src/osprey/services/xopt_optimizer/yaml_generation/agent.py index 1e09cf929..c33557277 100644 --- a/src/osprey/services/xopt_optimizer/yaml_generation/agent.py +++ b/src/osprey/services/xopt_optimizer/yaml_generation/agent.py @@ -196,9 +196,9 @@ def read_yaml_file(filename: str) -> str: file_path = file_path.resolve() examples_resolved = examples_path.resolve() if not str(file_path).startswith(str(examples_resolved)): - return f"Error: Cannot read files outside examples directory." + return "Error: Cannot read files outside examples directory." except Exception: - return f"Error: Invalid file path." + return "Error: Invalid file path." if not file_path.exists(): # Try searching subdirectories @@ -362,9 +362,15 @@ async def generate_yaml( **Optimization Objective:** {objective} **Strategy:** {strategy} -{"First, use the tools to read available example configurations. " if self._has_examples else ""}Generate a complete, valid YAML configuration based on { - "what you learn from the examples" if self._has_examples else "your knowledge of XOpt configuration patterns" -}. +{ + "First, use the tools to read available example configurations. " + if self._has_examples + else "" + }Generate a complete, valid YAML configuration based on { + "what you learn from the examples" + if self._has_examples + else "your knowledge of XOpt configuration patterns" + }. Remember: - Use generic parameter names unless specific names are provided @@ -396,9 +402,7 @@ async def generate_yaml( # Get the last message content last_message = messages[-1] content = ( - last_message.content - if hasattr(last_message, "content") - else str(last_message) + last_message.content if hasattr(last_message, "content") else str(last_message) ) # Extract YAML from response diff --git a/src/osprey/services/xopt_optimizer/yaml_generation/node.py b/src/osprey/services/xopt_optimizer/yaml_generation/node.py index 27584b0dc..9291e41cb 100644 --- a/src/osprey/services/xopt_optimizer/yaml_generation/node.py +++ b/src/osprey/services/xopt_optimizer/yaml_generation/node.py @@ -55,10 +55,14 @@ def _get_yaml_generation_config() -> dict[str, Any]: model_config = get_model_config(model_config_name) # Check if the model config is valid (has provider) if not model_config or not model_config.get("provider"): - logger.debug(f"Model '{model_config_name}' not configured, falling back to orchestrator") + logger.debug( + f"Model '{model_config_name}' not configured, falling back to orchestrator" + ) model_config = get_model_config("orchestrator") except Exception as e: - logger.warning(f"Could not load model config '{model_config_name}': {e}, falling back to orchestrator") + logger.warning( + f"Could not load model config '{model_config_name}': {e}, falling back to orchestrator" + ) model_config = get_model_config("orchestrator") return { diff --git a/tests/services/xopt_optimizer/test_state_identification.py b/tests/services/xopt_optimizer/test_state_identification.py index dc7f1c6a3..20f95114a 100644 --- a/tests/services/xopt_optimizer/test_state_identification.py +++ b/tests/services/xopt_optimizer/test_state_identification.py @@ -56,9 +56,7 @@ def test_real_mode_no_path_returns_message(self): def test_real_mode_missing_path_returns_message(self): """Real mode with non-existent path should return error message.""" - list_files = create_list_files_tool( - reference_path="/nonexistent/path", mock_mode=False - ) + list_files = create_list_files_tool(reference_path="/nonexistent/path", mock_mode=False) result = list_files.invoke({}) assert "does not exist" in result.lower() @@ -95,9 +93,7 @@ def test_real_mode_with_temp_dir(self, tmp_path): test_file = tmp_path / "test_criteria.md" test_file.write_text("# Test Criteria\nThis is a test file.") - list_files = create_list_files_tool( - reference_path=str(tmp_path), mock_mode=False - ) + list_files = create_list_files_tool(reference_path=str(tmp_path), mock_mode=False) result = list_files.invoke({}) assert "test_criteria.md" in result @@ -109,9 +105,7 @@ def test_real_mode_read_file(self, tmp_path): test_content = "# Test Criteria\nBeam current must be > 10 mA" test_file.write_text(test_content) - read_file = create_read_file_tool( - reference_path=str(tmp_path), mock_mode=False - ) + read_file = create_read_file_tool(reference_path=str(tmp_path), mock_mode=False) result = read_file.invoke({"filename": "test_criteria.md"}) assert "Test Criteria" in result @@ -119,9 +113,7 @@ def test_real_mode_read_file(self, tmp_path): def test_real_mode_prevents_path_traversal(self, tmp_path): """Real mode should prevent reading files outside reference path.""" - read_file = create_read_file_tool( - reference_path=str(tmp_path), mock_mode=False - ) + read_file = create_read_file_tool(reference_path=str(tmp_path), mock_mode=False) # Try to read parent directory file result = read_file.invoke({"filename": "../../../etc/passwd"}) From d4bd69414e96ed28378e603fe089159dc70ca502 Mon Sep 17 00:00:00 2001 From: Gianluca Martino Date: Thu, 26 Feb 2026 11:10:48 -0800 Subject: [PATCH 08/14] fix(xopt): replace str+Enum with StrEnum to resolve ruff UP042 lint errors --- src/osprey/approval/__init__.py | 2 +- src/osprey/services/xopt_optimizer/exceptions.py | 4 ++-- src/osprey/services/xopt_optimizer/models.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/osprey/approval/__init__.py b/src/osprey/approval/__init__.py index dd2d87271..8903dbe70 100644 --- a/src/osprey/approval/__init__.py +++ b/src/osprey/approval/__init__.py @@ -56,8 +56,8 @@ create_code_approval_interrupt, create_memory_approval_interrupt, create_plan_approval_interrupt, - create_xopt_approval_interrupt, create_step_approval_interrupt, + create_xopt_approval_interrupt, get_approval_resume_data, get_approved_payload_from_state, handle_service_with_interrupts, diff --git a/src/osprey/services/xopt_optimizer/exceptions.py b/src/osprey/services/xopt_optimizer/exceptions.py index 8d259ffa1..e9ba68a83 100644 --- a/src/osprey/services/xopt_optimizer/exceptions.py +++ b/src/osprey/services/xopt_optimizer/exceptions.py @@ -13,11 +13,11 @@ - WORKFLOW: Service-level workflow issues """ -from enum import Enum +from enum import StrEnum from typing import Any -class ErrorCategory(str, Enum): +class ErrorCategory(StrEnum): """Categorization of errors for retry logic.""" MACHINE_STATE = "machine_state" # Machine not ready - may retry after delay diff --git a/src/osprey/services/xopt_optimizer/models.py b/src/osprey/services/xopt_optimizer/models.py index 2f72d7f79..01effb034 100644 --- a/src/osprey/services/xopt_optimizer/models.py +++ b/src/osprey/services/xopt_optimizer/models.py @@ -30,7 +30,7 @@ import dataclasses from dataclasses import dataclass, field -from enum import Enum +from enum import StrEnum from typing import Annotated, Any, TypedDict from pydantic import BaseModel, Field @@ -68,7 +68,7 @@ def preserve_once_set(existing: Any | None, new: Any | None) -> Any | None: # ============================================================================= -class MachineState(str, Enum): +class MachineState(StrEnum): """Machine states for optimization readiness. NOTE: These are placeholders. Actual states will be determined @@ -87,7 +87,7 @@ class MachineState(str, Enum): # INTERLOCK_ACTIVE = "interlock_active" -class XOptStrategy(str, Enum): +class XOptStrategy(StrEnum): """Optimization strategy to execute.""" EXPLORATION = "exploration" # Explore parameter space From b54cd841432f63268474a0e34ae135f218fc6b9c Mon Sep 17 00:00:00 2001 From: Gianluca Martino Date: Fri, 13 Mar 2026 15:08:23 -0700 Subject: [PATCH 09/14] refactor(xopt): replace yaml_generation with config_generation and add tuning API client --- .../services/xopt_optimizer/__init__.py | 4 +- .../services/xopt_optimizer/analysis/node.py | 166 ++++-- .../services/xopt_optimizer/approval/node.py | 8 +- .../config_generation/__init__.py | 18 + .../xopt_optimizer/config_generation/agent.py | 178 +++++++ .../xopt_optimizer/config_generation/node.py | 389 ++++++++++++++ .../services/xopt_optimizer/decision/node.py | 4 +- .../services/xopt_optimizer/exceptions.py | 32 +- .../xopt_optimizer/execution/__init__.py | 12 +- .../xopt_optimizer/execution/api_client.py | 242 +++++++++ .../services/xopt_optimizer/execution/node.py | 114 ++-- src/osprey/services/xopt_optimizer/models.py | 16 +- src/osprey/services/xopt_optimizer/service.py | 58 ++- .../yaml_generation/__init__.py | 25 - .../xopt_optimizer/yaml_generation/agent.py | 488 ------------------ .../xopt_optimizer/yaml_generation/node.py | 308 ----------- src/osprey/utils/config.py | 4 +- 17 files changed, 1113 insertions(+), 953 deletions(-) create mode 100644 src/osprey/services/xopt_optimizer/config_generation/__init__.py create mode 100644 src/osprey/services/xopt_optimizer/config_generation/agent.py create mode 100644 src/osprey/services/xopt_optimizer/config_generation/node.py create mode 100644 src/osprey/services/xopt_optimizer/execution/api_client.py delete mode 100644 src/osprey/services/xopt_optimizer/yaml_generation/__init__.py delete mode 100644 src/osprey/services/xopt_optimizer/yaml_generation/agent.py delete mode 100644 src/osprey/services/xopt_optimizer/yaml_generation/node.py diff --git a/src/osprey/services/xopt_optimizer/__init__.py b/src/osprey/services/xopt_optimizer/__init__.py index 42bb0f262..1b0ae96ef 100644 --- a/src/osprey/services/xopt_optimizer/__init__.py +++ b/src/osprey/services/xopt_optimizer/__init__.py @@ -23,13 +23,13 @@ """ from .exceptions import ( + ConfigGenerationError, ConfigurationError, ErrorCategory, MachineStateAssessmentError, MaxIterationsExceededError, XOptExecutionError, XOptExecutorException, - YamlGenerationError, ) from .models import ( MachineState, @@ -55,7 +55,7 @@ "XOptExecutorException", "ErrorCategory", "MachineStateAssessmentError", - "YamlGenerationError", + "ConfigGenerationError", "XOptExecutionError", "MaxIterationsExceededError", "ConfigurationError", diff --git a/src/osprey/services/xopt_optimizer/analysis/node.py b/src/osprey/services/xopt_optimizer/analysis/node.py index 784720574..c189eea74 100644 --- a/src/osprey/services/xopt_optimizer/analysis/node.py +++ b/src/osprey/services/xopt_optimizer/analysis/node.py @@ -2,6 +2,11 @@ This node analyzes XOpt results and decides whether to continue with additional iterations or complete the optimization. + +When the run_artifact contains real data from the tuning_scripts API +(indicated by the presence of a ``job_id`` and ``data`` fields), the +node extracts the best point and generates meaningful recommendations. +Otherwise it falls back to the placeholder analysis path for testing. """ from typing import Any @@ -13,6 +18,79 @@ logger = get_logger("xopt_optimizer") +def _analyze_real_data(run_artifact: dict[str, Any]) -> dict[str, Any]: + """Analyze real optimization data from the tuning_scripts API. + + Args: + run_artifact: Full result dict containing ``data``, ``objective_name``, + ``variable_names``, etc. + + Returns: + Analysis dict with ``best_point``, ``total_evaluations``, and ``summary``. + """ + data = run_artifact.get("data") or [] + objective_name = run_artifact.get("objective_name") + variable_names = run_artifact.get("variable_names") or [] + + analysis: dict[str, Any] = { + "total_evaluations": len(data), + "objective_name": objective_name, + "variable_names": variable_names, + "job_id": run_artifact.get("job_id"), + } + + if not data or not objective_name: + analysis["summary"] = "No data or objective to analyze" + analysis["best_point"] = None + return analysis + + # Find best point (maximize by default β€” Xopt convention) + best_record = max( + (r for r in data if objective_name in r and r[objective_name] is not None), + key=lambda r: r[objective_name], + default=None, + ) + + if best_record is not None: + best_value = best_record[objective_name] + best_vars = {v: best_record.get(v) for v in variable_names if v in best_record} + analysis["best_point"] = {"objective_value": best_value, "variables": best_vars} + analysis["summary"] = ( + f"Best {objective_name} = {best_value:.6g} " + f"achieved at {', '.join(f'{k}={v:.4g}' for k, v in best_vars.items() if v is not None)}" + ) + else: + analysis["best_point"] = None + analysis["summary"] = f"No valid evaluations found for objective '{objective_name}'" + + return analysis + + +def _generate_real_recommendations( + analysis: dict[str, Any], iteration: int, max_iterations: int, should_continue: bool +) -> list[str]: + """Generate recommendations based on real optimization results.""" + recs: list[str] = [] + best = analysis.get("best_point") + total_evals = analysis.get("total_evaluations", 0) + + if best: + recs.append(f"Best result: {analysis['summary']}") + recs.append(f"Total evaluations: {total_evals}") + + if should_continue: + recs.append(f"Continuing to iteration {iteration + 1}/{max_iterations}") + else: + recs.append(f"Optimization complete after {iteration} iteration(s)") + if best and best.get("variables"): + recs.append( + "Recommended setpoints: " + + ", ".join(f"{k}={v:.4g}" for k, v in best["variables"].items() if v is not None) + ) + + return recs + + def create_analysis_node(): """Create the analysis node for LangGraph integration. @@ -26,56 +104,64 @@ def create_analysis_node(): async def analysis_node(state: XOptExecutionState) -> dict[str, Any]: """Analyze XOpt results and decide whether to continue. - Simple continuation logic based on iteration count. - Future implementation may include: - - Convergence detection - - Improvement rate analysis - - Domain-specific completion criteria + Routes real API data through ``_analyze_real_data`` when available, + otherwise falls back to the placeholder path. """ node_logger = get_logger("xopt_optimizer", state=state) node_logger.status("Analyzing optimization results...") - run_artifact = state.get("run_artifact") + run_artifact = state.get("run_artifact") or {} iteration = state.get("iteration_count", 0) + 1 max_iterations = state.get("max_iterations", 3) - # Simple continuation logic (can be refined) - # Future: Add convergence detection, improvement rate analysis, etc. should_continue = iteration < max_iterations - # Generate analysis result - # NOTE: This is a placeholder implementation for testing the workflow - analysis_result = { - "status": "PLACEHOLDER_TEST_SUCCESS", - "message": "XOpt optimizer service workflow test completed successfully", - "iteration": iteration, - "max_iterations": max_iterations, - "run_artifact": run_artifact, - "should_continue": should_continue, - "note": ( - "This is a placeholder implementation. All subsystems (state identification, " - "decision, YAML generation, approval, execution, analysis) executed successfully " - "with placeholder logic. Real optimization will be implemented when domain " - "requirements are defined by facility operators." - ), - } - - # Generate recommendations (placeholder - clearly indicate test status) - recommendations = [] - if should_continue: - recommendations.append(f"[TEST] Continuing to iteration {iteration + 1}") - else: - recommendations.append( - f"[TEST SUCCESS] XOpt optimizer workflow completed {iteration} iterations successfully" - ) - recommendations.append( - "[PLACEHOLDER] All subsystems executed with placeholder logic - " - "ready for real implementation when domain requirements are defined" - ) - recommendations.append( - "[NEXT STEPS] Implement real machine state assessment, YAML generation, " - "and XOpt execution based on facility-specific requirements" + # ----- Real data path ----- + has_real_data = run_artifact.get("job_id") and run_artifact.get("data") + + if has_real_data: + node_logger.info(f"Analyzing real data from job {run_artifact['job_id']}") + analysis_result = _analyze_real_data(run_artifact) + analysis_result["iteration"] = iteration + analysis_result["max_iterations"] = max_iterations + analysis_result["should_continue"] = should_continue + analysis_result["status"] = "success" + + recommendations = _generate_real_recommendations( + analysis_result, iteration, max_iterations, should_continue ) + else: + # ----- Placeholder path (backward compatible) ----- + analysis_result = { + "status": "PLACEHOLDER_TEST_SUCCESS", + "message": "XOpt optimizer service workflow test completed successfully", + "iteration": iteration, + "max_iterations": max_iterations, + "run_artifact": run_artifact, + "should_continue": should_continue, + "note": ( + "This is a placeholder implementation. All subsystems (state identification, " + "decision, config generation, approval, execution, analysis) executed successfully " + "with placeholder logic. Real optimization will be implemented when domain " + "requirements are defined by facility operators." + ), + } + + recommendations = [] + if should_continue: + recommendations.append(f"[TEST] Continuing to iteration {iteration + 1}") + else: + recommendations.append( + f"[TEST SUCCESS] XOpt optimizer workflow completed {iteration} iterations successfully" + ) + recommendations.append( + "[PLACEHOLDER] All subsystems executed with placeholder logic - " + "ready for real implementation when domain requirements are defined" + ) + recommendations.append( + "[NEXT STEPS] Implement real machine state assessment, YAML generation, " + "and XOpt execution based on facility-specific requirements" + ) node_logger.info(f"Iteration {iteration}/{max_iterations} complete") diff --git a/src/osprey/services/xopt_optimizer/approval/node.py b/src/osprey/services/xopt_optimizer/approval/node.py index 4d9a49eed..7cd6728e5 100644 --- a/src/osprey/services/xopt_optimizer/approval/node.py +++ b/src/osprey/services/xopt_optimizer/approval/node.py @@ -2,7 +2,7 @@ This node handles human approval for XOpt configurations using the standard Osprey LangGraph interrupt pattern. The approval interrupt data is pre-created -by the yaml_generation node, following the pattern from Python executor's +by the config_generation node, following the pattern from Python executor's analyzer node. """ @@ -24,7 +24,7 @@ def create_approval_node(): clean interrupt handler. The node is designed with single responsibility: processing LangGraph interrupts for user approval. - The approval interrupt data is pre-created by the yaml_generation node, + The approval interrupt data is pre-created by the config_generation node, following the pattern from Python executor's analyzer node. Returns: @@ -38,12 +38,12 @@ async def approval_node(state: XOptExecutionState) -> dict[str, Any]: node_logger = get_logger("xopt_optimizer", state=state) node_logger.status("Requesting human approval...") - # Get the pre-created interrupt data from yaml_generation node + # Get the pre-created interrupt data from config_generation node interrupt_data = state.get("approval_interrupt_data") if not interrupt_data: raise RuntimeError( "No approval interrupt data found in state. " - "The yaml_generation node should create this data." + "The config_generation node should create this data." ) node_logger.info("Requesting human approval for XOpt configuration") diff --git a/src/osprey/services/xopt_optimizer/config_generation/__init__.py b/src/osprey/services/xopt_optimizer/config_generation/__init__.py new file mode 100644 index 000000000..5483674cf --- /dev/null +++ b/src/osprey/services/xopt_optimizer/config_generation/__init__.py @@ -0,0 +1,18 @@ +"""Config Generation Subsystem for XOpt Optimizer. + +This subsystem generates OptimizationConfig dicts for the tuning_scripts API using: +1. Structured mode: LLM with structured output fills in config fields +2. Mock mode: Placeholder config for quick testing (use for fast iteration) + +The mode is controlled via configuration: + osprey.xopt_optimizer.config_generation.mode: "structured" | "mock" +""" + +from .agent import ConfigGenerationAgent, create_config_generation_agent +from .node import create_config_generation_node + +__all__ = [ + "create_config_generation_node", + "ConfigGenerationAgent", + "create_config_generation_agent", +] diff --git a/src/osprey/services/xopt_optimizer/config_generation/agent.py b/src/osprey/services/xopt_optimizer/config_generation/agent.py new file mode 100644 index 000000000..5eae53fa9 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/config_generation/agent.py @@ -0,0 +1,178 @@ +"""Structured Output Agent for Optimization Config Generation. + +This module provides an agent that generates OptimizationConfig dicts for the +tuning_scripts API. Instead of generating raw Xopt YAML, the LLM fills in +structured config fields (environment, algorithm, iterations, etc.) which are +then submitted as JSON to the ``/optimization/start`` endpoint. + +This simplifies the LLM's job from "generate valid Xopt YAML" to "fill in +config fields" and reduces validation complexity. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + +from osprey.models.langchain import get_langchain_model +from osprey.utils.logger import get_logger + +logger = get_logger("xopt_optimizer") + + +# ============================================================================= +# STRUCTURED OUTPUT MODEL +# ============================================================================= + + +class OptimizationConfigOutput(BaseModel): + """Structured output model mirroring tuning_scripts' OptimizationConfig schema. + + The LLM fills in these fields based on the optimization objective and strategy. + Fields left as None are excluded from the final config dict. + """ + + environment_name: str | None = Field( + default=None, + description="Badger/Xopt environment name (e.g. 'als_injector_sim')", + ) + objective_name: str | None = Field( + default=None, + description="Name of the objective to optimize", + ) + algorithm: str = Field( + default="upper_confidence_bound", + description="Algorithm type: upper_confidence_bound, expected_improvement, mobo, random", + ) + n_iterations: int = Field( + default=20, + description="Number of optimization iterations to run", + ) + n_initial_samples: int | None = Field( + default=None, + description="Number of initial random samples before Bayesian optimization", + ) + variables: list[str] | None = Field( + default=None, + description="List of variable names to optimize", + ) + variable_overrides: dict[str, Any] | None = Field( + default=None, + description="Per-variable overrides (bounds, types, etc.)", + ) + + +# ============================================================================= +# SYSTEM PROMPT +# ============================================================================= + +CONFIG_GENERATION_PROMPT = """You are an optimization configuration generator for accelerator tuning. + +You must select appropriate settings for an optimization run based on the user's +objective and strategy. + +## Algorithm Selection Guide + +Based on the optimization strategy: + +**Exploration** (map the parameter space): +- Use "random" or "upper_confidence_bound" (with high exploration weight) +- Higher n_initial_samples for broader coverage + +**Optimization** (converge on optimal values): +- Use "expected_improvement" for single-objective +- Use "mobo" for multi-objective +- Use "upper_confidence_bound" for balanced explore/exploit + +## Available Algorithms +- "upper_confidence_bound" β€” Bayesian optimization with UCB acquisition (default, good general choice) +- "expected_improvement" β€” Bayesian optimization with EI acquisition (good for exploitation) +- "mobo" β€” Multi-objective Bayesian optimization +- "random" β€” Random sampling (good for initial exploration) + +## Important Notes +- Use generic/placeholder names unless the user provides specific names +- Do NOT invent specific accelerator channel names +- Set n_iterations based on the scope of the task (default 20) +- Set n_initial_samples only if you want to override the default +""" + + +# ============================================================================= +# AGENT CLASS +# ============================================================================= + + +class ConfigGenerationAgent: + """Agent for generating optimization config dicts via structured LLM output. + + Uses ``model.with_structured_output(OptimizationConfigOutput)`` for a single + LLM call that returns a validated Pydantic model, then dumps it to a dict. + """ + + def __init__(self, model_config: dict[str, Any] | None = None): + self.model_config = model_config + + def _get_model(self): + if self.model_config: + return get_langchain_model(model_config=self.model_config) + raise ValueError( + "No model_config provided to ConfigGenerationAgent. " + "Ensure xopt_optimizer.config_generation.model_config_name is set in config.yml " + "or that 'orchestrator' model is configured as fallback." + ) + + async def generate_config( + self, + objective: str, + strategy: str, + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Generate an optimization config dict using structured LLM output. + + Args: + objective: The optimization objective + strategy: The selected strategy ("exploration" or "optimization") + context: Optional additional context + + Returns: + Config dict with None values excluded + """ + model = self._get_model() + structured_model = model.with_structured_output(OptimizationConfigOutput) + + user_message = ( + f"Generate an optimization configuration for:\n\n" + f"**Objective:** {objective}\n" + f"**Strategy:** {strategy}\n" + ) + if context: + user_message += f"\n**Additional Context:** {context}\n" + + logger.info("Generating optimization config via structured output...") + + result = await structured_model.ainvoke( + [ + {"role": "system", "content": CONFIG_GENERATION_PROMPT}, + {"role": "user", "content": user_message}, + ] + ) + + config = result.model_dump(exclude_none=True) + logger.info(f"Config generation complete: {config}") + return config + + +def create_config_generation_agent( + model_config: dict[str, Any] | None = None, +) -> ConfigGenerationAgent: + """Factory function to create a config generation agent. + + Args: + model_config: Optional model configuration + + Returns: + Configured ConfigGenerationAgent instance + """ + return ConfigGenerationAgent(model_config=model_config) diff --git a/src/osprey/services/xopt_optimizer/config_generation/node.py b/src/osprey/services/xopt_optimizer/config_generation/node.py new file mode 100644 index 000000000..074d333d5 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/config_generation/node.py @@ -0,0 +1,389 @@ +"""Config Generation Node for XOpt Optimizer Service. + +This node generates OptimizationConfig dicts and prepares approval interrupt data. +It follows the Python executor's analyzer pattern where the node that generates +content also creates the approval interrupt data. + +Supports two modes (configured via osprey.xopt_optimizer.config_generation.mode): +- "structured": LLM with structured output fills in config fields +- "mock": Placeholder config for quick testing (use for fast iteration) + +DO NOT add accelerator-specific parameters without operator input. +""" + +from typing import Any + +from langgraph.types import interrupt + +from osprey.utils.config import get_model_config, get_xopt_optimizer_config +from osprey.utils.logger import get_logger + +from ..exceptions import ConfigGenerationError +from ..execution.api_client import TuningScriptsAPIError, TuningScriptsClient +from ..models import XOptError, XOptExecutionState, XOptStrategy + +logger = get_logger("xopt_optimizer") + +# Allowed algorithm values for validation +_ALLOWED_ALGORITHMS = frozenset({ + "upper_confidence_bound", + "expected_improvement", + "mobo", + "random", +}) + + +def _get_config_generation_config() -> dict[str, Any]: + """Get config generation configuration from osprey config. + + Reads from config structure: + xopt_optimizer: + config_generation: + mode: "structured" + model_config_name: "xopt_config_generation" + default_algorithm: "upper_confidence_bound" + default_environment: null + + Returns: + Configuration dict with mode, model_config, and defaults + """ + xopt_config = get_xopt_optimizer_config() + gen_config = xopt_config.get("config_generation", {}) + + # Resolve model config from name reference + model_config = None + model_config_name = gen_config.get("model_config_name", "xopt_config_generation") + try: + model_config = get_model_config(model_config_name) + if not model_config or not model_config.get("provider"): + logger.debug( + f"Model '{model_config_name}' not configured, falling back to orchestrator" + ) + model_config = get_model_config("orchestrator") + except Exception as e: + logger.warning( + f"Could not load model config '{model_config_name}': {e}, falling back to orchestrator" + ) + model_config = get_model_config("orchestrator") + + return { + "mode": gen_config.get("mode", "mock"), + "model_config": model_config, + "default_algorithm": gen_config.get("default_algorithm", "upper_confidence_bound"), + "default_environment": gen_config.get("default_environment"), + } + + +def _generate_placeholder_config(objective: str, strategy: XOptStrategy) -> dict[str, Any]: + """Generate a placeholder optimization config dict. + + Used when config_generation.mode is "mock". + + DO NOT add accelerator-specific parameters without operator input. + """ + algorithm = "random" if strategy == XOptStrategy.EXPLORATION else "upper_confidence_bound" + return { + "algorithm": algorithm, + "n_iterations": 20, + "note": ( + f"Placeholder config for: {objective} (strategy: {strategy.value}). " + "Set config_generation.mode: 'structured' to use the LLM agent." + ), + } + + +async def _generate_config_with_agent( + objective: str, + strategy: XOptStrategy, + model_config: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Generate config using the structured-output agent. + + Args: + objective: The optimization objective + strategy: The selected strategy + model_config: Model configuration for the agent + + Returns: + Config dict from the agent + """ + from .agent import create_config_generation_agent + + agent = create_config_generation_agent(model_config=model_config) + + try: + return await agent.generate_config( + objective=objective, + strategy=strategy.value, + ) + except Exception as e: + logger.warning(f"Structured agent failed, falling back to mock: {e}") + return _generate_placeholder_config(objective, strategy) + + +async def _resolve_environment( + config: dict[str, Any], node_logger: Any +) -> None: + """Resolve environment_name if missing by asking the user. + + Queries the tuning_scripts API for available environments, then + uses a LangGraph interrupt to present the options and wait for + the user's choice. + + If only one valid environment exists it is auto-selected silently. + + Modifies ``config`` in place. + """ + if config.get("environment_name"): + return + + # Fetch available environments from the API + try: + client = TuningScriptsClient() + environments = await client.list_environments() + except (TuningScriptsAPIError, Exception) as e: + raise ConfigGenerationError( + "No environment_name configured and the tuning_scripts API is unreachable. " + "Set xopt_optimizer.config_generation.default_environment in config.yml " + "or ensure the tuning_scripts API is running.", + generated_config=config, + validation_errors=[f"Missing environment_name, API unreachable: {e}"], + ) from e + + valid_envs = [env for env in environments if env.get("valid", False)] + + if not valid_envs: + available = [env.get("name", "?") for env in environments] + raise ConfigGenerationError( + "No valid optimization environments found on the tuning_scripts API. " + f"Available (invalid): {available}. " + "Configure a valid environment or set default_environment in config.yml.", + generated_config=config, + validation_errors=["No valid environments available"], + ) + + # Single environment β€” auto-select + if len(valid_envs) == 1: + config["environment_name"] = valid_envs[0]["name"] + node_logger.info(f"Auto-selected environment: {valid_envs[0]['name']}") + return + + # Multiple environments β€” ask the user + env_lines = [] + for i, env in enumerate(valid_envs, 1): + source = f" [{env['source']}]" if env.get("source") else "" + desc = env.get("description", "") + env_lines.append(f" {i}. **{env['name']}** β€” {desc}{source}") + + prompt = ( + "Multiple optimization environments are available. " + "Please select one by number or name:\n\n" + + "\n".join(env_lines) + ) + + node_logger.info("Asking user to select optimization environment...") + user_choice = interrupt({"question": prompt, "environments": valid_envs}) + + # Parse the user's response + choice = str(user_choice).strip() + selected = _match_environment(choice, valid_envs) + + if not selected: + raise ConfigGenerationError( + f"Could not match '{choice}' to an available environment. " + f"Valid options: {[e['name'] for e in valid_envs]}", + generated_config=config, + validation_errors=[f"Invalid environment selection: {choice}"], + ) + + config["environment_name"] = selected["name"] + node_logger.info(f"User selected environment: {selected['name']}") + + +def _match_environment( + choice: str, environments: list[dict[str, Any]] +) -> dict[str, Any] | None: + """Match a user's choice (number or name) to an environment. + + Returns the matched environment dict, or None if no match found. + """ + # Try as a 1-based index + try: + idx = int(choice) - 1 + if 0 <= idx < len(environments): + return environments[idx] + except ValueError: + pass + + # Try exact name match + for env in environments: + if env["name"] == choice: + return env + + # Try case-insensitive prefix match + lower = choice.lower() + for env in environments: + if env["name"].lower().startswith(lower): + return env + + return None + + +def _validate_config(config: dict[str, Any]) -> None: + """Validate the generated optimization config. + + Checks that required keys are present and algorithm is in the allowed set. + """ + if not config: + raise ConfigGenerationError( + "Generated config is empty", + generated_config=config, + validation_errors=["Empty configuration"], + ) + + algorithm = config.get("algorithm") + if algorithm and algorithm not in _ALLOWED_ALGORITHMS: + raise ConfigGenerationError( + f"Invalid algorithm: {algorithm}", + generated_config=config, + validation_errors=[f"Algorithm '{algorithm}' not in {sorted(_ALLOWED_ALGORITHMS)}"], + ) + + +def create_config_generation_node(): + """Create the config generation node for LangGraph integration. + + This factory function creates a node that generates optimization config + dicts and prepares approval interrupt data. + + The generation mode is controlled via configuration: + - osprey.xopt_optimizer.config_generation.mode: "mock" | "structured" + + Returns: + Async function that takes XOptExecutionState and returns state updates + """ + + async def config_generation_node(state: XOptExecutionState) -> dict[str, Any]: + """Generate optimization config. + + Supports two modes: + - "mock": Fast placeholder generation for testing (default) + - "structured": LLM with structured output fills in config fields + + Also prepares approval interrupt data following the Python + executor's analyzer pattern. + """ + node_logger = get_logger("xopt_optimizer", state=state) + + # Get configuration + gen_config = _get_config_generation_config() + mode = gen_config.get("mode", "mock") + is_mock = mode == "mock" + mode_indicator = " (mock)" if is_mock else "" + + node_logger.status(f"Generating optimization config{mode_indicator}...") + + # Track generation attempts + attempt = state.get("config_generation_attempt", 0) + 1 + request = state.get("request") + strategy = state.get("selected_strategy", XOptStrategy.EXPLORATION) + objective = request.optimization_objective if request else "Unknown objective" + + try: + # Generate config based on mode + if mode == "structured": + optimization_config = await _generate_config_with_agent( + objective=objective, + strategy=strategy, + model_config=gen_config.get("model_config"), + ) + else: + optimization_config = _generate_placeholder_config(objective, strategy) + + # Apply defaults from config if not already set by the generator + default_env = gen_config.get("default_environment") + if default_env and not optimization_config.get("environment_name"): + optimization_config["environment_name"] = default_env + + default_algo = gen_config.get("default_algorithm") + if default_algo and not optimization_config.get("algorithm"): + optimization_config["algorithm"] = default_algo + + # Resolve environment_name from the API if still missing + await _resolve_environment(optimization_config, node_logger) + + # Validate config + _validate_config(optimization_config) + + node_logger.key_info(f"Optimization config generated{mode_indicator}") + + # Prepare approval interrupt data (following Python executor pattern) + requires_approval = request.require_approval if request else True + + if requires_approval: + from osprey.approval.approval_system import create_xopt_approval_interrupt + + machine_state_details = state.get("machine_state_details") + + approval_interrupt_data = create_xopt_approval_interrupt( + optimization_config=optimization_config, + strategy=strategy.value, + objective=objective, + machine_state_details=machine_state_details, + step_objective=f"Execute XOpt optimization: {objective}", + ) + + return { + "optimization_config": optimization_config, + "config_generation_attempt": attempt, + "config_generation_failed": False, + "requires_approval": True, + "approval_interrupt_data": approval_interrupt_data, + "current_stage": "approval", + } + else: + return { + "optimization_config": optimization_config, + "config_generation_attempt": attempt, + "config_generation_failed": False, + "requires_approval": False, + "current_stage": "execution", + } + + except ConfigGenerationError: + raise + + except Exception as e: + # Re-raise GraphInterrupt β€” it's not an error, it's LangGraph + # pausing the graph to wait for user input (e.g. environment selection). + if e.__class__.__name__ == "GraphInterrupt": + raise + + node_logger.warning(f"Config generation failed: {e}") + + error = XOptError( + error_type="config_generation", + error_message=str(e), + stage="config_generation", + attempt_number=attempt, + ) + error_chain = list(state.get("error_chain", [])) + [error] + + # Check retry limit + max_retries = request.retries if request else 3 + retry_limit_exceeded = len(error_chain) >= max_retries + + return { + "config_generation_attempt": attempt, + "config_generation_failed": True, + "error_chain": error_chain, + "is_failed": retry_limit_exceeded, + "failure_reason": ( + f"Config generation failed after {max_retries} attempts" + if retry_limit_exceeded + else None + ), + "current_stage": "config_gen" if not retry_limit_exceeded else "failed", + } + + return config_generation_node diff --git a/src/osprey/services/xopt_optimizer/decision/node.py b/src/osprey/services/xopt_optimizer/decision/node.py index 220ad4dd4..1aacc03af 100644 --- a/src/osprey/services/xopt_optimizer/decision/node.py +++ b/src/osprey/services/xopt_optimizer/decision/node.py @@ -297,7 +297,7 @@ async def decision_node(state: XOptExecutionState) -> dict[str, Any]: return { "selected_strategy": decision.strategy, "decision_reasoning": decision.reasoning, - "current_stage": "yaml_gen", + "current_stage": "config_gen", } except Exception as e: @@ -322,7 +322,7 @@ async def decision_node(state: XOptExecutionState) -> dict[str, Any]: return { "selected_strategy": decision.strategy, "decision_reasoning": f"Fallback: {decision.reasoning}", - "current_stage": "yaml_gen", + "current_stage": "config_gen", } return decision_node diff --git a/src/osprey/services/xopt_optimizer/exceptions.py b/src/osprey/services/xopt_optimizer/exceptions.py index e9ba68a83..edcbe0cdf 100644 --- a/src/osprey/services/xopt_optimizer/exceptions.py +++ b/src/osprey/services/xopt_optimizer/exceptions.py @@ -7,7 +7,7 @@ Error Categories: - MACHINE_STATE: Machine not ready - may retry after delay - - YAML_GENERATION: Code generation issues - retry with feedback + - CONFIG_GENERATION: Config generation issues - retry with feedback - EXECUTION: XOpt runtime errors - CONFIGURATION: Invalid configuration - WORKFLOW: Service-level workflow issues @@ -21,7 +21,7 @@ class ErrorCategory(StrEnum): """Categorization of errors for retry logic.""" MACHINE_STATE = "machine_state" # Machine not ready - may retry after delay - YAML_GENERATION = "yaml_generation" # Code generation issues - retry with feedback + CONFIG_GENERATION = "config_generation" # Config generation issues - retry with feedback EXECUTION = "execution" # XOpt runtime errors CONFIGURATION = "configuration" # Invalid configuration WORKFLOW = "workflow" # Service-level workflow issues @@ -51,11 +51,11 @@ def __init__( def is_retriable(self) -> bool: """Check if this error type typically warrants a retry.""" - return self.category in (ErrorCategory.MACHINE_STATE, ErrorCategory.YAML_GENERATION) + return self.category in (ErrorCategory.MACHINE_STATE, ErrorCategory.CONFIG_GENERATION) - def should_retry_yaml_generation(self) -> bool: - """Check if YAML should be regenerated.""" - return self.category == ErrorCategory.YAML_GENERATION + def should_retry_config_generation(self) -> bool: + """Check if config should be regenerated.""" + return self.category == ErrorCategory.CONFIG_GENERATION class MachineStateAssessmentError(XOptExecutorException): @@ -78,26 +78,26 @@ def __init__( self.assessment_details = assessment_details or {} -class YamlGenerationError(XOptExecutorException): - """Failed to generate valid XOpt YAML configuration. +class ConfigGenerationError(XOptExecutorException): + """Failed to generate valid optimization configuration. - Raised when the YAML generation agent produces invalid configuration. + Raised when the config generation agent produces invalid configuration. Usually retryable with error feedback. :param message: Error description - :param generated_yaml: The invalid YAML that was generated + :param generated_config: The invalid config that was generated :param validation_errors: List of validation errors found """ def __init__( self, message: str, - generated_yaml: str | None = None, + generated_config: dict | None = None, validation_errors: list[str] | None = None, **kwargs, ): - super().__init__(message, category=ErrorCategory.YAML_GENERATION, **kwargs) - self.generated_yaml = generated_yaml + super().__init__(message, category=ErrorCategory.CONFIG_GENERATION, **kwargs) + self.generated_config = generated_config self.validation_errors = validation_errors or [] @@ -107,19 +107,19 @@ class XOptExecutionError(XOptExecutorException): Raised when XOpt itself fails during execution. :param message: Error description - :param yaml_used: The YAML configuration that was used + :param config_used: The optimization config that was used :param xopt_error: The original XOpt error message """ def __init__( self, message: str, - yaml_used: str | None = None, + config_used: dict | None = None, xopt_error: str | None = None, **kwargs, ): super().__init__(message, category=ErrorCategory.EXECUTION, **kwargs) - self.yaml_used = yaml_used + self.config_used = config_used self.xopt_error = xopt_error diff --git a/src/osprey/services/xopt_optimizer/execution/__init__.py b/src/osprey/services/xopt_optimizer/execution/__init__.py index 1d154ac87..d4a07288d 100644 --- a/src/osprey/services/xopt_optimizer/execution/__init__.py +++ b/src/osprey/services/xopt_optimizer/execution/__init__.py @@ -1,13 +1,11 @@ """Execution Subsystem for XOpt Optimizer. -This subsystem executes XOpt optimization runs using the generated -YAML configuration. - -PLACEHOLDER: Current implementation is a no-op placeholder. -Actual XOpt execution will be implemented when XOpt prototype -integration is ready. +This subsystem executes XOpt optimization runs by submitting optimization +configs to the tuning_scripts API and polling for results. +Falls back to placeholder execution when the API is unavailable. """ +from .api_client import TuningScriptsAPIError, TuningScriptsClient from .node import create_executor_node -__all__ = ["create_executor_node"] +__all__ = ["TuningScriptsAPIError", "TuningScriptsClient", "create_executor_node"] diff --git a/src/osprey/services/xopt_optimizer/execution/api_client.py b/src/osprey/services/xopt_optimizer/execution/api_client.py new file mode 100644 index 000000000..030069107 --- /dev/null +++ b/src/osprey/services/xopt_optimizer/execution/api_client.py @@ -0,0 +1,242 @@ +"""HTTP client for the tuning_scripts optimization API. + +Provides an async interface (aiohttp) for submitting Xopt YAML configurations, +polling for results, and controlling optimization jobs via the tuning_scripts +REST API. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +try: + import aiohttp +except ImportError as e: + raise ImportError( + "aiohttp is required for TuningScriptsClient. " + "Install it with: pip install aiohttp" + ) from e + +from osprey.utils.config import get_full_configuration +from osprey.utils.logger import get_logger + +logger = get_logger("xopt_optimizer") + +# Terminal statuses that indicate the job is done +_TERMINAL_STATUSES = frozenset({"completed", "error", "cancelled"}) + + +class TuningScriptsAPIError(Exception): + """Error communicating with the tuning_scripts API.""" + + def __init__(self, message: str, status_code: int | None = None, detail: str | None = None): + self.status_code = status_code + self.detail = detail + full_msg = message + if status_code: + full_msg = f"[HTTP {status_code}] {message}" + if detail: + full_msg = f"{full_msg} β€” {detail}" + super().__init__(full_msg) + + +class TuningScriptsClient: + """Async HTTP client for the tuning_scripts optimization API. + + Configuration is read from the ``xopt_optimizer.api`` section of + Osprey's config.yml:: + + xopt_optimizer: + api: + base_url: "http://tuning-api:8001" + poll_interval_seconds: 5.0 + timeout_seconds: 3600 + + The client can also be instantiated directly with explicit parameters. + """ + + def __init__( + self, + base_url: str | None = None, + poll_interval_seconds: float | None = None, + timeout_seconds: float | None = None, + ): + api_config = self._load_api_config() + + self.base_url = (base_url or api_config.get("base_url", "http://localhost:8001")).rstrip("/") + self.poll_interval = poll_interval_seconds or api_config.get("poll_interval_seconds", 5.0) + self.timeout = timeout_seconds or api_config.get("timeout_seconds", 3600) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def health_check(self) -> dict[str, Any]: + """Check API health. + + Returns: + Health status dict (e.g. ``{"status": "ok"}``). + + Raises: + TuningScriptsAPIError: If the API is unreachable or unhealthy. + """ + return await self._get("/health") + + async def list_environments(self) -> list[dict[str, Any]]: + """List available optimization environments. + + Returns: + List of environment dicts with ``name``, ``display_name``, + ``description``, ``valid``, and ``source`` fields. + + Raises: + TuningScriptsAPIError: If the API is unreachable. + """ + return await self._get("/environments") + + async def submit_config(self, config: dict[str, Any]) -> str: + """Submit an OptimizationConfig dict to start an optimization. + + Args: + config: Optimization config dict (environment_name, algorithm, etc.). + + Returns: + The ``job_id`` of the submitted job. + + Raises: + TuningScriptsAPIError: On submission failure. + """ + data = await self._post("/optimization/start", json=config) + return data["job_id"] + + async def submit_yaml(self, yaml_config: str, n_iterations: int | None = None) -> str: + """Submit an Xopt YAML configuration to start an optimization. + + Args: + yaml_config: Raw Xopt YAML string. + n_iterations: Optional iteration count override. + + Returns: + The ``job_id`` of the submitted job. + + Raises: + TuningScriptsAPIError: On submission failure. + """ + payload: dict[str, Any] = {"yaml_config": yaml_config} + if n_iterations is not None: + payload["n_iterations"] = n_iterations + + data = await self._post("/optimization/start-yaml", json=payload) + return data["job_id"] + + async def get_status(self, job_id: str) -> dict[str, Any]: + """Get lightweight job status. + + Returns: + Dict with ``job_id``, ``status``, ``message``, ``results_path``. + """ + return await self._get(f"/optimization/{job_id}", params={"detail": "summary"}) + + async def get_full_state(self, job_id: str) -> dict[str, Any]: + """Get full optimization state including data. + + Returns: + Dict with full state: data records, logs, variable names, etc. + """ + return await self._get(f"/optimization/{job_id}", params={"detail": "full"}) + + async def poll_until_complete(self, job_id: str) -> dict[str, Any]: + """Poll job status until a terminal state is reached, then fetch full state. + + Args: + job_id: The job to poll. + + Returns: + Full optimization state from ``get_full_state``. + + Raises: + TuningScriptsAPIError: On timeout or if the job ends in error. + """ + elapsed = 0.0 + while elapsed < self.timeout: + status_resp = await self.get_status(job_id) + job_status = status_resp.get("status", "unknown") + + if job_status in _TERMINAL_STATUSES: + full_state = await self.get_full_state(job_id) + if job_status == "error": + error_msg = full_state.get("message") or "Optimization failed" + raise TuningScriptsAPIError( + f"Optimization job {job_id} failed: {error_msg}", + detail=error_msg, + ) + return full_state + + logger.info(f"Job {job_id} status: {job_status} (elapsed: {elapsed:.0f}s)") + await asyncio.sleep(self.poll_interval) + elapsed += self.poll_interval + + raise TuningScriptsAPIError( + f"Timeout waiting for job {job_id} after {self.timeout}s" + ) + + async def cancel(self, job_id: str) -> dict[str, Any]: + """Cancel a running optimization job.""" + return await self._post(f"/optimization/{job_id}/cancel") + + async def pause(self, job_id: str) -> dict[str, Any]: + """Pause a running optimization job.""" + return await self._post(f"/optimization/{job_id}/pause") + + async def resume(self, job_id: str) -> dict[str, Any]: + """Resume a paused optimization job.""" + return await self._post(f"/optimization/{job_id}/resume") + + # ------------------------------------------------------------------ + # HTTP helpers + # ------------------------------------------------------------------ + + async def _get(self, path: str, params: dict | None = None) -> Any: + url = f"{self.base_url}{path}" + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params) as resp: + return await self._handle_response(resp) + except aiohttp.ClientError as e: + raise TuningScriptsAPIError(f"Connection error: {e}") from e + + async def _post(self, path: str, json: dict | None = None) -> Any: + url = f"{self.base_url}{path}" + try: + async with aiohttp.ClientSession() as session: + async with session.post(url, json=json) as resp: + return await self._handle_response(resp) + except aiohttp.ClientError as e: + raise TuningScriptsAPIError(f"Connection error: {e}") from e + + async def _handle_response(self, resp: aiohttp.ClientResponse) -> Any: + if resp.status >= 400: + try: + body = await resp.json() + detail = body.get("detail", str(body)) + except Exception: + detail = await resp.text() + raise TuningScriptsAPIError( + f"API request failed: {resp.method} {resp.url}", + status_code=resp.status, + detail=detail, + ) + return await resp.json() + + # ------------------------------------------------------------------ + # Config + # ------------------------------------------------------------------ + + @staticmethod + def _load_api_config() -> dict[str, Any]: + try: + config = get_full_configuration() + return config.get("xopt_optimizer", {}).get("api", {}) + except Exception: + return {} diff --git a/src/osprey/services/xopt_optimizer/execution/node.py b/src/osprey/services/xopt_optimizer/execution/node.py index 0b9442c32..bf0c0b758 100644 --- a/src/osprey/services/xopt_optimizer/execution/node.py +++ b/src/osprey/services/xopt_optimizer/execution/node.py @@ -1,16 +1,14 @@ """Execution Node for XOpt Optimizer Service. -This node executes XOpt optimization runs using the generated YAML configuration. +This node executes XOpt optimization runs by submitting the generated optimization +config to the tuning_scripts API and polling for results. -PLACEHOLDER: This implementation is a no-op that returns placeholder results. - -TODO: Replace with actual XOpt prototype integration when ready. -This will require: -- Integration with existing XOpt Python prototype -- Proper error handling for XOpt execution failures -- Result artifact capture - -DO NOT add accelerator-specific execution logic without operator input. +The tuning_scripts API (FastAPI + Redis + Xopt) handles the actual optimization +execution. This node acts as an HTTP client that: +1. Health-checks the API +2. Submits the optimization config via POST /optimization/start +3. Polls until completion via GET /optimization/{job_id} +4. Returns the full result state as the run_artifact ## Badger/XOpt Environment Integration @@ -51,34 +49,23 @@ async def set_values(self, channel_inputs): from osprey.utils.logger import get_logger from ..models import XOptExecutionState +from .api_client import TuningScriptsAPIError, TuningScriptsClient logger = get_logger("xopt_optimizer") -async def _run_xopt_placeholder(yaml_config: str) -> dict[str, Any]: - """Placeholder for XOpt execution. - - PLACEHOLDER: Returns mock results. +async def _run_xopt_placeholder(config: dict[str, Any]) -> dict[str, Any]: + """Placeholder for XOpt execution (used when API is unavailable). - TODO: Replace with actual XOpt prototype integration. - This will involve: - - Parsing the YAML configuration - - Creating a Badger Environment with OspreyInterface (see module docstring) - - Setting up XOpt with proper generator and evaluator - - Running the optimization loop - - Capturing results and artifacts - - The Environment defines variables/observables; the OspreyInterface - bridges to Osprey's ConnectorFactory for control system access. + Returns mock results for testing without a running tuning_scripts API. """ return { "status": "completed", "evaluations": 0, "best_value": None, "best_parameters": {}, - "yaml_used": yaml_config, - "note": "This is a placeholder result. Actual XOpt execution will be " - "implemented when XOpt prototype integration is ready.", + "config_used": config, + "note": "This is a placeholder result. The tuning_scripts API was not available.", } @@ -86,27 +73,82 @@ def create_executor_node(): """Create the execution node for LangGraph integration. This factory function creates a node that executes XOpt optimization - runs. Currently implements a placeholder. + runs by submitting YAML to the tuning_scripts API. Falls back to + placeholder execution if the API is unreachable. Returns: Async function that takes XOptExecutionState and returns state updates """ async def executor_node(state: XOptExecutionState) -> dict[str, Any]: - """Execute XOpt optimization. - - PLACEHOLDER: Returns mock results. - """ + """Execute XOpt optimization via the tuning_scripts API.""" node_logger = get_logger("xopt_optimizer", state=state) node_logger.status("Executing XOpt optimization...") - yaml_config = state.get("generated_yaml") + optimization_config = state.get("optimization_config") + + if not optimization_config: + node_logger.error("No optimization config available for execution") + return { + "execution_error": "No optimization config generated", + "execution_failed": True, + "is_failed": True, + "failure_reason": "Missing optimization config for execution", + "current_stage": "failed", + } + + client = TuningScriptsClient() try: - # PLACEHOLDER: Call placeholder XOpt execution - run_artifact = await _run_xopt_placeholder(yaml_config) + # 1. Health check + node_logger.info("Checking tuning_scripts API health...") + await client.health_check() + node_logger.info("API health check passed") + + # 2. Submit config + node_logger.info("Submitting optimization config to tuning_scripts API...") + job_id = await client.submit_config(optimization_config) + node_logger.info(f"Optimization job submitted: {job_id}") + + # 3. Poll until complete + node_logger.info(f"Polling job {job_id} for results...") + full_state = await client.poll_until_complete(job_id) + node_logger.info(f"Job {job_id} completed with status: {full_state.get('status')}") + + # 4. Build run_artifact + run_artifact = { + "job_id": job_id, + "status": full_state.get("status", "unknown"), + "data": full_state.get("data"), + "environment_name": full_state.get("environment_name"), + "objective_name": full_state.get("objective_name"), + "variable_names": full_state.get("variable_names"), + "results_path": full_state.get("results_path"), + "config_used": optimization_config, + "logs": full_state.get("logs", ""), + } + + return { + "run_artifact": run_artifact, + "execution_failed": False, + "current_stage": "analysis", + } - node_logger.info("XOpt execution completed") + except TuningScriptsAPIError as e: + if e.status_code is not None: + # API returned an error response β€” real failure + node_logger.error(f"Tuning scripts API error: {e}") + return { + "execution_error": str(e), + "execution_failed": True, + "is_failed": True, + "failure_reason": f"Tuning scripts API error: {e}", + "current_stage": "failed", + } + + # Connection error (API not running) β€” fall back to placeholder + node_logger.warning(f"API unreachable ({e}), falling back to placeholder execution") + run_artifact = await _run_xopt_placeholder(optimization_config) return { "run_artifact": run_artifact, "execution_failed": False, diff --git a/src/osprey/services/xopt_optimizer/models.py b/src/osprey/services/xopt_optimizer/models.py index 01effb034..92a922991 100644 --- a/src/osprey/services/xopt_optimizer/models.py +++ b/src/osprey/services/xopt_optimizer/models.py @@ -107,7 +107,7 @@ class XOptError: Captures error context to help subsequent nodes understand what failed and potentially adjust their approach. - :param error_type: Category of error (state_assessment, yaml_generation, execution, analysis) + :param error_type: Category of error (state_assessment, config_generation, execution, analysis) :param error_message: Human-readable error message :param stage: Pipeline stage where error occurred :param attempt_number: Which attempt this error occurred in @@ -188,7 +188,7 @@ class XOptServiceResult: On failure, the service raises appropriate exceptions. :param run_artifact: Optimization run output data - :param generated_yaml: XOpt YAML configuration used + :param optimization_config: Optimization config dict submitted to tuning_scripts :param strategy: Strategy used (exploration/optimization) :param total_iterations: Number of iterations completed :param analysis_summary: Summary of optimization analysis @@ -196,7 +196,7 @@ class XOptServiceResult: """ run_artifact: dict[str, Any] - generated_yaml: str + optimization_config: dict[str, Any] strategy: XOptStrategy total_iterations: int analysis_summary: dict[str, Any] @@ -232,7 +232,7 @@ class XOptExecutionState(TypedDict): # Error tracking (matches Python executor pattern) error_chain: list[XOptError] - yaml_generation_attempt: int # For YAML regeneration retries + config_generation_attempt: int # For config regeneration retries # Machine state assessment machine_state: MachineState | None @@ -242,9 +242,9 @@ class XOptExecutionState(TypedDict): selected_strategy: XOptStrategy | None decision_reasoning: str | None - # YAML configuration - generated_yaml: str | None - yaml_generation_failed: bool | None + # Optimization configuration + optimization_config: dict[str, Any] | None + config_generation_failed: bool | None # Approval state (standard Osprey pattern) requires_approval: bool | None @@ -270,4 +270,4 @@ class XOptExecutionState(TypedDict): is_successful: bool is_failed: bool failure_reason: str | None - current_stage: str # "state_id", "decision", "yaml_gen", "approval", "execution", "analysis", "complete", "failed" + current_stage: str # "state_id", "decision", "config_gen", "approval", "execution", "analysis", "complete", "failed" diff --git a/src/osprey/services/xopt_optimizer/service.py b/src/osprey/services/xopt_optimizer/service.py index 01f3cdf7a..903ba45d4 100644 --- a/src/osprey/services/xopt_optimizer/service.py +++ b/src/osprey/services/xopt_optimizer/service.py @@ -7,7 +7,7 @@ The service implements a multi-stage workflow: 1. State Identification - Assess machine readiness 2. Decision - Select optimization strategy -3. YAML Generation - Create XOpt configuration +3. Config Generation - Create optimization config dict 4. Approval - Human approval of configuration 5. Execution - Run XOpt optimization 6. Analysis - Analyze results and decide continuation @@ -30,6 +30,7 @@ from .analysis import create_analysis_node from .approval import create_approval_node +from .config_generation import create_config_generation_node from .decision import create_decision_node from .exceptions import XOptExecutionError from .execution import create_executor_node @@ -40,7 +41,6 @@ XOptStrategy, ) from .state_identification import create_state_identification_node -from .yaml_generation import create_yaml_generation_node logger = get_logger("xopt_optimizer") @@ -92,6 +92,8 @@ async def ainvoke(self, input_data, config): XOptExecutionError: If optimization fails TypeError: If input_data is not a supported type """ + config = self._inject_application_config(config) + if isinstance(input_data, Command): # This is a resume command (approval response) if hasattr(input_data, "resume") and input_data.resume: @@ -149,16 +151,16 @@ def _create_initial_state(self, request: XOptExecutionRequest) -> XOptExecutionS capability_context_data=request.capability_context_data, # Error tracking error_chain=[], - yaml_generation_attempt=0, + config_generation_attempt=0, # Machine state machine_state=None, machine_state_details=None, # Decision selected_strategy=None, decision_reasoning=None, - # YAML - generated_yaml=None, - yaml_generation_failed=None, + # Config + optimization_config=None, + config_generation_failed=None, # Approval requires_approval=None, approval_interrupt_data=None, @@ -196,7 +198,7 @@ def _build_and_compile_graph(self): # Add nodes workflow.add_node("state_identification", create_state_identification_node()) workflow.add_node("decision", create_decision_node()) - workflow.add_node("yaml_generation", create_yaml_generation_node()) + workflow.add_node("config_generation", create_config_generation_node()) workflow.add_node("approval", create_approval_node()) workflow.add_node("execution", create_executor_node()) workflow.add_node("analysis", create_analysis_node()) @@ -208,16 +210,16 @@ def _build_and_compile_graph(self): workflow.add_conditional_edges( "decision", self._decision_router, - {"continue": "yaml_generation", "abort": "__end__"}, + {"continue": "config_generation", "abort": "__end__"}, ) workflow.add_conditional_edges( - "yaml_generation", - self._yaml_generation_router, + "config_generation", + self._config_generation_router, { "approve": "approval", "execute": "execution", - "retry": "yaml_generation", + "retry": "config_generation", "__end__": "__end__", }, ) @@ -258,8 +260,8 @@ def _decision_router(self, state: XOptExecutionState) -> str: return "abort" return "continue" - def _yaml_generation_router(self, state: XOptExecutionState) -> str: - """Route after YAML generation. + def _config_generation_router(self, state: XOptExecutionState) -> str: + """Route after config generation. Args: state: Current execution state @@ -269,7 +271,7 @@ def _yaml_generation_router(self, state: XOptExecutionState) -> str: """ if state.get("is_failed"): return "__end__" - if state.get("yaml_generation_failed"): + if state.get("config_generation_failed"): return "retry" if state.get("requires_approval"): return "approve" @@ -328,7 +330,7 @@ def _create_service_result(self, result: dict) -> XOptServiceResult: recommendations = result.get("recommendations") or [] return XOptServiceResult( run_artifact=result.get("run_artifact", {}), - generated_yaml=result.get("generated_yaml", ""), + optimization_config=result.get("optimization_config", {}), strategy=result.get("selected_strategy", XOptStrategy.EXPLORATION), total_iterations=result.get("iteration_count", 0), analysis_summary=result.get("analysis_result", {}), @@ -360,6 +362,32 @@ def _create_checkpointer(self): logger.info("XOpt optimizer service using in-memory checkpointer") return create_memory_checkpointer() + def _inject_application_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Inject application config into the LangGraph configurable. + + LangGraph nodes access configuration via get_config(), which returns the + runtime configurable dict. The caller typically only passes runtime keys + (thread_id, checkpoint_ns). This method merges in the application config + sections that the XOpt nodes need (xopt_optimizer, model_configs, etc.), + following the same pattern used by the main Osprey pipeline. + + Existing runtime keys are preserved and take precedence. + """ + configurable = config.setdefault("configurable", {}) + + # Sections required by XOpt nodes (get_xopt_optimizer_config, + # get_model_config, get_provider_config, get_full_configuration) + for key in ( + "xopt_optimizer", + "model_configs", + "provider_configs", + "project_root", + ): + if key not in configurable and key in self.config: + configurable[key] = self.config[key] + + return config + def _load_config(self) -> dict[str, Any]: """Load service configuration. diff --git a/src/osprey/services/xopt_optimizer/yaml_generation/__init__.py b/src/osprey/services/xopt_optimizer/yaml_generation/__init__.py deleted file mode 100644 index 1f76a4ab6..000000000 --- a/src/osprey/services/xopt_optimizer/yaml_generation/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -"""YAML Generation Subsystem for XOpt Optimizer. - -This subsystem generates XOpt YAML configurations using either: -1. ReAct mode (default): Agent-based generation that dynamically adapts: - - If example files exist: Agent reads them and learns patterns - - If no examples: Agent generates from built-in XOpt knowledge -2. Mock mode: Placeholder YAML for quick testing (use for fast iteration) - -The mode is controlled via configuration: - osprey.xopt_optimizer.yaml_generation.mode: "react" | "mock" - -When using ReAct mode with examples, place YAML files in: - osprey.xopt_optimizer.yaml_generation.examples_path: "path/to/yamls" - -Example files are optional - the agent adapts its behavior based on availability. -""" - -from .agent import YamlGenerationAgent, create_yaml_generation_agent -from .node import create_yaml_generation_node - -__all__ = [ - "create_yaml_generation_node", - "YamlGenerationAgent", - "create_yaml_generation_agent", -] diff --git a/src/osprey/services/xopt_optimizer/yaml_generation/agent.py b/src/osprey/services/xopt_optimizer/yaml_generation/agent.py deleted file mode 100644 index c33557277..000000000 --- a/src/osprey/services/xopt_optimizer/yaml_generation/agent.py +++ /dev/null @@ -1,488 +0,0 @@ -"""ReAct Agent for XOpt YAML Configuration Generation. - -This module provides a ReAct agent that generates XOpt YAML configurations. -The agent dynamically adapts based on whether example files are available: - -- **With examples**: Agent gets file reading tools and is instructed to - learn from historical configurations before generating new ones. -- **Without examples**: Agent generates YAML from its built-in knowledge - of XOpt configuration patterns. - -This design avoids requiring pre-created example files while still -benefiting from them when available. -""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any - -from langchain_core.tools import tool -from langgraph.prebuilt import create_react_agent - -from osprey.models.langchain import get_langchain_model -from osprey.utils.logger import get_logger - -logger = get_logger("xopt_optimizer") - - -# ============================================================================= -# DYNAMIC PROMPTS -# ============================================================================= - -# Prompt when example files ARE available -PROMPT_WITH_EXAMPLES = """You are an expert XOpt configuration generator for accelerator optimization. - -You have access to example XOpt YAML configurations that you should read and learn from. - -## Your Workflow - -1. **READ EXAMPLES FIRST**: Use the `list_yaml_files` tool to see what examples are available, - then use `read_yaml_file` to read them. Study the structure carefully: - - Variable definitions (types, bounds) - - Objective specifications - - Generator selection patterns - - Comments explaining configuration choices - -2. **Understand the Objective**: Parse the user's optimization request to understand: - - What they want to optimize - - What strategy is appropriate (exploration vs optimization) - - Any constraints mentioned - -3. **Generate Configuration**: Create a valid XOpt YAML based on: - - The patterns you learned from examples - - The specific optimization objective - - Best practices for XOpt - -## Output Format - -Your final output MUST be a complete, valid YAML configuration wrapped in ```yaml``` code blocks. -Include comments explaining key configuration choices. - -## Important Notes - -- Always read the examples first - they show the expected structure -- Use placeholder names (param_1, param_2, objective_1) unless the user provides specific names -- Do NOT invent specific accelerator channel names or parameters -- Always include: generator, evaluator, vocs sections -""" - -# Prompt when NO example files are available -PROMPT_WITHOUT_EXAMPLES = """You are an expert XOpt configuration generator for accelerator optimization. - -No example configurations are available, so you must generate YAML from your knowledge of XOpt. - -## XOpt Configuration Structure - -A valid XOpt YAML configuration includes: - -```yaml -# Generator - how to sample new points -generator: - name: random # Options: random, latin_hypercube, sobol, bayesian - -# Evaluator - how to assess points -evaluator: - function: objective_function_name - -# VOCS - Variables, Objectives, Constraints, Statics -vocs: - variables: - param_name: - type: continuous # or discrete, ordinal - lower: 0.0 - upper: 10.0 - objectives: - objective_name: - type: minimize # or maximize - constraints: {} - statics: {} - -# Runtime settings -n_initial: 5 -max_evaluations: 20 -``` - -## Your Workflow - -1. **Understand the Objective**: Parse the user's optimization request -2. **Select Generator**: Based on strategy (exploration β†’ random/latin_hypercube, optimization β†’ bayesian) -3. **Define Variables**: Create placeholder variables with reasonable defaults -4. **Define Objectives**: Based on what user wants to optimize -5. **Generate YAML**: Complete, valid configuration - -## Output Format - -Your final output MUST be a complete, valid YAML configuration wrapped in ```yaml``` code blocks. -Include comments explaining key configuration choices. - -## Important Notes - -- Use placeholder names (param_1, param_2, objective_1) unless the user provides specific names -- Do NOT invent specific accelerator channel names or parameters -- Include reasonable default bounds (e.g., 0.0 to 10.0 for continuous variables) -- Always include: generator, evaluator, vocs sections -""" - - -# ============================================================================= -# FILE TOOLS (only created when examples exist) -# ============================================================================= - - -def _create_file_tools(examples_path: Path) -> list[Any]: - """Create file reading tools for the agent. - - These tools are only created when example files exist. - - Args: - examples_path: Path to directory containing example YAML files - - Returns: - List of LangChain tools for file operations - """ - - @tool - def list_yaml_files() -> str: - """List available YAML configuration files in the examples directory. - - Use this tool first to see what examples are available before reading them. - - Returns: - List of available YAML files with brief descriptions from their first comment. - """ - yaml_files = list(examples_path.glob("**/*.yaml")) + list(examples_path.glob("**/*.yml")) - - if not yaml_files: - return "No YAML files found in examples directory." - - results = ["Available YAML configurations:"] - for yaml_file in yaml_files: - rel_path = yaml_file.relative_to(examples_path) - # Try to extract description from first comment line - try: - first_lines = yaml_file.read_text(encoding="utf-8").split("\n")[:5] - description = "" - for line in first_lines: - if line.startswith("#") and not line.startswith("# ="): - description = line.lstrip("# ").strip() - break - if description: - results.append(f" - {rel_path}: {description}") - else: - results.append(f" - {rel_path}") - except Exception: - results.append(f" - {rel_path}") - - return "\n".join(results) - - @tool - def read_yaml_file(filename: str) -> str: - """Read the contents of a YAML configuration file. - - Use this after listing files to read specific examples and learn their structure. - - Args: - filename: Name of the YAML file to read (e.g., 'exploration_basic.yaml') - - Returns: - Contents of the YAML file, or error message if not found. - """ - # Security: only allow reading from examples directory - file_path = examples_path / filename - - # Check for path traversal attacks - try: - file_path = file_path.resolve() - examples_resolved = examples_path.resolve() - if not str(file_path).startswith(str(examples_resolved)): - return "Error: Cannot read files outside examples directory." - except Exception: - return "Error: Invalid file path." - - if not file_path.exists(): - # Try searching subdirectories - matches = list(examples_path.glob(f"**/{filename}")) - if matches: - file_path = matches[0] - else: - return f"Error: File '{filename}' not found. Use list_yaml_files to see available files." - - try: - content = file_path.read_text(encoding="utf-8") - return f"=== {filename} ===\n{content}" - except Exception as e: - return f"Error reading file: {e}" - - return [list_yaml_files, read_yaml_file] - - -# ============================================================================= -# AGENT CLASS -# ============================================================================= - - -class YamlGenerationAgent: - """ReAct agent for generating XOpt YAML configurations. - - This agent dynamically adapts based on whether example files are available: - - With examples: Gets file tools and prompt to read examples first - - Without examples: Generates from knowledge with appropriate prompt - - Attributes: - examples_path: Path to directory containing example YAML files (optional) - model_config: Configuration for the LLM model to use - """ - - def __init__( - self, - examples_path: str | Path | None = None, - model_config: dict[str, Any] | None = None, - ): - """Initialize the YAML generation agent. - - Args: - examples_path: Path to directory containing example YAML files. - If None or directory doesn't exist/is empty, agent generates from knowledge. - model_config: Optional model configuration. If not provided, - uses the 'fast' model from osprey config. - """ - self.examples_path = Path(examples_path) if examples_path else None - self.model_config = model_config - self._agent = None - self._has_examples = False - - def _check_examples_exist(self) -> bool: - """Check if example YAML files exist. - - Returns: - True if examples directory exists and contains YAML files - """ - if not self.examples_path: - return False - - if not self.examples_path.exists(): - return False - - yaml_files = list(self.examples_path.glob("**/*.yaml")) + list( - self.examples_path.glob("**/*.yml") - ) - return len(yaml_files) > 0 - - def _get_tools(self) -> list[Any]: - """Get tools for the agent based on file availability. - - Returns: - List of tools (file tools if examples exist, empty otherwise) - """ - if self._has_examples and self.examples_path: - return _create_file_tools(self.examples_path) - else: - return [] - - def _get_prompt(self) -> str: - """Get the appropriate system prompt based on file availability. - - Returns: - System prompt string - """ - if self._has_examples: - return PROMPT_WITH_EXAMPLES - else: - return PROMPT_WITHOUT_EXAMPLES - - def _get_model(self): - """Get the LangChain model for the agent. - - Uses model_config provided during initialization. - The node.py handles fallback to orchestrator model if xopt-specific - model is not configured. - - Returns: - LangChain BaseChatModel instance - - Raises: - ValueError: If no model_config is available - """ - if self.model_config: - return get_langchain_model(model_config=self.model_config) - - # This shouldn't happen if node.py fallback is working - raise ValueError( - "No model_config provided to YamlGenerationAgent. " - "Ensure xopt_optimizer.yaml_generation.model_config_name is set in config.yml " - "or that 'orchestrator' model is configured as fallback." - ) - - def _get_agent(self): - """Get or create the ReAct agent with dynamic configuration. - - Returns: - Compiled ReAct agent graph - """ - if self._agent is None: - # Check for examples at agent creation time - self._has_examples = self._check_examples_exist() - - model = self._get_model() - tools = self._get_tools() - - # Create agent with or without tools - self._agent = create_react_agent( - model=model, - tools=tools, - ) - - return self._agent - - async def generate_yaml( - self, - objective: str, - strategy: str, - additional_context: dict[str, Any] | None = None, - ) -> str: - """Generate XOpt YAML configuration using the ReAct agent. - - Args: - objective: The optimization objective (e.g., "maximize injection efficiency") - strategy: The selected strategy ("exploration" or "optimization") - additional_context: Optional additional context to include in the prompt - - Returns: - Generated YAML configuration as a string - - Raises: - ValueError: If YAML generation fails or produces invalid output - """ - agent = self._get_agent() - - # Build the user message - user_message = f"""Generate an XOpt YAML configuration for the following: - -**Optimization Objective:** {objective} -**Strategy:** {strategy} - -{ - "First, use the tools to read available example configurations. " - if self._has_examples - else "" - }Generate a complete, valid YAML configuration based on { - "what you learn from the examples" - if self._has_examples - else "your knowledge of XOpt configuration patterns" - }. - -Remember: -- Use generic parameter names unless specific names are provided -- Include comments explaining your configuration choices -- Output the final YAML in ```yaml``` code blocks -""" - - if additional_context: - user_message += f"\n**Additional Context:** {additional_context}" - - # Run the agent - logger.info("Starting YAML generation agent...") - - try: - result = await agent.ainvoke( - { - "messages": [ - {"role": "system", "content": self._get_prompt()}, - {"role": "user", "content": user_message}, - ] - } - ) - - # Extract the final response - messages = result.get("messages", []) - if not messages: - raise ValueError("Agent did not produce any output") - - # Get the last message content - last_message = messages[-1] - content = ( - last_message.content if hasattr(last_message, "content") else str(last_message) - ) - - # Extract YAML from response - yaml_content = self._extract_yaml(content) - - if not yaml_content: - logger.warning(f"Could not extract YAML from response. Response: {content[:500]}") - raise ValueError("Agent did not produce valid YAML output") - - logger.info(f"YAML generation complete: {len(yaml_content)} characters") - return yaml_content - - except Exception as e: - logger.error(f"YAML generation agent failed: {e}") - raise ValueError(f"YAML generation failed: {e}") from e - - def _extract_yaml(self, content: str) -> str | None: - """Extract YAML content from agent response. - - Args: - content: The agent's response text - - Returns: - Extracted YAML content or None if not found - """ - import re - - # Try to find YAML code blocks - yaml_pattern = r"```yaml\n(.*?)```" - matches = re.findall(yaml_pattern, content, re.DOTALL) - - if matches: - return matches[-1].strip() - - # Try generic code blocks - code_pattern = r"```\n(.*?)```" - matches = re.findall(code_pattern, content, re.DOTALL) - - for match in matches: - # Check if it looks like YAML - if "generator:" in match or "vocs:" in match or "evaluator:" in match: - return match.strip() - - # If no code blocks, check if the whole response is YAML-like - if "generator:" in content and "vocs:" in content: - # Try to extract just the YAML part - lines = content.split("\n") - yaml_lines = [] - in_yaml = False - - for line in lines: - if line.strip().startswith(("#", "generator:", "evaluator:", "vocs:")): - in_yaml = True - if in_yaml: - yaml_lines.append(line) - - if yaml_lines: - return "\n".join(yaml_lines).strip() - - return None - - -def create_yaml_generation_agent( - examples_path: str | Path | None = None, - model_config: dict[str, Any] | None = None, -) -> YamlGenerationAgent: - """Factory function to create a YAML generation agent. - - The agent dynamically adapts based on whether example files exist: - - If examples_path has YAML files: Agent gets tools to read them - - If no examples: Agent generates from its built-in knowledge - - Args: - examples_path: Path to directory containing example YAML files (optional) - model_config: Optional model configuration - - Returns: - Configured YamlGenerationAgent instance - """ - return YamlGenerationAgent( - examples_path=examples_path, - model_config=model_config, - ) diff --git a/src/osprey/services/xopt_optimizer/yaml_generation/node.py b/src/osprey/services/xopt_optimizer/yaml_generation/node.py deleted file mode 100644 index 9291e41cb..000000000 --- a/src/osprey/services/xopt_optimizer/yaml_generation/node.py +++ /dev/null @@ -1,308 +0,0 @@ -"""YAML Generation Node for XOpt Optimizer Service. - -This node generates XOpt YAML configurations and prepares approval interrupt data. -It follows the Python executor's analyzer pattern where the node that generates -content also creates the approval interrupt data. - -Supports two modes (configured via osprey.xopt_optimizer.yaml_generation.mode): -- "react": ReAct agent generates YAML (default) - adapts based on file availability: - - If example files exist: Agent reads them and learns patterns - - If no examples: Agent generates from built-in XOpt knowledge -- "mock": Placeholder YAML for quick testing (use for fast iteration) - -Example YAML files are optional. If provided, place them in: - osprey.xopt_optimizer.yaml_generation.examples_path: "path/to/yamls" - -DO NOT add accelerator-specific YAML parameters without operator input. -""" - -from pathlib import Path -from typing import Any - -from osprey.utils.config import get_model_config, get_xopt_optimizer_config -from osprey.utils.logger import get_logger - -from ..exceptions import YamlGenerationError -from ..models import XOptError, XOptExecutionState, XOptStrategy - -logger = get_logger("xopt_optimizer") - -# Default path for example YAML files (relative to working directory) -DEFAULT_EXAMPLES_PATH = "_agent_data/xopt_examples/yaml_templates" - - -def _get_yaml_generation_config() -> dict[str, Any]: - """Get YAML generation configuration from osprey config. - - Reads from config structure: - xopt_optimizer: - yaml_generation: - mode: "react" - examples_path: "..." - model_config_name: "xopt_yaml_generation" # References models section - - Returns: - Configuration dict with mode, examples_path, and model_config - """ - xopt_config = get_xopt_optimizer_config() - yaml_config = xopt_config.get("yaml_generation", {}) - - # Resolve model config from name reference - # Falls back to "orchestrator" model if xopt-specific model not configured - model_config = None - model_config_name = yaml_config.get("model_config_name", "xopt_yaml_generation") - try: - model_config = get_model_config(model_config_name) - # Check if the model config is valid (has provider) - if not model_config or not model_config.get("provider"): - logger.debug( - f"Model '{model_config_name}' not configured, falling back to orchestrator" - ) - model_config = get_model_config("orchestrator") - except Exception as e: - logger.warning( - f"Could not load model config '{model_config_name}': {e}, falling back to orchestrator" - ) - model_config = get_model_config("orchestrator") - - return { - "mode": yaml_config.get("mode", "react"), # Default to react (agent-based) - "examples_path": yaml_config.get("examples_path"), # None if not specified - "model_config": model_config, - } - - -def _generate_placeholder_yaml(objective: str, strategy: XOptStrategy) -> str: - """Generate placeholder XOpt YAML configuration. - - PLACEHOLDER: This generates a minimal valid YAML structure. - Used when yaml_generation.mode is "mock". - - DO NOT add accelerator-specific parameters without operator input. - """ - return f"""# XOpt Optimization Configuration -# PLACEHOLDER - Generated for: {objective} -# Strategy: {strategy.value} - -# NOTE: This is a MOCK configuration for testing the workflow. -# Set yaml_generation.mode: "react" to use the ReAct agent. - -generator: - name: random # Placeholder generator - # Real implementation would use appropriate generator based on strategy - -evaluator: - function: placeholder_objective - # Real implementation would define actual objective function - -vocs: - variables: - param_1: - type: continuous - lower: 0.0 - upper: 10.0 - param_2: - type: continuous - lower: -1.0 - upper: 1.0 - objectives: - objective_1: - type: minimize - constraints: {{}} - statics: {{}} - -n_initial: 5 -max_evaluations: 20 - -# NOTE: This is a placeholder configuration. -# Actual XOpt parameters will be determined based on: -# - Historical YAML examples from the facility -# - Operator-defined parameter bounds -# - Machine-specific safety constraints -""" - - -async def _generate_yaml_with_react_agent( - objective: str, - strategy: XOptStrategy, - examples_path: str | None, - model_config: dict[str, Any] | None = None, -) -> str: - """Generate YAML using the ReAct agent. - - The agent dynamically adapts: - - If examples_path has YAML files: Agent gets file tools and reads examples - - If no examples: Agent generates from built-in XOpt knowledge - - Args: - objective: The optimization objective - strategy: The selected strategy - examples_path: Path to example YAML files (optional) - model_config: Optional model configuration for the agent - - Returns: - Generated YAML configuration string - """ - from .agent import create_yaml_generation_agent - - # Check if examples path exists - if not, agent will work without file tools - if examples_path: - path = Path(examples_path) - if not path.exists(): - examples_path = None - - # Create and run the agent (it adapts based on whether examples exist) - agent = create_yaml_generation_agent( - examples_path=examples_path, - model_config=model_config, - ) - - try: - yaml_config = await agent.generate_yaml( - objective=objective, - strategy=strategy.value, - ) - return yaml_config - except Exception as e: - logger.warning(f"ReAct agent failed, falling back to mock: {e}") - return _generate_placeholder_yaml(objective, strategy) - - -def _validate_yaml(yaml_config: str) -> None: - """Validate generated YAML configuration. - - PLACEHOLDER: Basic validation only. - Real implementation would use XOpt schema validation. - """ - if not yaml_config or not yaml_config.strip(): - raise YamlGenerationError( - "Generated YAML is empty", - generated_yaml=yaml_config, - validation_errors=["Empty YAML configuration"], - ) - # Future: Add XOpt schema validation - - -def create_yaml_generation_node(): - """Create the YAML generation node for LangGraph integration. - - This factory function creates a node that generates XOpt YAML - configurations and prepares approval interrupt data. - - The generation mode is controlled via configuration: - - osprey.xopt_optimizer.yaml_generation.mode: "mock" | "react" - - Returns: - Async function that takes XOptExecutionState and returns state updates - """ - - async def yaml_generation_node(state: XOptExecutionState) -> dict[str, Any]: - """Generate XOpt YAML configuration. - - Supports two modes: - - "mock": Fast placeholder generation for testing (default) - - "react": ReAct agent reads examples and generates YAML - - Also prepares approval interrupt data following the Python - executor's analyzer pattern. - """ - node_logger = get_logger("xopt_optimizer", state=state) - - # Get configuration - yaml_gen_config = _get_yaml_generation_config() - mode = yaml_gen_config.get("mode", "mock") - is_mock = mode == "mock" - mode_indicator = " (mock)" if is_mock else "" - - node_logger.status(f"Generating XOpt configuration{mode_indicator}...") - - # Track generation attempts - attempt = state.get("yaml_generation_attempt", 0) + 1 - request = state.get("request") - strategy = state.get("selected_strategy", XOptStrategy.EXPLORATION) - objective = request.optimization_objective if request else "Unknown objective" - - try: - # Generate YAML configuration based on mode - if mode == "react": - yaml_config = await _generate_yaml_with_react_agent( - objective=objective, - strategy=strategy, - examples_path=yaml_gen_config.get("examples_path", DEFAULT_EXAMPLES_PATH), - model_config=yaml_gen_config.get("model_config"), - ) - else: - yaml_config = _generate_placeholder_yaml(objective, strategy) - - # Validate YAML - _validate_yaml(yaml_config) - - node_logger.key_info(f"YAML configuration generated{mode_indicator}") - - # Prepare approval interrupt data (following Python executor pattern) - requires_approval = request.require_approval if request else True - - if requires_approval: - # Import here to avoid circular imports - from osprey.approval.approval_system import create_xopt_approval_interrupt - - machine_state_details = state.get("machine_state_details") - - approval_interrupt_data = create_xopt_approval_interrupt( - yaml_config=yaml_config, - strategy=strategy.value, - objective=objective, - machine_state_details=machine_state_details, - step_objective=f"Execute XOpt optimization: {objective}", - ) - - return { - "generated_yaml": yaml_config, - "yaml_generation_attempt": attempt, - "yaml_generation_failed": False, - "requires_approval": True, - "approval_interrupt_data": approval_interrupt_data, - "current_stage": "approval", - } - else: - return { - "generated_yaml": yaml_config, - "yaml_generation_attempt": attempt, - "yaml_generation_failed": False, - "requires_approval": False, - "current_stage": "execution", - } - - except YamlGenerationError: - # Re-raise YAML generation errors - raise - - except Exception as e: - node_logger.warning(f"YAML generation failed: {e}") - - error = XOptError( - error_type="yaml_generation", - error_message=str(e), - stage="yaml_generation", - attempt_number=attempt, - ) - error_chain = list(state.get("error_chain", [])) + [error] - - # Check retry limit - max_retries = request.retries if request else 3 - retry_limit_exceeded = len(error_chain) >= max_retries - - return { - "yaml_generation_attempt": attempt, - "yaml_generation_failed": True, - "error_chain": error_chain, - "is_failed": retry_limit_exceeded, - "failure_reason": ( - f"YAML generation failed after {max_retries} attempts" - if retry_limit_exceeded - else None - ), - "current_stage": "yaml_gen" if not retry_limit_exceeded else "failed", - } - - return yaml_generation_node diff --git a/src/osprey/utils/config.py b/src/osprey/utils/config.py index 5ad65d00b..9c195b6ab 100644 --- a/src/osprey/utils/config.py +++ b/src/osprey/utils/config.py @@ -833,8 +833,8 @@ def get_xopt_optimizer_config(config_path: str | None = None) -> dict[str, Any]: mock_files: true decision: mode: "llm" # or "mock" - yaml_generation: - mode: "react" # or "mock" + config_generation: + mode: "structured" # or "mock" Args: config_path: Optional explicit path to configuration file From d94d69c616ec41dc41d88bc52d1f942dcd4de19b Mon Sep 17 00:00:00 2001 From: Gianluca Martino Date: Fri, 13 Mar 2026 15:09:15 -0700 Subject: [PATCH 10/14] feat(xopt): add tuning service templates and update optimization prompts --- src/osprey/prompts/defaults/optimization.py | 34 +++++------- src/osprey/prompts/loader.py | 10 ++++ .../apps/control_assistant/config.yml.j2 | 17 +++--- src/osprey/templates/project/config.yml.j2 | 55 +++++++++++++++++++ .../services/tuning-api/docker-compose.yml.j2 | 33 +++++++++++ .../tuning-redis/docker-compose.yml.j2 | 26 +++++++++ .../services/tuning-web/docker-compose.yml.j2 | 27 +++++++++ 7 files changed, 173 insertions(+), 29 deletions(-) create mode 100644 src/osprey/templates/services/tuning-api/docker-compose.yml.j2 create mode 100644 src/osprey/templates/services/tuning-redis/docker-compose.yml.j2 create mode 100644 src/osprey/templates/services/tuning-web/docker-compose.yml.j2 diff --git a/src/osprey/prompts/defaults/optimization.py b/src/osprey/prompts/defaults/optimization.py index 35face1c5..411ca014b 100644 --- a/src/osprey/prompts/defaults/optimization.py +++ b/src/osprey/prompts/defaults/optimization.py @@ -29,13 +29,13 @@ class DefaultOptimizationPromptBuilder(FrameworkPromptBuilder): Override Points: - get_instructions(): Domain-specific optimization guidance - get_machine_state_definitions(): Facility-specific machine states - - get_yaml_generation_guidance(): Historical patterns and templates + - get_config_generation_guidance(): Historical patterns and templates - get_strategy_selection_guidance(): Strategy selection criteria """ PROMPT_TYPE = "optimization" - def get_role_definition(self) -> str: + def get_role(self) -> str: """Get the role definition for optimization. :return: Role definition string @@ -43,14 +43,6 @@ def get_role_definition(self) -> str: """ return "You are an expert optimization assistant helping to configure and execute autonomous machine optimization using XOpt." - def get_task_definition(self) -> str: - """Get the task definition for optimization. - - :return: Task definition or None if task is provided externally - :rtype: Optional[str] - """ - return None # Task is provided via request - def get_instructions(self) -> str: """Get domain-specific optimization instructions. @@ -70,7 +62,7 @@ def get_instructions(self) -> str: When implementing actual optimization: 1. Assess machine readiness before proceeding 2. Use appropriate optimization strategy (exploration vs optimization) - 3. Generate valid XOpt YAML configuration + 3. Generate optimization configuration 4. Request human approval before execution 5. Analyze results and provide recommendations @@ -92,24 +84,24 @@ def get_machine_state_definitions(self) -> dict[str, str]: "unknown": "Machine state assessment inconclusive", } - def get_yaml_generation_guidance(self) -> str: - """Get guidance for XOpt YAML configuration generation. + def get_config_generation_guidance(self) -> str: + """Get guidance for optimization config generation. - :return: Domain-specific YAML generation guidance + :return: Domain-specific config generation guidance :rtype: str """ # Placeholder - facilities override with facility-specific templates return textwrap.dedent( """ - YAML Generation Guidance (Placeholder): + Config Generation Guidance (Placeholder): - When generating XOpt YAML configurations: - - Use valid XOpt schema structure - - Define appropriate variables, objectives, and constraints - - Select suitable generator and evaluator + When generating optimization configurations: + - Select appropriate algorithm (upper_confidence_bound, expected_improvement, mobo, random) + - Set iteration count based on task scope + - Specify environment name if known - NOTE: Actual YAML templates and parameter definitions will be - provided based on facility-specific requirements and historical examples. + NOTE: Actual config parameters and constraints will be defined + based on facility-specific requirements. """ ).strip() diff --git a/src/osprey/prompts/loader.py b/src/osprey/prompts/loader.py index 484c5a616..efe9753a1 100644 --- a/src/osprey/prompts/loader.py +++ b/src/osprey/prompts/loader.py @@ -398,6 +398,16 @@ def get_channel_finder_middle_layer_prompt_builder(self) -> "FrameworkPromptBuil return DefaultMiddleLayerPromptBuilder() + def get_optimization_prompt_builder(self) -> "FrameworkPromptBuilder": + """Provide prompt builder for optimization capability guides. + + Returns a default implementation. Override in application prompt providers + to supply facility-specific optimization guidance. + """ + from osprey.prompts.defaults.optimization import DefaultOptimizationPromptBuilder + + return DefaultOptimizationPromptBuilder() + # ================================================================= # ARIEL prompt builders (used by native ARIEL search service) # ================================================================= diff --git a/src/osprey/templates/apps/control_assistant/config.yml.j2 b/src/osprey/templates/apps/control_assistant/config.yml.j2 index 303d96ae2..8917289e7 100644 --- a/src/osprey/templates/apps/control_assistant/config.yml.j2 +++ b/src/osprey/templates/apps/control_assistant/config.yml.j2 @@ -57,10 +57,10 @@ models: provider: {{ default_provider }} model_id: {{ default_model }} max_tokens: 4096 # For channel finder semantic search - xopt_yaml_generation: + xopt_config_generation: provider: {{ default_provider }} model_id: {{ default_model }} - max_tokens: 4096 # For XOpt YAML configuration generation + max_tokens: 4096 # For XOpt optimization config generation xopt_decision: provider: {{ default_provider }} model_id: {{ default_model }} @@ -552,12 +552,13 @@ xopt_optimizer: mode: "llm" # "llm" (LLM-based with structured output) or "mock" (defaults to exploration) model_config_name: "xopt_decision" # Reference to models section - # YAML Generation Agent - # Generates XOpt configuration files using ReAct pattern - yaml_generation: - mode: "react" # "react" (agent-based, default) or "mock" (fast placeholder) - # examples_path: "_agent_data/xopt_examples/yaml_templates" # Optional - agent adapts if missing - model_config_name: "xopt_yaml_generation" # Reference to models section + # Config Generation + # Generates optimization config dicts for the tuning_scripts API + config_generation: + mode: "structured" # "structured" (LLM structured output) or "mock" (fast placeholder) + model_config_name: "xopt_config_generation" # Reference to models section + default_algorithm: "upper_confidence_bound" + default_environment: null # Analysis Agent (placeholder for future implementation) # Analyzes optimization results and decides continuation diff --git a/src/osprey/templates/project/config.yml.j2 b/src/osprey/templates/project/config.yml.j2 index c48bfbd3f..d6c4e070b 100644 --- a/src/osprey/templates/project/config.yml.j2 +++ b/src/osprey/templates/project/config.yml.j2 @@ -64,6 +64,15 @@ models: time_parsing: provider: {{ default_provider | default("cborg") }} model_id: {{ default_model | default("anthropic/claude-haiku") }} + # xopt_config_generation: + # provider: {{ default_provider | default("cborg") }} + # model_id: {{ default_model | default("anthropic/claude-haiku") }} + # xopt_decision: + # provider: {{ default_provider | default("cborg") }} + # model_id: {{ default_model | default("anthropic/claude-haiku") }} + # xopt_state_identification: + # provider: {{ default_provider | default("cborg") }} + # model_id: {{ default_model | default("anthropic/claude-haiku") }} # ============================================================ # DEPLOYMENT CONFIGURATION @@ -304,6 +313,52 @@ python_executor: max_execution_retries: 3 execution_timeout_seconds: 600 +# ============================================================ +# XOPT OPTIMIZER CONFIGURATION +# ============================================================ +# Configuration for the XOpt optimization service. +# Each stage can run in "mock" mode (for testing) or a real mode. +# Model references point to entries in the models section above. + +xopt_optimizer: + # --- Stage 1: State Identification --- + state_identification: + mode: "mock" # "mock" | "react" + model_config_name: "xopt_state_identification" + # reference_path: "_agent_data/xopt_docs" # Path to machine-readiness reference docs + + # --- Stage 2: Decision --- + decision: + mode: "mock" # "mock" | "llm" + model_config_name: "xopt_decision" + + # --- Stage 3: Config Generation --- + config_generation: + mode: "mock" # "mock" | "structured" + model_config_name: "xopt_config_generation" + default_algorithm: "upper_confidence_bound" + default_environment: null # e.g. "als_injector_sim" + + # --- API Connection (tuning_scripts backend) --- + api: + base_url: "http://tuning-api:8001" # Docker service name + poll_interval_seconds: 5.0 + timeout_seconds: 3600 + +# ============================================================ +# TUNING SERVICES +# ============================================================ +# Container service definitions for the tuning_scripts stack. +# Uncomment to enable Docker deployment of tuning services. + +# services: +# tuning_api: +# source_path: "~/git/tuning_scripts" +# port_host: 8001 +# epics_ca_name_servers: "" +# tuning_web: +# source_path: "~/git/tuning_scripts" +# port_host: 8050 # ============================================================ # APPLICATION METADATA diff --git a/src/osprey/templates/services/tuning-api/docker-compose.yml.j2 b/src/osprey/templates/services/tuning-api/docker-compose.yml.j2 new file mode 100644 index 000000000..447803ea9 --- /dev/null +++ b/src/osprey/templates/services/tuning-api/docker-compose.yml.j2 @@ -0,0 +1,33 @@ +services: + tuning-api: + build: + context: {{ services.tuning_api.source_path }} + dockerfile: tuning_scripts_app/Dockerfile + container_name: tuning-api + labels: + osprey.project.name: "{{ osprey_labels.project_name }}" + osprey.project.root: "{{ osprey_labels.project_root }}" + osprey.deployed.at: "{{ osprey_labels.deployed_at }}" + restart: unless-stopped + ports: + - "{{ deployment.bind_address | default('127.0.0.1') }}:{{ services.tuning_api.port_host | default(8001) }}:8000" + environment: + REDIS_URL: redis://tuning-redis:6379/0 + EPICS_CA_NAME_SERVERS: "{{ services.tuning_api.epics_ca_name_servers | default('') }}" + EPICS_CA_AUTO_ADDR_LIST: "NO" + TZ: {{ system.timezone }} + volumes: + - tuning_results:/app/results + - tuning_cache:/app/.cache + networks: + - osprey-network + depends_on: + tuning-redis: + condition: service_healthy + +volumes: + tuning_results: + tuning_cache: + +networks: + osprey-network: diff --git a/src/osprey/templates/services/tuning-redis/docker-compose.yml.j2 b/src/osprey/templates/services/tuning-redis/docker-compose.yml.j2 new file mode 100644 index 000000000..3b802bbf1 --- /dev/null +++ b/src/osprey/templates/services/tuning-redis/docker-compose.yml.j2 @@ -0,0 +1,26 @@ +services: + tuning-redis: + image: redis:7-alpine + container_name: tuning-redis + labels: + osprey.project.name: "{{ osprey_labels.project_name }}" + osprey.project.root: "{{ osprey_labels.project_root }}" + osprey.deployed.at: "{{ osprey_labels.deployed_at }}" + restart: unless-stopped + environment: + TZ: {{ system.timezone }} + volumes: + - tuning_redis_data:/data + networks: + - osprey-network + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + +volumes: + tuning_redis_data: + +networks: + osprey-network: diff --git a/src/osprey/templates/services/tuning-web/docker-compose.yml.j2 b/src/osprey/templates/services/tuning-web/docker-compose.yml.j2 new file mode 100644 index 000000000..1e8378adf --- /dev/null +++ b/src/osprey/templates/services/tuning-web/docker-compose.yml.j2 @@ -0,0 +1,27 @@ +services: + tuning-web: + build: + context: {{ services.tuning_web.source_path }} + dockerfile: tuning_scripts_web/Dockerfile + container_name: tuning-web + labels: + osprey.project.name: "{{ osprey_labels.project_name }}" + osprey.project.root: "{{ osprey_labels.project_root }}" + osprey.deployed.at: "{{ osprey_labels.deployed_at }}" + restart: unless-stopped + ports: + - "{{ deployment.bind_address | default('127.0.0.1') }}:{{ services.tuning_web.port_host | default(8050) }}:8050" + environment: + API_URL: http://tuning-api:8001 + REDIS_URL: redis://tuning-redis:6379/0 + TZ: {{ system.timezone }} + networks: + - osprey-network + depends_on: + tuning-redis: + condition: service_healthy + tuning-api: + condition: service_started + +networks: + osprey-network: From 23e9f9df0d3a931efb31f1df0135ac44bf37cda2 Mon Sep 17 00:00:00 2001 From: Gianluca Martino Date: Fri, 13 Mar 2026 15:09:46 -0700 Subject: [PATCH 11/14] test(xopt): add API client tests and update tests for config_generation refactor --- tests/conftest.py | 20 +- .../xopt_optimizer/test_e2e_weather_agent.py | 181 ++++++++++ .../test_execution_api_client.py | 321 ++++++++++++++++++ .../xopt_optimizer/test_xopt_approval.py | 38 ++- .../xopt_optimizer/test_xopt_exceptions.py | 48 +-- .../xopt_optimizer/test_xopt_service.py | 44 +-- .../xopt_optimizer/test_xopt_workflow.py | 105 +++++- 7 files changed, 678 insertions(+), 79 deletions(-) create mode 100644 tests/services/xopt_optimizer/test_e2e_weather_agent.py create mode 100644 tests/services/xopt_optimizer/test_execution_api_client.py diff --git a/tests/conftest.py b/tests/conftest.py index 39508deb4..286b285a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -442,7 +442,7 @@ def get_registry_config(self): "models": { "orchestrator": {"provider": "openai", "model_id": "gpt-4"}, "python_code_generator": {"provider": "openai", "model_id": "gpt-4"}, - "xopt_yaml_generation": {"provider": "openai", "model_id": "gpt-4"}, + "xopt_config_generation": {"provider": "openai", "model_id": "gpt-4"}, "xopt_decision": {"provider": "openai", "model_id": "gpt-4"}, }, "xopt_optimizer": { @@ -453,9 +453,13 @@ def get_registry_config(self): "mode": "mock", # Use mock for fast tests "model_config_name": "xopt_decision", }, - "yaml_generation": { + "config_generation": { "mode": "mock", # Use mock for fast tests - "model_config_name": "xopt_yaml_generation", + "model_config_name": "xopt_config_generation", + "default_environment": "test_environment", + }, + "api": { + "base_url": "http://localhost:19876", # Non-existent port to force placeholder fallback }, }, } @@ -552,7 +556,7 @@ def get_registry_config(self): "models": { "orchestrator": {"provider": "openai", "model_id": "gpt-4"}, "python_code_generator": {"provider": "openai", "model_id": "gpt-4"}, - "xopt_yaml_generation": {"provider": "openai", "model_id": "gpt-4"}, + "xopt_config_generation": {"provider": "openai", "model_id": "gpt-4"}, "xopt_decision": {"provider": "openai", "model_id": "gpt-4"}, }, "xopt_optimizer": { @@ -563,9 +567,13 @@ def get_registry_config(self): "mode": "mock", # Use mock for fast tests "model_config_name": "xopt_decision", }, - "yaml_generation": { + "config_generation": { "mode": "mock", # Use mock for fast tests - "model_config_name": "xopt_yaml_generation", + "model_config_name": "xopt_config_generation", + "default_environment": "test_environment", + }, + "api": { + "base_url": "http://localhost:19876", # Non-existent port to force placeholder fallback }, }, } diff --git a/tests/services/xopt_optimizer/test_e2e_weather_agent.py b/tests/services/xopt_optimizer/test_e2e_weather_agent.py new file mode 100644 index 000000000..af0a206fd --- /dev/null +++ b/tests/services/xopt_optimizer/test_e2e_weather_agent.py @@ -0,0 +1,181 @@ +"""End-to-end integration test: optimization capability + xopt_optimizer service. + +Validates that the optimization capability can be loaded from a registry +and that the xopt_optimizer service runs the full mock-mode workflow. +This simulates what would happen when a user says "Optimize injection efficiency" +through the weather-agent project. +""" + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from osprey.services.xopt_optimizer.execution.api_client import TuningScriptsAPIError + + +class TestEndToEndOptimization: + """Test optimization from capability registration through service execution.""" + + @pytest.mark.asyncio + async def test_capability_loads_and_service_runs(self, test_config): + """Verify OptimizationCapability can be imported and xopt_optimizer service runs.""" + os.environ["CONFIG_FILE"] = str(test_config) + + # Clear config cache + from osprey.utils import config as config_module + + config_module._default_config = None + config_module._default_configurable = None + config_module._config_cache.clear() + + # 1. Verify capability can be imported and instantiated + from osprey.capabilities.optimization import ( + OptimizationCapability, + OptimizationResultContext, + ) + + assert OptimizationCapability.name == "optimization" + assert "OPTIMIZATION_RESULT" in OptimizationCapability.provides + + # 2. Verify context class works + ctx = OptimizationResultContext( + run_artifact={"status": "test"}, + strategy="exploration", + total_iterations=1, + optimization_config={"algorithm": "random"}, + ) + assert ctx.context_type == "OPTIMIZATION_RESULT" + summary = ctx.get_summary() + assert summary["strategy"] == "exploration" + assert summary["iterations"] == 1 + + # 3. Run the service directly (mock mode, no LLM calls) + from osprey.services.xopt_optimizer import ( + XOptExecutionRequest, + XOptOptimizerService, + XOptServiceResult, + ) + + mock_client = MagicMock() + mock_client.health_check = AsyncMock( + side_effect=TuningScriptsAPIError("Connection refused") + ) + + service = XOptOptimizerService() + request = XOptExecutionRequest( + user_query="Optimize injection efficiency", + optimization_objective="Maximize injection efficiency", + max_iterations=1, + require_approval=False, + ) + + config = { + "configurable": { + "thread_id": "test_e2e", + "checkpoint_ns": "xopt_test", + } + } + + with patch( + "osprey.services.xopt_optimizer.execution.node.TuningScriptsClient", + return_value=mock_client, + ): + result = await service.ainvoke(request, config) + + assert isinstance(result, XOptServiceResult) + assert isinstance(result.optimization_config, dict) + assert "algorithm" in result.optimization_config + assert result.total_iterations == 1 + + # 4. Verify context can be created from service result + from osprey.capabilities.optimization import _create_optimization_context + + ctx = _create_optimization_context(result) + assert ctx.context_type == "OPTIMIZATION_RESULT" + assert ctx.strategy == result.strategy.value + assert ctx.optimization_config == result.optimization_config + + @pytest.mark.asyncio + async def test_service_with_real_api_data_path(self, test_config): + """Test full path with mocked API returning real data.""" + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.capabilities.optimization import _create_optimization_context + from osprey.services.xopt_optimizer import ( + XOptExecutionRequest, + XOptOptimizerService, + XOptServiceResult, + ) + + full_state = { + "job_id": "e2e-test-001", + "status": "completed", + "data": [ + {"x1": 0.5, "x2": -1.0, "objective": 0.85}, + {"x1": 1.0, "x2": -0.5, "objective": 0.92}, + {"x1": 0.8, "x2": -0.7, "objective": 0.95}, + ], + "environment_name": "test_env", + "objective_name": "objective", + "variable_names": ["x1", "x2"], + "results_path": "2024/e2e-test", + "logs": "Completed", + } + + mock_client = MagicMock() + mock_client.health_check = AsyncMock(return_value={"status": "ok"}) + mock_client.submit_config = AsyncMock(return_value="e2e-test-001") + mock_client.poll_until_complete = AsyncMock(return_value=full_state) + + service = XOptOptimizerService() + request = XOptExecutionRequest( + user_query="Optimize the objective function", + optimization_objective="Maximize objective", + max_iterations=1, + require_approval=False, + ) + + config = { + "configurable": { + "thread_id": "test_e2e_api", + "checkpoint_ns": "xopt_test", + } + } + + with patch( + "osprey.services.xopt_optimizer.execution.node.TuningScriptsClient", + return_value=mock_client, + ): + result = await service.ainvoke(request, config) + + assert isinstance(result, XOptServiceResult) + assert result.run_artifact["job_id"] == "e2e-test-001" + assert result.run_artifact["data"] is not None + assert len(result.run_artifact["data"]) == 3 + + # Analysis should identify best point + assert any("0.95" in r for r in result.recommendations) + + # Context creation should work + ctx = _create_optimization_context(result) + assert ctx.optimization_config == result.optimization_config + assert ctx.total_iterations == 1 + + @pytest.mark.asyncio + async def test_prompt_builder_loads(self): + """Verify optimization prompt builder is accessible.""" + from osprey.prompts.defaults.optimization import DefaultOptimizationPromptBuilder + + builder = DefaultOptimizationPromptBuilder() + + # Classifier guide should have examples + classifier = builder.get_classifier_guide() + assert classifier is not None + assert len(classifier.examples) > 0 + + # Orchestrator guide should exist + # Note: requires registry to be initialized, so we test the builder directly + assert hasattr(builder, "get_orchestrator_guide") + assert hasattr(builder, "get_config_generation_guidance") + assert hasattr(builder, "get_strategy_selection_guidance") diff --git a/tests/services/xopt_optimizer/test_execution_api_client.py b/tests/services/xopt_optimizer/test_execution_api_client.py new file mode 100644 index 000000000..2c9ef1091 --- /dev/null +++ b/tests/services/xopt_optimizer/test_execution_api_client.py @@ -0,0 +1,321 @@ +"""Tests for TuningScriptsClient and the execution node. + +Uses mocked aiohttp responses β€” no running API required. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from osprey.services.xopt_optimizer.execution.api_client import ( + TuningScriptsAPIError, + TuningScriptsClient, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_mock_response(status: int = 200, json_data: dict | None = None, text: str = ""): + """Create a mock aiohttp response.""" + resp = MagicMock() + resp.status = status + resp.method = "GET" + resp.url = "http://test/mock" + resp.json = AsyncMock(return_value=json_data or {}) + resp.text = AsyncMock(return_value=text) + return resp + + +def _make_client(**kwargs) -> TuningScriptsClient: + """Create a client with config loading disabled.""" + with patch.object(TuningScriptsClient, "_load_api_config", return_value={}): + return TuningScriptsClient( + base_url="http://test-api:8001", + poll_interval_seconds=0.01, # Fast polling for tests + timeout_seconds=1.0, + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Client method tests +# --------------------------------------------------------------------------- + + +def _mock_session_for(mock_resp, method="get"): + """Build a mock aiohttp.ClientSession whose `method` returns mock_resp.""" + # The response is returned from an async context manager (session.get/post) + resp_ctx = AsyncMock() + resp_ctx.__aenter__ = AsyncMock(return_value=mock_resp) + resp_ctx.__aexit__ = AsyncMock(return_value=False) + + mock_session = AsyncMock() + setattr(mock_session, method, MagicMock(return_value=resp_ctx)) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + return mock_session + + +class TestTuningScriptsClient: + """Unit tests for TuningScriptsClient HTTP methods.""" + + @pytest.mark.asyncio + async def test_health_check_success(self): + client = _make_client() + mock_resp = _make_mock_response(200, {"status": "ok"}) + mock_session = _mock_session_for(mock_resp, "get") + + with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + result = await client.health_check() + + assert result == {"status": "ok"} + + @pytest.mark.asyncio + async def test_health_check_failure(self): + client = _make_client() + mock_resp = _make_mock_response(503, {"detail": "unhealthy"}) + mock_session = _mock_session_for(mock_resp, "get") + + with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(TuningScriptsAPIError) as exc_info: + await client.health_check() + + assert exc_info.value.status_code == 503 + + @pytest.mark.asyncio + async def test_submit_config_returns_job_id(self): + client = _make_client() + mock_resp = _make_mock_response(200, {"job_id": "abc-123", "status": "submitted"}) + mock_session = _mock_session_for(mock_resp, "post") + + with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + job_id = await client.submit_config({"algorithm": "random", "n_iterations": 10}) + + assert job_id == "abc-123" + + @pytest.mark.asyncio + async def test_submit_yaml_returns_job_id(self): + client = _make_client() + mock_resp = _make_mock_response(200, {"job_id": "abc-123", "status": "submitted"}) + mock_session = _mock_session_for(mock_resp, "post") + + with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + job_id = await client.submit_yaml("generator:\n name: random\n") + + assert job_id == "abc-123" + + @pytest.mark.asyncio + async def test_submit_yaml_with_iterations(self): + client = _make_client() + mock_resp = _make_mock_response(200, {"job_id": "abc-123", "status": "submitted"}) + mock_session = _mock_session_for(mock_resp, "post") + + with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + job_id = await client.submit_yaml("yaml: data\n", n_iterations=10) + + assert job_id == "abc-123" + + @pytest.mark.asyncio + async def test_get_status(self): + client = _make_client() + mock_resp = _make_mock_response(200, {"job_id": "abc", "status": "running"}) + mock_session = _mock_session_for(mock_resp, "get") + + with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + result = await client.get_status("abc") + + assert result["status"] == "running" + + @pytest.mark.asyncio + async def test_get_full_state(self): + client = _make_client() + full_data = { + "job_id": "abc", + "status": "completed", + "data": [{"x": 1.0, "f": 0.5}], + "variable_names": ["x"], + "objective_name": "f", + } + mock_resp = _make_mock_response(200, full_data) + mock_session = _mock_session_for(mock_resp, "get") + + with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + result = await client.get_full_state("abc") + + assert result["data"] == [{"x": 1.0, "f": 0.5}] + + @pytest.mark.asyncio + async def test_cancel(self): + client = _make_client() + mock_resp = _make_mock_response(200, {"status": "cancelled", "job_id": "abc"}) + mock_session = _mock_session_for(mock_resp, "post") + + with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + result = await client.cancel("abc") + + assert result["status"] == "cancelled" + + +# --------------------------------------------------------------------------- +# Polling tests +# --------------------------------------------------------------------------- + + +class TestPolling: + """Tests for poll_until_complete.""" + + @pytest.mark.asyncio + async def test_poll_completes_on_success(self): + client = _make_client() + + # First call returns running, second returns completed + status_responses = [ + {"job_id": "j1", "status": "running"}, + {"job_id": "j1", "status": "completed"}, + ] + full_state = { + "job_id": "j1", + "status": "completed", + "data": [{"x": 1.0}], + } + + call_count = 0 + + async def mock_get_status(job_id): + nonlocal call_count + resp = status_responses[min(call_count, len(status_responses) - 1)] + call_count += 1 + return resp + + client.get_status = mock_get_status + client.get_full_state = AsyncMock(return_value=full_state) + + result = await client.poll_until_complete("j1") + assert result["status"] == "completed" + assert result["data"] == [{"x": 1.0}] + + @pytest.mark.asyncio + async def test_poll_raises_on_error_status(self): + client = _make_client() + + client.get_status = AsyncMock(return_value={"job_id": "j1", "status": "error"}) + client.get_full_state = AsyncMock( + return_value={"job_id": "j1", "status": "error", "message": "Divergence detected"} + ) + + with pytest.raises(TuningScriptsAPIError, match="Divergence detected"): + await client.poll_until_complete("j1") + + @pytest.mark.asyncio + async def test_poll_timeout(self): + client = _make_client() + client.timeout = 0.05 # Very short timeout + + client.get_status = AsyncMock(return_value={"job_id": "j1", "status": "running"}) + + with pytest.raises(TuningScriptsAPIError, match="Timeout"): + await client.poll_until_complete("j1") + + +# --------------------------------------------------------------------------- +# Executor node tests +# --------------------------------------------------------------------------- + + +class TestExecutorNode: + """Tests for the execution node with mocked client.""" + + @pytest.mark.asyncio + async def test_executor_with_api_success(self): + from osprey.services.xopt_optimizer.execution.node import create_executor_node + + full_state = { + "job_id": "test-job", + "status": "completed", + "data": [{"x": 1.0, "f": 0.5}], + "environment_name": "test_env", + "objective_name": "f", + "variable_names": ["x"], + "results_path": "2024/test", + "logs": "some logs", + } + + mock_client = MagicMock(spec=TuningScriptsClient) + mock_client.health_check = AsyncMock(return_value={"status": "ok"}) + mock_client.submit_config = AsyncMock(return_value="test-job") + mock_client.poll_until_complete = AsyncMock(return_value=full_state) + + node = create_executor_node() + + state = { + "optimization_config": {"algorithm": "random", "n_iterations": 10}, + "request": MagicMock(), + } + + with patch( + "osprey.services.xopt_optimizer.execution.node.TuningScriptsClient", + return_value=mock_client, + ): + result = await node(state) + + assert result["execution_failed"] is False + assert result["run_artifact"]["job_id"] == "test-job" + assert result["run_artifact"]["data"] == [{"x": 1.0, "f": 0.5}] + + @pytest.mark.asyncio + async def test_executor_api_error(self): + from osprey.services.xopt_optimizer.execution.node import create_executor_node + + mock_client = MagicMock(spec=TuningScriptsClient) + mock_client.health_check = AsyncMock(return_value={"status": "ok"}) + mock_client.submit_config = AsyncMock( + side_effect=TuningScriptsAPIError("conflict", status_code=409, detail="already running") + ) + + node = create_executor_node() + state = {"optimization_config": {"algorithm": "random"}, "request": MagicMock()} + + with patch( + "osprey.services.xopt_optimizer.execution.node.TuningScriptsClient", + return_value=mock_client, + ): + result = await node(state) + + assert result["execution_failed"] is True + assert result["is_failed"] is True + + @pytest.mark.asyncio + async def test_executor_falls_back_on_connection_error(self): + from osprey.services.xopt_optimizer.execution.node import create_executor_node + + mock_client = MagicMock(spec=TuningScriptsClient) + mock_client.health_check = AsyncMock( + side_effect=TuningScriptsAPIError("Connection refused") + ) + + node = create_executor_node() + state = {"optimization_config": {"algorithm": "random"}, "request": MagicMock()} + + with patch( + "osprey.services.xopt_optimizer.execution.node.TuningScriptsClient", + return_value=mock_client, + ): + result = await node(state) + + # Should fall back to placeholder, not fail + assert result["execution_failed"] is False + assert "placeholder" in result["run_artifact"].get("note", "").lower() + + @pytest.mark.asyncio + async def test_executor_missing_config(self): + from osprey.services.xopt_optimizer.execution.node import create_executor_node + + node = create_executor_node() + state = {"optimization_config": None, "request": MagicMock()} + + result = await node(state) + + assert result["execution_failed"] is True + assert "Missing optimization config" in result["failure_reason"] diff --git a/tests/services/xopt_optimizer/test_xopt_approval.py b/tests/services/xopt_optimizer/test_xopt_approval.py index e5e2c37b5..a8204ca66 100644 --- a/tests/services/xopt_optimizer/test_xopt_approval.py +++ b/tests/services/xopt_optimizer/test_xopt_approval.py @@ -12,7 +12,7 @@ class TestCreateXOptApprovalInterrupt: def test_basic_interrupt_creation(self): """Should create interrupt data with required fields.""" result = create_xopt_approval_interrupt( - yaml_config="xopt:\n generator: random", + optimization_config={"algorithm": "upper_confidence_bound", "n_iterations": 20}, strategy="exploration", objective="Maximize efficiency", ) @@ -24,12 +24,15 @@ def test_basic_interrupt_creation(self): assert "HUMAN APPROVAL REQUIRED" in result["user_message"] assert "Maximize efficiency" in result["user_message"] assert "EXPLORATION" in result["user_message"] - assert "xopt:" in result["user_message"] + assert "algorithm" in result["user_message"] # Check resume payload payload = result["resume_payload"] assert payload["approval_type"] == "xopt_optimizer" - assert payload["yaml_config"] == "xopt:\n generator: random" + assert payload["optimization_config"] == { + "algorithm": "upper_confidence_bound", + "n_iterations": 20, + } assert payload["strategy"] == "exploration" assert payload["objective"] == "Maximize efficiency" @@ -41,7 +44,7 @@ def test_interrupt_with_machine_state_details(self): } result = create_xopt_approval_interrupt( - yaml_config="test: yaml", + optimization_config={"algorithm": "random"}, strategy="optimization", objective="Test objective", machine_state_details=machine_details, @@ -57,7 +60,7 @@ def test_interrupt_with_machine_state_details(self): def test_interrupt_with_custom_step_objective(self): """Should use custom step objective.""" result = create_xopt_approval_interrupt( - yaml_config="test: yaml", + optimization_config={"algorithm": "random"}, strategy="exploration", objective="Test", step_objective="Custom optimization task", @@ -69,7 +72,7 @@ def test_interrupt_with_custom_step_objective(self): def test_interrupt_contains_approval_instructions(self): """Should contain clear approval instructions.""" result = create_xopt_approval_interrupt( - yaml_config="test: yaml", + optimization_config={"algorithm": "random"}, strategy="exploration", objective="Test", ) @@ -79,21 +82,20 @@ def test_interrupt_contains_approval_instructions(self): assert "no" in message.lower() assert "approve" in message.lower() - def test_interrupt_yaml_displayed_correctly(self): - """Should display YAML in code block.""" - yaml_config = """xopt: - generator: - name: bayesian - vocs: - variables: - x1: [0, 10] -""" + def test_interrupt_config_displayed_as_yaml(self): + """Should display config as YAML in code block.""" + config = { + "algorithm": "expected_improvement", + "n_iterations": 30, + "environment_name": "test_env", + } result = create_xopt_approval_interrupt( - yaml_config=yaml_config, + optimization_config=config, strategy="optimization", objective="Test", ) - # YAML should be in code block + # Config should be rendered in a yaml code block assert "```yaml" in result["user_message"] - assert yaml_config in result["user_message"] + assert "expected_improvement" in result["user_message"] + assert "n_iterations" in result["user_message"] diff --git a/tests/services/xopt_optimizer/test_xopt_exceptions.py b/tests/services/xopt_optimizer/test_xopt_exceptions.py index 4f526bd05..5ba81b119 100644 --- a/tests/services/xopt_optimizer/test_xopt_exceptions.py +++ b/tests/services/xopt_optimizer/test_xopt_exceptions.py @@ -4,13 +4,13 @@ """ from osprey.services.xopt_optimizer.exceptions import ( + ConfigGenerationError, ConfigurationError, ErrorCategory, MachineStateAssessmentError, MaxIterationsExceededError, XOptExecutionError, XOptExecutorException, - YamlGenerationError, ) @@ -20,7 +20,7 @@ class TestErrorCategory: def test_error_categories_exist(self): """All expected error categories should exist.""" assert ErrorCategory.MACHINE_STATE.value == "machine_state" - assert ErrorCategory.YAML_GENERATION.value == "yaml_generation" + assert ErrorCategory.CONFIG_GENERATION.value == "config_generation" assert ErrorCategory.EXECUTION.value == "execution" assert ErrorCategory.CONFIGURATION.value == "configuration" assert ErrorCategory.WORKFLOW.value == "workflow" @@ -44,20 +44,20 @@ def test_base_exception_with_category(self): def test_is_retriable(self): """is_retriable should return True for retriable categories.""" machine_exc = XOptExecutorException("Test", category=ErrorCategory.MACHINE_STATE) - yaml_exc = XOptExecutorException("Test", category=ErrorCategory.YAML_GENERATION) + config_exc = XOptExecutorException("Test", category=ErrorCategory.CONFIG_GENERATION) workflow_exc = XOptExecutorException("Test", category=ErrorCategory.WORKFLOW) assert machine_exc.is_retriable() is True - assert yaml_exc.is_retriable() is True + assert config_exc.is_retriable() is True assert workflow_exc.is_retriable() is False - def test_should_retry_yaml_generation(self): - """should_retry_yaml_generation should return True for YAML errors.""" - yaml_exc = XOptExecutorException("Test", category=ErrorCategory.YAML_GENERATION) + def test_should_retry_config_generation(self): + """should_retry_config_generation should return True for config gen errors.""" + config_exc = XOptExecutorException("Test", category=ErrorCategory.CONFIG_GENERATION) other_exc = XOptExecutorException("Test", category=ErrorCategory.EXECUTION) - assert yaml_exc.should_retry_yaml_generation() is True - assert other_exc.should_retry_yaml_generation() is False + assert config_exc.should_retry_config_generation() is True + assert other_exc.should_retry_config_generation() is False class TestMachineStateAssessmentError: @@ -79,25 +79,25 @@ def test_is_retriable(self): assert exc.is_retriable() is True -class TestYamlGenerationError: - """Test YamlGenerationError.""" +class TestConfigGenerationError: + """Test ConfigGenerationError.""" def test_creation(self): - """Should be creatable with message and yaml details.""" - exc = YamlGenerationError( - "Invalid YAML", - generated_yaml="bad: yaml", + """Should be creatable with message and config details.""" + exc = ConfigGenerationError( + "Invalid config", + generated_config={"bad": "config"}, validation_errors=["Missing field X"], ) - assert exc.message == "Invalid YAML" - assert exc.category == ErrorCategory.YAML_GENERATION - assert exc.generated_yaml == "bad: yaml" + assert exc.message == "Invalid config" + assert exc.category == ErrorCategory.CONFIG_GENERATION + assert exc.generated_config == {"bad": "config"} assert exc.validation_errors == ["Missing field X"] - def test_should_retry_yaml_generation(self): - """YAML generation errors should trigger retry.""" - exc = YamlGenerationError("Test") - assert exc.should_retry_yaml_generation() is True + def test_should_retry_config_generation(self): + """Config generation errors should trigger retry.""" + exc = ConfigGenerationError("Test") + assert exc.should_retry_config_generation() is True class TestXOptExecutionError: @@ -107,12 +107,12 @@ def test_creation(self): """Should be creatable with message and execution details.""" exc = XOptExecutionError( "XOpt failed", - yaml_used="test: yaml", + config_used={"algorithm": "random"}, xopt_error="Runtime error", ) assert exc.message == "XOpt failed" assert exc.category == ErrorCategory.EXECUTION - assert exc.yaml_used == "test: yaml" + assert exc.config_used == {"algorithm": "random"} assert exc.xopt_error == "Runtime error" def test_not_retriable(self): diff --git a/tests/services/xopt_optimizer/test_xopt_service.py b/tests/services/xopt_optimizer/test_xopt_service.py index ad10f9d1f..523e31b2d 100644 --- a/tests/services/xopt_optimizer/test_xopt_service.py +++ b/tests/services/xopt_optimizer/test_xopt_service.py @@ -72,25 +72,25 @@ def test_xopt_error_dataclass(self): error = XOptError( error_type="test_error", error_message="Test error message", - stage="yaml_generation", + stage="config_generation", attempt_number=1, details={"key": "value"}, ) assert error.error_type == "test_error" assert error.error_message == "Test error message" - assert error.stage == "yaml_generation" + assert error.stage == "config_generation" assert error.attempt_number == 1 # Test prompt text formatting prompt_text = error.to_prompt_text() - assert "YAML_GENERATION FAILED" in prompt_text + assert "CONFIG_GENERATION FAILED" in prompt_text assert "Test error message" in prompt_text def test_xopt_service_result_creation(self): """XOptServiceResult should be creatable with all fields.""" result = XOptServiceResult( run_artifact={"status": "completed"}, - generated_yaml="test: yaml", + optimization_config={"algorithm": "random"}, strategy=XOptStrategy.EXPLORATION, total_iterations=3, analysis_summary={"summary": "test"}, @@ -193,13 +193,13 @@ async def test_state_identification_returns_ready(self, test_config): ), capability_context_data=None, error_chain=[], - yaml_generation_attempt=0, + config_generation_attempt=0, machine_state=None, machine_state_details=None, selected_strategy=None, decision_reasoning=None, - generated_yaml=None, - yaml_generation_failed=None, + optimization_config=None, + config_generation_failed=None, requires_approval=None, approval_interrupt_data=None, approval_result=None, @@ -229,8 +229,8 @@ class TestDecisionNode: """Test decision node.""" @pytest.mark.asyncio - async def test_decision_routes_to_yaml_gen_when_ready(self, test_config): - """Decision node should route to yaml_gen when machine is READY.""" + async def test_decision_routes_to_config_gen_when_ready(self, test_config): + """Decision node should route to config_gen when machine is READY.""" os.environ["CONFIG_FILE"] = str(test_config) from osprey.services.xopt_optimizer.decision import create_decision_node @@ -243,13 +243,13 @@ async def test_decision_routes_to_yaml_gen_when_ready(self, test_config): ), capability_context_data=None, error_chain=[], - yaml_generation_attempt=0, + config_generation_attempt=0, machine_state=MachineState.READY, machine_state_details={"assessment": "test"}, selected_strategy=None, decision_reasoning=None, - generated_yaml=None, - yaml_generation_failed=None, + optimization_config=None, + config_generation_failed=None, requires_approval=None, approval_interrupt_data=None, approval_result=None, @@ -271,7 +271,7 @@ async def test_decision_routes_to_yaml_gen_when_ready(self, test_config): result = await node(state) assert result["selected_strategy"] == XOptStrategy.EXPLORATION - assert result["current_stage"] == "yaml_gen" + assert result["current_stage"] == "config_gen" assert "decision_reasoning" in result @pytest.mark.asyncio @@ -289,13 +289,13 @@ async def test_decision_aborts_when_not_ready(self, test_config): ), capability_context_data=None, error_chain=[], - yaml_generation_attempt=0, + config_generation_attempt=0, machine_state=MachineState.NOT_READY, machine_state_details={"reason": "Machine offline"}, selected_strategy=None, decision_reasoning=None, - generated_yaml=None, - yaml_generation_failed=None, + optimization_config=None, + config_generation_failed=None, requires_approval=None, approval_interrupt_data=None, approval_result=None, @@ -339,13 +339,13 @@ async def test_analysis_continues_when_under_max_iterations(self, test_config): ), capability_context_data=None, error_chain=[], - yaml_generation_attempt=0, + config_generation_attempt=0, machine_state=MachineState.READY, machine_state_details={}, selected_strategy=XOptStrategy.EXPLORATION, decision_reasoning="test", - generated_yaml="test: yaml", - yaml_generation_failed=False, + optimization_config={"algorithm": "random"}, + config_generation_failed=False, requires_approval=False, approval_interrupt_data=None, approval_result=None, @@ -385,13 +385,13 @@ async def test_analysis_completes_at_max_iterations(self, test_config): ), capability_context_data=None, error_chain=[], - yaml_generation_attempt=0, + config_generation_attempt=0, machine_state=MachineState.READY, machine_state_details={}, selected_strategy=XOptStrategy.EXPLORATION, decision_reasoning="test", - generated_yaml="test: yaml", - yaml_generation_failed=False, + optimization_config={"algorithm": "random"}, + config_generation_failed=False, requires_approval=False, approval_interrupt_data=None, approval_result=None, diff --git a/tests/services/xopt_optimizer/test_xopt_workflow.py b/tests/services/xopt_optimizer/test_xopt_workflow.py index d648d1656..37bba6633 100644 --- a/tests/services/xopt_optimizer/test_xopt_workflow.py +++ b/tests/services/xopt_optimizer/test_xopt_workflow.py @@ -1,13 +1,17 @@ """Integration test for XOpt optimizer service workflow. This test runs the full service workflow without approval to verify -the placeholder implementation works end-to-end. +the placeholder implementation works end-to-end, and also tests +the workflow with a mocked TuningScriptsClient for the real API path. """ import os +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from osprey.services.xopt_optimizer.execution.api_client import TuningScriptsAPIError + class TestXOptWorkflow: """Test complete XOpt workflow execution.""" @@ -17,10 +21,19 @@ async def test_full_workflow_without_approval(self, test_config): """Test complete workflow execution without approval. This runs the service through all nodes: - state_id -> decision -> yaml_gen -> execution -> analysis + state_id -> decision -> config_gen -> execution -> analysis + + The execution node falls back to placeholder when the API is unreachable. """ os.environ["CONFIG_FILE"] = str(test_config) + # Clear config cache so test fixture config is picked up + from osprey.utils import config as config_module + + config_module._default_config = None + config_module._default_configurable = None + config_module._config_cache.clear() + from osprey.services.xopt_optimizer import ( XOptExecutionRequest, XOptOptimizerService, @@ -28,6 +41,12 @@ async def test_full_workflow_without_approval(self, test_config): XOptStrategy, ) + # Mock the client to simulate API-unreachable (connection error β†’ placeholder fallback) + mock_client = MagicMock() + mock_client.health_check = AsyncMock( + side_effect=TuningScriptsAPIError("Connection refused") + ) + service = XOptOptimizerService() # Create request with approval disabled @@ -46,18 +65,19 @@ async def test_full_workflow_without_approval(self, test_config): } } - # Run the service - result = await service.ainvoke(request, config) + # Run the service with mocked client to avoid depending on external API + with patch( + "osprey.services.xopt_optimizer.execution.node.TuningScriptsClient", + return_value=mock_client, + ): + result = await service.ainvoke(request, config) # Verify result structure assert isinstance(result, XOptServiceResult) assert result.strategy == XOptStrategy.EXPLORATION assert result.total_iterations == 2 # We set max_iterations=2 - assert result.generated_yaml is not None - # Check for valid XOpt YAML structure (generator and vocs are required) - yaml_lower = result.generated_yaml.lower() - assert "generator" in yaml_lower, "Generated YAML should contain generator config" - assert "vocs" in yaml_lower, "Generated YAML should contain vocs config" + assert isinstance(result.optimization_config, dict) + assert "algorithm" in result.optimization_config assert len(result.recommendations) > 0 @pytest.mark.asyncio @@ -91,3 +111,70 @@ async def test_single_iteration_workflow(self, test_config): assert isinstance(result, XOptServiceResult) assert result.total_iterations == 1 + + @pytest.mark.asyncio + async def test_workflow_with_mocked_api_client(self, test_config): + """Test full workflow with a mocked TuningScriptsClient. + + This simulates the real API path where the execution node talks + to the tuning_scripts API and receives actual optimization data. + """ + os.environ["CONFIG_FILE"] = str(test_config) + + from osprey.services.xopt_optimizer import ( + XOptExecutionRequest, + XOptOptimizerService, + XOptServiceResult, + ) + + # Mock API responses with real-looking data + full_state = { + "job_id": "mock-job-001", + "status": "completed", + "data": [ + {"quad1_k1": 0.5, "quad2_k1": -1.0, "injection_efficiency": 0.85}, + {"quad1_k1": 1.0, "quad2_k1": -0.5, "injection_efficiency": 0.92}, + {"quad1_k1": 0.8, "quad2_k1": -0.7, "injection_efficiency": 0.95}, + ], + "environment_name": "als_injector", + "objective_name": "injection_efficiency", + "variable_names": ["quad1_k1", "quad2_k1"], + "results_path": "2024/mock-job-001", + "logs": "Optimization completed in 3 evaluations", + } + + mock_client = MagicMock() + mock_client.health_check = AsyncMock(return_value={"status": "ok"}) + mock_client.submit_config = AsyncMock(return_value="mock-job-001") + mock_client.poll_until_complete = AsyncMock(return_value=full_state) + + service = XOptOptimizerService() + + request = XOptExecutionRequest( + user_query="Optimize injection efficiency", + optimization_objective="Maximize injection efficiency", + max_iterations=1, + require_approval=False, + ) + + config = { + "configurable": { + "thread_id": "test_mocked_api", + "checkpoint_ns": "xopt_test", + } + } + + with patch( + "osprey.services.xopt_optimizer.execution.node.TuningScriptsClient", + return_value=mock_client, + ): + result = await service.ainvoke(request, config) + + assert isinstance(result, XOptServiceResult) + assert result.total_iterations == 1 + assert result.run_artifact["job_id"] == "mock-job-001" + assert result.run_artifact["data"] is not None + assert len(result.run_artifact["data"]) == 3 + # Analysis should have found best point + assert len(result.recommendations) > 0 + assert any("injection_efficiency" in r or "0.95" in r for r in result.recommendations) From 3d8c73fa3ff65f2c8b37b55abd6869ec01da8bc6 Mon Sep 17 00:00:00 2001 From: Gianluca Martino Date: Fri, 13 Mar 2026 15:10:24 -0700 Subject: [PATCH 12/14] feat(xopt): add question interrupt support for environment selection --- src/osprey/approval/__init__.py | 2 + src/osprey/approval/approval_system.py | 48 ++++++++-- src/osprey/capabilities/optimization.py | 91 +++++++++++++++++-- src/osprey/infrastructure/gateway.py | 19 ++++ .../xopt_optimizer/config_generation/node.py | 49 +++------- src/osprey/services/xopt_optimizer/models.py | 3 + 6 files changed, 158 insertions(+), 54 deletions(-) diff --git a/src/osprey/approval/__init__.py b/src/osprey/approval/__init__.py index 8903dbe70..30b8c96da 100644 --- a/src/osprey/approval/__init__.py +++ b/src/osprey/approval/__init__.py @@ -56,6 +56,7 @@ create_code_approval_interrupt, create_memory_approval_interrupt, create_plan_approval_interrupt, + create_question_interrupt, create_step_approval_interrupt, create_xopt_approval_interrupt, get_approval_resume_data, @@ -78,6 +79,7 @@ "create_code_approval_interrupt", "create_memory_approval_interrupt", "create_channel_write_approval_interrupt", + "create_question_interrupt", "create_xopt_approval_interrupt", "get_approved_payload_from_state", "get_approval_resume_data", diff --git a/src/osprey/approval/approval_system.py b/src/osprey/approval/approval_system.py index da0c688f7..1bf425145 100644 --- a/src/osprey/approval/approval_system.py +++ b/src/osprey/approval/approval_system.py @@ -525,8 +525,33 @@ def create_code_approval_interrupt( } +def create_question_interrupt( + question: str, + options: list[str] | None = None, +) -> dict[str, Any]: + """Create a question interrupt payload for non-approval user interactions. + + Unlike approval interrupts (which expect yes/no), question interrupts pass + the user's free-text answer back via ``Command(resume=answer)``. The + gateway detects the ``"type": "question"`` key and skips approval + classification. + + :param question: The question to present to the user + :type question: str + :param options: Optional list of valid option strings for display + :type options: list[str] | None + :return: Interrupt payload dict with type, user_message, and options + :rtype: dict[str, Any] + """ + return { + "type": "question", + "user_message": question, + "options": options or [], + } + + def create_xopt_approval_interrupt( - yaml_config: str, + optimization_config: dict[str, Any], strategy: str, objective: str, machine_state_details: dict[str, Any] | None = None, @@ -536,11 +561,10 @@ def create_xopt_approval_interrupt( Generates LangGraph-compatible interrupt data for XOpt configurations that require human approval before execution. The interrupt provides comprehensive - context including the generated YAML, optimization strategy, and machine - state details. + context including the optimization config, strategy, and machine state details. - :param yaml_config: Generated XOpt YAML configuration - :type yaml_config: str + :param optimization_config: Generated optimization config dict + :type optimization_config: Dict[str, Any] :param strategy: Selected optimization strategy (exploration/optimization) :type strategy: str :param objective: Optimization objective description @@ -556,7 +580,7 @@ def create_xopt_approval_interrupt( Basic XOpt approval:: >>> interrupt_data = create_xopt_approval_interrupt( - ... yaml_config="xopt:\\n generator: random", + ... optimization_config={"algorithm": "upper_confidence_bound"}, ... strategy="exploration", ... objective="Maximize injection efficiency", ... step_objective="Execute XOpt optimization" @@ -568,6 +592,11 @@ def create_xopt_approval_interrupt( This function is used for security-critical approval decisions for optimization operations that may affect machine parameters. """ + import yaml as _yaml + + # Format config as YAML for human readability + config_display = _yaml.dump(optimization_config, default_flow_style=False, sort_keys=False) + # Format machine state if available machine_state_section = "" if machine_state_details: @@ -585,10 +614,9 @@ def create_xopt_approval_interrupt( **Optimization Objective:** {objective} **Strategy:** {strategy.upper()} {machine_state_section} -**Generated XOpt Configuration:** +**Optimization Configuration:** ```yaml -{yaml_config} -``` +{config_display}``` **Review the configuration above carefully.** @@ -602,7 +630,7 @@ def create_xopt_approval_interrupt( "resume_payload": { "approval_type": create_approval_type("xopt_optimizer"), "step_objective": step_objective, - "yaml_config": yaml_config, + "optimization_config": optimization_config, "strategy": strategy, "objective": objective, "machine_state_details": machine_state_details, diff --git a/src/osprey/capabilities/optimization.py b/src/osprey/capabilities/optimization.py index 4059155a9..fe1f229e0 100644 --- a/src/osprey/capabilities/optimization.py +++ b/src/osprey/capabilities/optimization.py @@ -33,12 +33,13 @@ from typing import Any, ClassVar -from langgraph.types import Command +from langgraph.types import Command, interrupt from pydantic import Field from osprey.approval import ( clear_approval_state, create_approval_type, + create_question_interrupt, get_approval_resume_data, handle_service_with_interrupts, ) @@ -50,6 +51,11 @@ from osprey.prompts.loader import get_framework_prompts from osprey.registry import get_registry from osprey.services.xopt_optimizer import XOptExecutionRequest, XOptServiceResult +from osprey.services.xopt_optimizer.config_generation.node import ( + _get_config_generation_config, + _match_environment, +) +from osprey.services.xopt_optimizer.execution.api_client import TuningScriptsClient from osprey.state import StateManager from osprey.utils.config import get_full_configuration from osprey.utils.logger import get_logger @@ -81,8 +87,8 @@ class OptimizationResultContext(CapabilityContext): :type analysis_summary: Dict[str, Any] :param recommendations: List of recommendations from analysis :type recommendations: List[str] - :param generated_yaml: XOpt YAML configuration used - :type generated_yaml: str + :param optimization_config: Optimization config dict submitted to tuning_scripts + :type optimization_config: Dict[str, Any] .. note:: The run_artifact contains the primary optimization outputs that @@ -98,7 +104,7 @@ class OptimizationResultContext(CapabilityContext): total_iterations: int = 0 analysis_summary: dict[str, Any] = Field(default_factory=dict) recommendations: list[str] = Field(default_factory=list) - generated_yaml: str = "" + optimization_config: dict[str, Any] = Field(default_factory=dict) CONTEXT_TYPE: ClassVar[str] = "OPTIMIZATION_RESULT" CONTEXT_CATEGORY: ClassVar[str] = "OPTIMIZATION_DATA" @@ -122,7 +128,7 @@ def get_access_details(self, key: str) -> dict[str, Any]: "total_iterations": f"Completed {self.total_iterations} iterations", "analysis_summary": "Summary of optimization analysis", "recommendations": "List of recommendations from analysis", - "generated_yaml": "XOpt YAML configuration used", + "optimization_config": "Optimization config dict used", "access_pattern": f"context.OPTIMIZATION_RESULT.{key}", } @@ -158,7 +164,7 @@ def _create_optimization_context(service_result: XOptServiceResult) -> Optimizat total_iterations=service_result.total_iterations, analysis_summary=service_result.analysis_summary, recommendations=list(service_result.recommendations), - generated_yaml=service_result.generated_yaml, + optimization_config=service_result.optimization_config, ) @@ -370,6 +376,9 @@ async def execute(self) -> dict[str, Any]: task_objective = self.get_task_objective(default="") capability_contexts = self._state.get("capability_context_data", {}) + # Resolve environment before calling the service (avoids subgraph interrupt) + resolved_env = await self._resolve_environment_if_needed(cap_logger) + # Create execution request execution_request = XOptExecutionRequest( user_query=user_query, @@ -377,6 +386,7 @@ async def execute(self) -> dict[str, Any]: capability_context_data=capability_contexts, require_approval=True, max_iterations=3, + environment_name=resolved_env, ) cap_logger.status("Invoking XOpt optimizer service...") @@ -407,6 +417,71 @@ async def execute(self) -> dict[str, Any]: results_context, ) + # ======================================== + # ENVIRONMENT RESOLUTION + # ======================================== + + async def _resolve_environment_if_needed(self, cap_logger) -> str | None: + """Resolve environment name before service invocation. + + Checks config for a default, then queries the tuning_scripts API. + If multiple valid environments exist, interrupts with a question + so the user can pick one in the chat. + + Returns: + Resolved environment name, or None if resolution is not possible. + """ + gen_config = _get_config_generation_config() + default_env = gen_config.get("default_environment") + if default_env: + return default_env + + # Query the tuning_scripts API + try: + client = TuningScriptsClient() + environments = await client.list_environments() + except Exception as e: + cap_logger.warning(f"Could not query environments from API: {e}") + return None + + valid_envs = [env for env in environments if env.get("valid", False)] + + if not valid_envs: + cap_logger.warning("No valid environments returned by API") + return None + + # Single valid environment β€” auto-select + if len(valid_envs) == 1: + cap_logger.info(f"Auto-selected environment: {valid_envs[0]['name']}") + return valid_envs[0]["name"] + + # Multiple environments β€” ask the user via question interrupt + env_lines = [] + for i, env in enumerate(valid_envs, 1): + source = f" [{env['source']}]" if env.get("source") else "" + desc = env.get("description", "") + env_lines.append(f" {i}. **{env['name']}** β€” {desc}{source}") + + question = ( + "Multiple optimization environments are available. " + "Please select one by number or name:\n\n" + "\n".join(env_lines) + ) + option_names = [env["name"] for env in valid_envs] + + cap_logger.info("Asking user to select optimization environment...") + user_choice = interrupt(create_question_interrupt(question, option_names)) + + # Parse the user's response + choice = str(user_choice).strip() + selected = _match_environment(choice, valid_envs) + + if not selected: + cap_logger.warning(f"Could not match '{choice}' to an available environment") + return None + + cap_logger.info(f"User selected environment: {selected['name']}") + return selected["name"] + # ======================================== # ERROR CLASSIFICATION # ======================================== @@ -423,9 +498,9 @@ def classify_error(exc: Exception, context: dict) -> ErrorClassification: :rtype: ErrorClassification """ from osprey.services.xopt_optimizer.exceptions import ( + ConfigGenerationError, MachineStateAssessmentError, XOptExecutionError, - YamlGenerationError, ) if isinstance(exc, MachineStateAssessmentError): @@ -438,7 +513,7 @@ def classify_error(exc: Exception, context: dict) -> ErrorClassification: }, ) - elif isinstance(exc, YamlGenerationError): + elif isinstance(exc, ConfigGenerationError): return ErrorClassification( severity=ErrorSeverity.REPLANNING, user_message=f"Failed to generate optimization configuration: {exc}", diff --git a/src/osprey/infrastructure/gateway.py b/src/osprey/infrastructure/gateway.py index 95edb03f8..63d5b9ef3 100644 --- a/src/osprey/infrastructure/gateway.py +++ b/src/osprey/infrastructure/gateway.py @@ -191,8 +191,27 @@ async def _handle_interrupt_flow( Gateway detects approval/rejection and uses Command(update=...) to inject interrupt payload into agent state while resuming execution. + + Question interrupts (type=="question") bypass approval classification and + return the raw user text via Command(resume=user_input). """ + # Extract interrupt payload to check for question type + success, interrupt_payload = self._extract_resume_payload(compiled_graph, config) + + if success and interrupt_payload.get("type") == "question": + emitter.emit( + StatusEvent( + component="gateway", + message="Detected question interrupt - passing user input as resume value", + level="info", + ) + ) + return GatewayResult( + resume_command=Command(resume=user_input.strip()), + is_interrupt_resume=True, + ) + # Detect approval or rejection approval_data = self._detect_approval_response(user_input) diff --git a/src/osprey/services/xopt_optimizer/config_generation/node.py b/src/osprey/services/xopt_optimizer/config_generation/node.py index 074d333d5..a7078260f 100644 --- a/src/osprey/services/xopt_optimizer/config_generation/node.py +++ b/src/osprey/services/xopt_optimizer/config_generation/node.py @@ -13,8 +13,6 @@ from typing import Any -from langgraph.types import interrupt - from osprey.utils.config import get_model_config, get_xopt_optimizer_config from osprey.utils.logger import get_logger @@ -168,37 +166,17 @@ async def _resolve_environment( node_logger.info(f"Auto-selected environment: {valid_envs[0]['name']}") return - # Multiple environments β€” ask the user - env_lines = [] - for i, env in enumerate(valid_envs, 1): - source = f" [{env['source']}]" if env.get("source") else "" - desc = env.get("description", "") - env_lines.append(f" {i}. **{env['name']}** β€” {desc}{source}") - - prompt = ( - "Multiple optimization environments are available. " - "Please select one by number or name:\n\n" - + "\n".join(env_lines) + # Multiple environments β€” auto-select the first valid one as fallback. + # (Environment selection is normally handled at the capability level via + # a question interrupt; this path runs only when the capability didn't + # resolve an environment ahead of time.) + first = valid_envs[0] + config["environment_name"] = first["name"] + node_logger.info( + f"Auto-selected first valid environment: {first['name']} " + f"(from {len(valid_envs)} available)" ) - node_logger.info("Asking user to select optimization environment...") - user_choice = interrupt({"question": prompt, "environments": valid_envs}) - - # Parse the user's response - choice = str(user_choice).strip() - selected = _match_environment(choice, valid_envs) - - if not selected: - raise ConfigGenerationError( - f"Could not match '{choice}' to an available environment. " - f"Valid options: {[e['name'] for e in valid_envs]}", - generated_config=config, - validation_errors=[f"Invalid environment selection: {choice}"], - ) - - config["environment_name"] = selected["name"] - node_logger.info(f"User selected environment: {selected['name']}") - def _match_environment( choice: str, environments: list[dict[str, Any]] @@ -300,6 +278,10 @@ async def config_generation_node(state: XOptExecutionState) -> dict[str, Any]: else: optimization_config = _generate_placeholder_config(objective, strategy) + # Honor pre-resolved environment from capability (highest priority) + if request and request.environment_name and not optimization_config.get("environment_name"): + optimization_config["environment_name"] = request.environment_name + # Apply defaults from config if not already set by the generator default_env = gen_config.get("default_environment") if default_env and not optimization_config.get("environment_name"): @@ -354,11 +336,6 @@ async def config_generation_node(state: XOptExecutionState) -> dict[str, Any]: raise except Exception as e: - # Re-raise GraphInterrupt β€” it's not an error, it's LangGraph - # pausing the graph to wait for user input (e.g. environment selection). - if e.__class__.__name__ == "GraphInterrupt": - raise - node_logger.warning(f"Config generation failed: {e}") error = XOptError( diff --git a/src/osprey/services/xopt_optimizer/models.py b/src/osprey/services/xopt_optimizer/models.py index 92a922991..b8308fe9d 100644 --- a/src/osprey/services/xopt_optimizer/models.py +++ b/src/osprey/services/xopt_optimizer/models.py @@ -167,6 +167,9 @@ class XOptExecutionRequest(BaseModel): None, description="Capability context data from main graph state" ) + # Pre-resolved environment (set by capability before service invocation) + environment_name: str | None = Field(None, description="Pre-resolved environment name") + # Standard Osprey fields require_approval: bool = Field(default=True) session_context: dict[str, Any] | None = Field( From 8343234a231cb258029f85dc69cc7ff648b1e44a Mon Sep 17 00:00:00 2001 From: Gianluca Martino Date: Fri, 13 Mar 2026 15:49:18 -0700 Subject: [PATCH 13/14] feat(xopt): add API-driven objective resolution and fix iteration interrupt --- src/osprey/capabilities/optimization.py | 135 ++++++++++++++---- .../xopt_optimizer/config_generation/node.py | 22 ++- .../xopt_optimizer/execution/api_client.py | 12 ++ src/osprey/services/xopt_optimizer/models.py | 3 +- 4 files changed, 140 insertions(+), 32 deletions(-) diff --git a/src/osprey/capabilities/optimization.py b/src/osprey/capabilities/optimization.py index fe1f229e0..556cc9795 100644 --- a/src/osprey/capabilities/optimization.py +++ b/src/osprey/capabilities/optimization.py @@ -330,40 +330,18 @@ async def execute(self) -> dict[str, Any]: except Exception as e: # Import here to avoid circular imports from langgraph.errors import GraphInterrupt - from langgraph.types import interrupt # Check if this is a GraphInterrupt (service looped and needs approval for next iteration) if isinstance(e, GraphInterrupt): cap_logger.info( "XOptOptimizer: Service completed iteration and requests approval for next" ) - - try: - # Extract interrupt data from GraphInterrupt - interrupt_data = e.args[0][0].value - cap_logger.debug( - f"XOptOptimizer: Extracted interrupt data with keys: {list(interrupt_data.keys())}" - ) - - # Re-raise interrupt in main graph context for next iteration - cap_logger.info( - "⏸️ XOptOptimizer: Creating approval interrupt for next iteration" - ) - interrupt(interrupt_data) - - # This line should never be reached - cap_logger.error( - "UNEXPECTED: interrupt() returned instead of pausing execution" - ) - raise RuntimeError("Interrupt mechanism failed in XOptOptimizer") - - except (IndexError, KeyError, AttributeError) as extract_error: - cap_logger.error( - f"XOptOptimizer: Failed to extract interrupt data: {extract_error}" - ) - raise RuntimeError( - f"XOptOptimizer: Failed to handle service interrupt: {extract_error}" - ) from extract_error + # Re-raise the GraphInterrupt directly so the main graph + # pauses and presents the approval to the user. + # NOTE: Do NOT use interrupt() here β€” it has context-dependent + # behavior and may return instead of raising inside an + # exception handler, breaking the interrupt mechanism. + raise else: # Re-raise non-interrupt exceptions raise @@ -379,6 +357,9 @@ async def execute(self) -> dict[str, Any]: # Resolve environment before calling the service (avoids subgraph interrupt) resolved_env = await self._resolve_environment_if_needed(cap_logger) + # Resolve objective from the environment's available objectives + resolved_objective = await self._resolve_objective_if_needed(resolved_env, cap_logger) + # Create execution request execution_request = XOptExecutionRequest( user_query=user_query, @@ -387,6 +368,7 @@ async def execute(self) -> dict[str, Any]: require_approval=True, max_iterations=3, environment_name=resolved_env, + objective_name=resolved_objective, ) cap_logger.status("Invoking XOpt optimizer service...") @@ -417,6 +399,103 @@ async def execute(self) -> dict[str, Any]: results_context, ) + # ======================================== + # OBJECTIVE RESOLUTION + # ======================================== + + async def _resolve_objective_if_needed( + self, environment_name: str | None, cap_logger + ) -> str | None: + """Resolve the objective_name from the environment's available objectives. + + Queries ``GET /environments/{name}`` for available objectives, then: + - Auto-selects if there is a default or only one objective + - Asks the user to pick if there are multiple + + Returns: + Selected objective name, or None if resolution is not possible. + """ + if not environment_name: + return None + + try: + client = TuningScriptsClient() + env_details = await client.get_environment_details(environment_name) + except Exception as e: + cap_logger.warning(f"Could not query environment details from API: {e}") + return None + + available = env_details.get("available_objectives", []) + default = env_details.get("default_objective") + metadata = env_details.get("observables_metadata", {}) + + if not available: + cap_logger.warning(f"No available objectives for environment '{environment_name}'") + return None + + # Single objective β€” auto-select + if len(available) == 1: + cap_logger.info(f"Auto-selected objective: {available[0]}") + return available[0] + + # Default exists and is in the available list β€” auto-select + if default and default in available: + cap_logger.info(f"Auto-selected default objective: {default}") + return default + + # Multiple objectives, no default β€” ask the user + obj_lines = [] + for i, obj_name in enumerate(available, 1): + meta = metadata.get(obj_name, {}) + desc = meta.get("description", "") + units = meta.get("units", "") + direction = meta.get("direction", "") + detail_parts = [p for p in [desc, units, direction] if p] + detail = f" β€” {', '.join(detail_parts)}" if detail_parts else "" + obj_lines.append(f" {i}. **{obj_name}**{detail}") + + question = ( + f"Multiple optimization objectives are available for environment " + f"'{environment_name}'. Please select one:\n\n" + "\n".join(obj_lines) + ) + + cap_logger.info("Asking user to select optimization objective...") + user_choice = interrupt(create_question_interrupt(question, available)) + + choice = str(user_choice).strip() + + # Match by number or name + selected = self._match_objective(choice, available) + if not selected: + cap_logger.warning(f"Could not match '{choice}' to an available objective") + return None + + cap_logger.info(f"User selected objective: {selected}") + return selected + + @staticmethod + def _match_objective(choice: str, available: list[str]) -> str | None: + """Match a user's choice (number or name) to an available objective.""" + # Try as a 1-based index + try: + idx = int(choice) - 1 + if 0 <= idx < len(available): + return available[idx] + except ValueError: + pass + + # Try exact match + if choice in available: + return choice + + # Try case-insensitive prefix match + lower = choice.lower() + for obj in available: + if obj.lower().startswith(lower): + return obj + + return None + # ======================================== # ENVIRONMENT RESOLUTION # ======================================== diff --git a/src/osprey/services/xopt_optimizer/config_generation/node.py b/src/osprey/services/xopt_optimizer/config_generation/node.py index a7078260f..074703a42 100644 --- a/src/osprey/services/xopt_optimizer/config_generation/node.py +++ b/src/osprey/services/xopt_optimizer/config_generation/node.py @@ -72,7 +72,11 @@ def _get_config_generation_config() -> dict[str, Any]: } -def _generate_placeholder_config(objective: str, strategy: XOptStrategy) -> dict[str, Any]: +def _generate_placeholder_config( + objective: str, + strategy: XOptStrategy, + objective_name: str | None = None, +) -> dict[str, Any]: """Generate a placeholder optimization config dict. Used when config_generation.mode is "mock". @@ -80,7 +84,7 @@ def _generate_placeholder_config(objective: str, strategy: XOptStrategy) -> dict DO NOT add accelerator-specific parameters without operator input. """ algorithm = "random" if strategy == XOptStrategy.EXPLORATION else "upper_confidence_bound" - return { + config: dict[str, Any] = { "algorithm": algorithm, "n_iterations": 20, "note": ( @@ -88,6 +92,9 @@ def _generate_placeholder_config(objective: str, strategy: XOptStrategy) -> dict "Set config_generation.mode: 'structured' to use the LLM agent." ), } + if objective_name: + config["objective_name"] = objective_name + return config async def _generate_config_with_agent( @@ -268,6 +275,9 @@ async def config_generation_node(state: XOptExecutionState) -> dict[str, Any]: objective = request.optimization_objective if request else "Unknown objective" try: + # Extract pre-resolved objective_name from request + objective_name = request.objective_name if request else None + # Generate config based on mode if mode == "structured": optimization_config = await _generate_config_with_agent( @@ -276,12 +286,18 @@ async def config_generation_node(state: XOptExecutionState) -> dict[str, Any]: model_config=gen_config.get("model_config"), ) else: - optimization_config = _generate_placeholder_config(objective, strategy) + optimization_config = _generate_placeholder_config( + objective, strategy, objective_name=objective_name + ) # Honor pre-resolved environment from capability (highest priority) if request and request.environment_name and not optimization_config.get("environment_name"): optimization_config["environment_name"] = request.environment_name + # Honor pre-resolved objective_name from capability (highest priority) + if objective_name and not optimization_config.get("objective_name"): + optimization_config["objective_name"] = objective_name + # Apply defaults from config if not already set by the generator default_env = gen_config.get("default_environment") if default_env and not optimization_config.get("environment_name"): diff --git a/src/osprey/services/xopt_optimizer/execution/api_client.py b/src/osprey/services/xopt_optimizer/execution/api_client.py index 030069107..232cf942a 100644 --- a/src/osprey/services/xopt_optimizer/execution/api_client.py +++ b/src/osprey/services/xopt_optimizer/execution/api_client.py @@ -95,6 +95,18 @@ async def list_environments(self) -> list[dict[str, Any]]: """ return await self._get("/environments") + async def get_environment_details(self, name: str) -> dict[str, Any]: + """Get detailed info for a specific environment (variables, objectives). + + Returns: + Environment details dict with ``available_objectives``, + ``default_objective``, ``observables_metadata``, etc. + + Raises: + TuningScriptsAPIError: If the API is unreachable or env not found. + """ + return await self._get(f"/environments/{name}") + async def submit_config(self, config: dict[str, Any]) -> str: """Submit an OptimizationConfig dict to start an optimization. diff --git a/src/osprey/services/xopt_optimizer/models.py b/src/osprey/services/xopt_optimizer/models.py index b8308fe9d..1a06f04e1 100644 --- a/src/osprey/services/xopt_optimizer/models.py +++ b/src/osprey/services/xopt_optimizer/models.py @@ -167,8 +167,9 @@ class XOptExecutionRequest(BaseModel): None, description="Capability context data from main graph state" ) - # Pre-resolved environment (set by capability before service invocation) + # Pre-resolved environment and objective (set by capability before service invocation) environment_name: str | None = Field(None, description="Pre-resolved environment name") + objective_name: str | None = Field(None, description="Pre-resolved objective name from environment") # Standard Osprey fields require_approval: bool = Field(default=True) From 50966ef88221ba6d845af2698cd46eda709bc868 Mon Sep 17 00:00:00 2001 From: Gianluca Martino Date: Fri, 13 Mar 2026 15:56:27 -0700 Subject: [PATCH 14/14] style: fix import sorting and formatting --- .../xopt_optimizer/config_generation/node.py | 31 +++++++------- .../xopt_optimizer/execution/api_client.py | 11 +++-- src/osprey/services/xopt_optimizer/models.py | 4 +- .../test_execution_api_client.py | 42 +++++++++++++++---- 4 files changed, 57 insertions(+), 31 deletions(-) diff --git a/src/osprey/services/xopt_optimizer/config_generation/node.py b/src/osprey/services/xopt_optimizer/config_generation/node.py index 074703a42..3347c825f 100644 --- a/src/osprey/services/xopt_optimizer/config_generation/node.py +++ b/src/osprey/services/xopt_optimizer/config_generation/node.py @@ -23,12 +23,14 @@ logger = get_logger("xopt_optimizer") # Allowed algorithm values for validation -_ALLOWED_ALGORITHMS = frozenset({ - "upper_confidence_bound", - "expected_improvement", - "mobo", - "random", -}) +_ALLOWED_ALGORITHMS = frozenset( + { + "upper_confidence_bound", + "expected_improvement", + "mobo", + "random", + } +) def _get_config_generation_config() -> dict[str, Any]: @@ -126,9 +128,7 @@ async def _generate_config_with_agent( return _generate_placeholder_config(objective, strategy) -async def _resolve_environment( - config: dict[str, Any], node_logger: Any -) -> None: +async def _resolve_environment(config: dict[str, Any], node_logger: Any) -> None: """Resolve environment_name if missing by asking the user. Queries the tuning_scripts API for available environments, then @@ -180,14 +180,11 @@ async def _resolve_environment( first = valid_envs[0] config["environment_name"] = first["name"] node_logger.info( - f"Auto-selected first valid environment: {first['name']} " - f"(from {len(valid_envs)} available)" + f"Auto-selected first valid environment: {first['name']} (from {len(valid_envs)} available)" ) -def _match_environment( - choice: str, environments: list[dict[str, Any]] -) -> dict[str, Any] | None: +def _match_environment(choice: str, environments: list[dict[str, Any]]) -> dict[str, Any] | None: """Match a user's choice (number or name) to an environment. Returns the matched environment dict, or None if no match found. @@ -291,7 +288,11 @@ async def config_generation_node(state: XOptExecutionState) -> dict[str, Any]: ) # Honor pre-resolved environment from capability (highest priority) - if request and request.environment_name and not optimization_config.get("environment_name"): + if ( + request + and request.environment_name + and not optimization_config.get("environment_name") + ): optimization_config["environment_name"] = request.environment_name # Honor pre-resolved objective_name from capability (highest priority) diff --git a/src/osprey/services/xopt_optimizer/execution/api_client.py b/src/osprey/services/xopt_optimizer/execution/api_client.py index 232cf942a..eb2e91f1f 100644 --- a/src/osprey/services/xopt_optimizer/execution/api_client.py +++ b/src/osprey/services/xopt_optimizer/execution/api_client.py @@ -14,8 +14,7 @@ import aiohttp except ImportError as e: raise ImportError( - "aiohttp is required for TuningScriptsClient. " - "Install it with: pip install aiohttp" + "aiohttp is required for TuningScriptsClient. Install it with: pip install aiohttp" ) from e from osprey.utils.config import get_full_configuration @@ -64,7 +63,9 @@ def __init__( ): api_config = self._load_api_config() - self.base_url = (base_url or api_config.get("base_url", "http://localhost:8001")).rstrip("/") + self.base_url = (base_url or api_config.get("base_url", "http://localhost:8001")).rstrip( + "/" + ) self.poll_interval = poll_interval_seconds or api_config.get("poll_interval_seconds", 5.0) self.timeout = timeout_seconds or api_config.get("timeout_seconds", 3600) @@ -189,9 +190,7 @@ async def poll_until_complete(self, job_id: str) -> dict[str, Any]: await asyncio.sleep(self.poll_interval) elapsed += self.poll_interval - raise TuningScriptsAPIError( - f"Timeout waiting for job {job_id} after {self.timeout}s" - ) + raise TuningScriptsAPIError(f"Timeout waiting for job {job_id} after {self.timeout}s") async def cancel(self, job_id: str) -> dict[str, Any]: """Cancel a running optimization job.""" diff --git a/src/osprey/services/xopt_optimizer/models.py b/src/osprey/services/xopt_optimizer/models.py index 1a06f04e1..b401c0038 100644 --- a/src/osprey/services/xopt_optimizer/models.py +++ b/src/osprey/services/xopt_optimizer/models.py @@ -169,7 +169,9 @@ class XOptExecutionRequest(BaseModel): # Pre-resolved environment and objective (set by capability before service invocation) environment_name: str | None = Field(None, description="Pre-resolved environment name") - objective_name: str | None = Field(None, description="Pre-resolved objective name from environment") + objective_name: str | None = Field( + None, description="Pre-resolved objective name from environment" + ) # Standard Osprey fields require_approval: bool = Field(default=True) diff --git a/tests/services/xopt_optimizer/test_execution_api_client.py b/tests/services/xopt_optimizer/test_execution_api_client.py index 2c9ef1091..0e8cbff15 100644 --- a/tests/services/xopt_optimizer/test_execution_api_client.py +++ b/tests/services/xopt_optimizer/test_execution_api_client.py @@ -12,11 +12,11 @@ TuningScriptsClient, ) - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _make_mock_response(status: int = 200, json_data: dict | None = None, text: str = ""): """Create a mock aiohttp response.""" resp = MagicMock() @@ -67,7 +67,10 @@ async def test_health_check_success(self): mock_resp = _make_mock_response(200, {"status": "ok"}) mock_session = _mock_session_for(mock_resp, "get") - with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + with patch( + "osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", + return_value=mock_session, + ): result = await client.health_check() assert result == {"status": "ok"} @@ -78,7 +81,10 @@ async def test_health_check_failure(self): mock_resp = _make_mock_response(503, {"detail": "unhealthy"}) mock_session = _mock_session_for(mock_resp, "get") - with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + with patch( + "osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", + return_value=mock_session, + ): with pytest.raises(TuningScriptsAPIError) as exc_info: await client.health_check() @@ -90,7 +96,10 @@ async def test_submit_config_returns_job_id(self): mock_resp = _make_mock_response(200, {"job_id": "abc-123", "status": "submitted"}) mock_session = _mock_session_for(mock_resp, "post") - with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + with patch( + "osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", + return_value=mock_session, + ): job_id = await client.submit_config({"algorithm": "random", "n_iterations": 10}) assert job_id == "abc-123" @@ -101,7 +110,10 @@ async def test_submit_yaml_returns_job_id(self): mock_resp = _make_mock_response(200, {"job_id": "abc-123", "status": "submitted"}) mock_session = _mock_session_for(mock_resp, "post") - with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + with patch( + "osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", + return_value=mock_session, + ): job_id = await client.submit_yaml("generator:\n name: random\n") assert job_id == "abc-123" @@ -112,7 +124,10 @@ async def test_submit_yaml_with_iterations(self): mock_resp = _make_mock_response(200, {"job_id": "abc-123", "status": "submitted"}) mock_session = _mock_session_for(mock_resp, "post") - with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + with patch( + "osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", + return_value=mock_session, + ): job_id = await client.submit_yaml("yaml: data\n", n_iterations=10) assert job_id == "abc-123" @@ -123,7 +138,10 @@ async def test_get_status(self): mock_resp = _make_mock_response(200, {"job_id": "abc", "status": "running"}) mock_session = _mock_session_for(mock_resp, "get") - with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + with patch( + "osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", + return_value=mock_session, + ): result = await client.get_status("abc") assert result["status"] == "running" @@ -141,7 +159,10 @@ async def test_get_full_state(self): mock_resp = _make_mock_response(200, full_data) mock_session = _mock_session_for(mock_resp, "get") - with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + with patch( + "osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", + return_value=mock_session, + ): result = await client.get_full_state("abc") assert result["data"] == [{"x": 1.0, "f": 0.5}] @@ -152,7 +173,10 @@ async def test_cancel(self): mock_resp = _make_mock_response(200, {"status": "cancelled", "job_id": "abc"}) mock_session = _mock_session_for(mock_resp, "post") - with patch("osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", return_value=mock_session): + with patch( + "osprey.services.xopt_optimizer.execution.api_client.aiohttp.ClientSession", + return_value=mock_session, + ): result = await client.cancel("abc") assert result["status"] == "cancelled"