diff --git a/debot/agent/loop.py b/debot/agent/loop.py index b097211..0289b56 100644 --- a/debot/agent/loop.py +++ b/debot/agent/loop.py @@ -284,36 +284,104 @@ async def _process_message(self, msg: InboundMessage) -> OutboundMessage | None: messages=messages, tools=self.tools.get_definitions(), model=chosen_model ) - # Auto-escalate: if model failed, try next tier (up to 3 escalations) - if ( - _debot_rust - and current_tier - and response.finish_reason in ("error", "context_length_exceeded") - ): - for _esc in range(3): - fb_json = _debot_rust.get_fallback_model(current_tier) - if not fb_json: - break - fb = json.loads(fb_json) - logger.warning( - "Escalating: {} ({}) failed [{}] → {} ({})", - chosen_model, - current_tier, - response.finish_reason, - fb["model"], - fb["tier"], - ) + # Auto-reroute on failure + _fail_reasons = ("error", "context_length_exceeded", "insufficient_credits") + if _debot_rust and current_tier and response.finish_reason in _fail_reasons: + if response.finish_reason == "insufficient_credits": + # Billing error strategy: + # 1) Try same-tier alternatives (cheaper models that handle same complexity) + # 2) If all same-tier alternatives exhausted, escalate to next tier + tried = {chosen_model} + rerouted = False + # Phase 1: same-tier alternatives sorted by cost ascending try: - _debot_rust.record_escalation() + alts = json.loads(_debot_rust.get_tier_alternatives(current_tier)) except Exception: - pass - chosen_model = fb["model"] - current_tier = fb["tier"] - response = await self.provider.chat( - messages=messages, tools=self.tools.get_definitions(), model=chosen_model - ) - if response.finish_reason not in ("error", "context_length_exceeded"): - break + alts = [] + for alt in alts: + if alt["model"] in tried: + continue + tried.add(alt["model"]) + logger.warning( + "Billing fallback: {} failed [{}] → trying same-tier {} (${:.2f}/M)", + chosen_model, response.finish_reason, + alt["model"], alt["cost"], + ) + try: + _debot_rust.record_escalation() + except Exception: + pass + chosen_model = alt["model"] + response = await self.provider.chat( + messages=messages, tools=self.tools.get_definitions(), model=chosen_model + ) + if response.finish_reason not in _fail_reasons: + rerouted = True + break + # Phase 2: same-tier exhausted, escalate up tier by tier + if not rerouted and response.finish_reason in _fail_reasons: + esc_tier = current_tier + for _esc in range(3): + fb_json = _debot_rust.get_fallback_model(esc_tier) + if not fb_json: + break + fb = json.loads(fb_json) + if fb["model"] in tried: + esc_tier = fb["tier"] + continue + tried.add(fb["model"]) + logger.warning( + "Billing fallback: same-tier exhausted, escalating → {} ({})", + fb["model"], fb["tier"], + ) + try: + _debot_rust.record_escalation() + except Exception: + pass + chosen_model = fb["model"] + current_tier = fb["tier"] + esc_tier = fb["tier"] + response = await self.provider.chat( + messages=messages, tools=self.tools.get_definitions(), model=chosen_model + ) + if response.finish_reason not in _fail_reasons: + break + else: + # Context / other errors → escalate to more capable model + for _esc in range(3): + fb_json = _debot_rust.get_fallback_model(current_tier) + if not fb_json: + break + fb = json.loads(fb_json) + logger.warning( + "Escalating: {} ({}) failed [{}] → {} ({})", + chosen_model, current_tier, response.finish_reason, + fb["model"], fb["tier"], + ) + try: + _debot_rust.record_escalation() + except Exception: + pass + chosen_model = fb["model"] + current_tier = fb["tier"] + response = await self.provider.chat( + messages=messages, tools=self.tools.get_definitions(), model=chosen_model + ) + if response.finish_reason not in _fail_reasons: + break + + # If all fallbacks exhausted, give a friendly error instead of raw API dump + if response.finish_reason == "insufficient_credits": + response.content = ( + "All available models failed due to insufficient credits. " + "Please top up your API provider credits and try again.\n\n" + "OpenRouter: https://openrouter.ai/settings/credits" + ) + elif response.finish_reason == "context_length_exceeded": + response.content = ( + "The conversation is too long for all available models. " + "Try starting a new conversation or use /compact to compress history." + ) # Handle tool calls if response.has_tool_calls: diff --git a/debot/cli/commands.py b/debot/cli/commands.py index 741e30a..1b66681 100644 --- a/debot/cli/commands.py +++ b/debot/cli/commands.py @@ -315,7 +315,8 @@ def gateway( raise typer.Exit(1) provider = LiteLLMProvider( - api_key=api_key, api_base=api_base, default_model=config.agents.defaults.model + api_key=api_key, api_base=api_base, default_model=config.agents.defaults.model, + all_api_keys=config.get_all_api_keys(), ) # Create agent @@ -423,7 +424,8 @@ def agent( bus = MessageBus() provider = LiteLLMProvider( - api_key=api_key, api_base=api_base, default_model=config.agents.defaults.model + api_key=api_key, api_base=api_base, default_model=config.agents.defaults.model, + all_api_keys=config.get_all_api_keys(), ) agent_loop = AgentLoop( diff --git a/debot/config/schema.py b/debot/config/schema.py index dc2ebfb..4738ef9 100644 --- a/debot/config/schema.py +++ b/debot/config/schema.py @@ -134,6 +134,23 @@ def get_api_key(self) -> str | None: or None ) + def get_all_api_keys(self) -> dict[str, str]: + """Return all configured provider API keys (non-empty only).""" + keys: dict[str, str] = {} + if self.providers.openrouter.api_key: + keys["openrouter"] = self.providers.openrouter.api_key + if self.providers.anthropic.api_key: + keys["anthropic"] = self.providers.anthropic.api_key + if self.providers.openai.api_key: + keys["openai"] = self.providers.openai.api_key + if self.providers.gemini.api_key: + keys["gemini"] = self.providers.gemini.api_key + if self.providers.groq.api_key: + keys["groq"] = self.providers.groq.api_key + if self.providers.zhipu.api_key: + keys["zhipu"] = self.providers.zhipu.api_key + return keys + def get_api_base(self) -> str | None: """Get API base URL if using OpenRouter, Zhipu or vLLM.""" if self.providers.openrouter.api_key: diff --git a/debot/providers/litellm_provider.py b/debot/providers/litellm_provider.py index 8cf0fa9..5ea5bb1 100644 --- a/debot/providers/litellm_provider.py +++ b/debot/providers/litellm_provider.py @@ -22,6 +22,7 @@ def __init__( api_key: str | None = None, api_base: str | None = None, default_model: str = "anthropic/claude-opus-4-5", + all_api_keys: dict[str, str] | None = None, ): super().__init__(api_key, api_base) self.default_model = default_model @@ -34,27 +35,42 @@ def __init__( # Track if using custom endpoint (vLLM, etc.) self.is_vllm = bool(api_base) and not self.is_openrouter - # Configure LiteLLM based on provider + # Set ALL configured provider API keys so LiteLLM can route across + # providers during fallback (e.g. OpenRouter credits exhausted → Anthropic direct). + _key_env_map = { + "openrouter": "OPENROUTER_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "openai": "OPENAI_API_KEY", + "gemini": "GEMINI_API_KEY", + "groq": "GROQ_API_KEY", + "zhipu": "ZHIPUAI_API_KEY", + } + if all_api_keys: + for provider_name, key in all_api_keys.items(): + env_var = _key_env_map.get(provider_name) + if env_var and key: + os.environ.setdefault(env_var, key) + + # Configure primary provider key (overrides setdefault above) if api_key: if self.is_openrouter: - # OpenRouter mode - set key os.environ["OPENROUTER_API_KEY"] = api_key elif self.is_vllm: - # vLLM/custom endpoint - uses OpenAI-compatible API os.environ["OPENAI_API_KEY"] = api_key elif "anthropic" in default_model: - os.environ.setdefault("ANTHROPIC_API_KEY", api_key) + os.environ["ANTHROPIC_API_KEY"] = api_key elif "openai" in default_model or "gpt" in default_model: - os.environ.setdefault("OPENAI_API_KEY", api_key) + os.environ["OPENAI_API_KEY"] = api_key elif "gemini" in default_model.lower(): - os.environ.setdefault("GEMINI_API_KEY", api_key) + os.environ["GEMINI_API_KEY"] = api_key elif "zhipu" in default_model or "glm" in default_model or "zai" in default_model: - os.environ.setdefault("ZHIPUAI_API_KEY", api_key) + os.environ["ZHIPUAI_API_KEY"] = api_key elif "groq" in default_model: - os.environ.setdefault("GROQ_API_KEY", api_key) + os.environ["GROQ_API_KEY"] = api_key - if api_base: - litellm.api_base = api_base + # Do NOT set litellm.api_base globally — it would route ALL calls + # (including cross-provider fallback) through OpenRouter. + # api_base is passed per-call in chat() kwargs instead. # Disable LiteLLM logging noise litellm.suppress_debug_info = True @@ -82,9 +98,18 @@ async def chat( """ model = model or self.default_model - # For OpenRouter, prefix model name if not already prefixed + # For OpenRouter, prefix model name — but skip if the model's native + # provider has a direct API key configured (cross-provider fallback). if self.is_openrouter and not model.startswith("openrouter/"): - model = f"openrouter/{model}" + # Check if a direct provider key is available for this model + _direct_available = ( + (model.startswith("anthropic/") and os.environ.get("ANTHROPIC_API_KEY")) + or (model.startswith("openai/") and os.environ.get("OPENAI_API_KEY")) + or (model.startswith("groq/") and os.environ.get("GROQ_API_KEY")) + or (model.startswith("gemini/") and os.environ.get("GEMINI_API_KEY")) + ) + if not _direct_available: + model = f"openrouter/{model}" # For Zhipu/Z.ai, ensure prefix is present # Handle cases like "glm-4.7-flash" -> "zhipu/glm-4.7-flash" @@ -111,8 +136,10 @@ async def chat( "temperature": temperature, } - # Pass api_base directly for custom endpoints (vLLM, etc.) - if self.api_base: + # Pass api_base for custom endpoints — but NOT for direct-provider calls + # (cross-provider fallback should hit the native API, not OpenRouter). + _is_direct = not model.startswith("openrouter/") and not model.startswith("hosted_vllm/") + if self.api_base and not _is_direct: kwargs["api_base"] = self.api_base if tools: @@ -124,18 +151,21 @@ async def chat( return self._parse_response(response) except Exception as e: err_str = str(e).lower() - # Classify context-window / token-limit errors for auto-escalation + # Billing check FIRST — billing errors often contain context-related + # words like "max_tokens" (e.g. "requires more credits, or fewer + # max_tokens"), so must be checked before context keywords. + billing_keywords = ( + "credits", "afford", "402", "billing", + "payment", "quota", "budget", + ) context_keywords = ( - "context_length", - "context window", - "maximum context", - "token limit", - "too many tokens", - "max_tokens", - "input too long", - "reduce your prompt", + "context_length", "context window", "maximum context", + "token limit", "too many tokens", + "input too long", "reduce your prompt", ) - if any(kw in err_str for kw in context_keywords): + if any(kw in err_str for kw in billing_keywords): + finish_reason = "insufficient_credits" + elif any(kw in err_str for kw in context_keywords): finish_reason = "context_length_exceeded" else: finish_reason = "error" diff --git a/pyproject.toml b/pyproject.toml index 71474b8..6fde117 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "debot" -version = "0.1.2" +version = "0.1.3" description = "A lightweight and secure personal AI assistant framework" requires-python = ">=3.11" license = {file = "LICENSE"} diff --git a/rust/src/router/config.rs b/rust/src/router/config.rs index 7b017c0..b034686 100644 --- a/rust/src/router/config.rs +++ b/rust/src/router/config.rs @@ -32,3 +32,44 @@ pub fn next_tier(current: &str) -> Option<&'static str> { let idx = TIER_ORDER.iter().position(|t| *t == current)?; TIER_ORDER.get(idx + 1).copied() } + +/// Returns the next cheaper tier for downgrade, or None if already at bottom. +pub fn prev_tier(current: &str) -> Option<&'static str> { + let idx = TIER_ORDER.iter().position(|t| *t == current)?; + if idx == 0 { + None + } else { + Some(TIER_ORDER[idx - 1]) + } +} + +/// Alternative models per tier, sorted by cost ascending (cheapest first). +/// Includes models from multiple providers for cross-provider billing fallback. +pub fn tier_alternatives() -> HashMap<&'static str, Vec<&'static str>> { + let mut m = HashMap::new(); + m.insert("SIMPLE", vec![ + "groq/llama-3.3-70b-versatile", // free tier + "deepseek/deepseek-chat", // $0.42 + "openai/gpt-4o-mini", // $0.60 + "openai/gpt-3.5-turbo", // $1.50 + ]); + m.insert("MEDIUM", vec![ + "groq/llama-3.3-70b-versatile", // free tier + "deepseek/deepseek-chat", // $0.42 + "openai/gpt-4o-mini", // $0.60 + "minimax/minimax-m2", // $1.20 + ]); + m.insert("COMPLEX", vec![ + "groq/llama-3.3-70b-versatile", // free tier (best-effort) + "anthropic/claude-sonnet-4-5", // $15.00 + "openai/gpt-4o", // $10.00 + "anthropic/claude-opus-4-5", // $25.00 + ]); + m.insert("REASONING", vec![ + "groq/llama-3.3-70b-versatile", // free tier (best-effort) + "deepseek/deepseek-reasoner", // $2.19 + "openai/o3-mini", // $4.40 + "openai/o3", // $8.00 + ]); + m +} diff --git a/rust/src/router/router.rs b/rust/src/router/router.rs index efbc10e..5c6b0e6 100644 --- a/rust/src/router/router.rs +++ b/rust/src/router/router.rs @@ -54,10 +54,35 @@ fn get_fallback_model(current_tier: &str) -> PyResult { } } +/// Returns a JSON array of alternative models for a tier, sorted by cost ascending. +/// Each entry: {"model": "...", "cost": ...} +/// Used for billing fallback: try same-tier alternatives before escalating. +#[pyfunction] +fn get_tier_alternatives(tier: &str) -> PyResult { + let alts = config::tier_alternatives(); + let pricing = catalog::default_pricing(); + let models = alts.get(tier).cloned().unwrap_or_default(); + let mut entries: Vec = models + .iter() + .map(|&model| { + let cost = *pricing.get(model).unwrap_or(&1.0); + json!({"model": model, "cost": cost}) + }) + .collect(); + // Sort by cost ascending (cheapest first) + entries.sort_by(|a, b| { + let ca = a["cost"].as_f64().unwrap_or(f64::MAX); + let cb = b["cost"].as_f64().unwrap_or(f64::MAX); + ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal) + }); + Ok(serde_json::to_string(&entries).unwrap_or_else(|_| "[]".to_string())) +} + pub fn pybindings(m: &pyo3::Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(route_text, m)?)?; m.add_function(wrap_pyfunction!(get_context_length, m)?)?; m.add_function(wrap_pyfunction!(get_fallback_model, m)?)?; + m.add_function(wrap_pyfunction!(get_tier_alternatives, m)?)?; m.add_function(wrap_pyfunction!(metrics::get_router_metrics, m)?)?; m.add_function(wrap_pyfunction!(metrics::reset_router_metrics, m)?)?; m.add_function(wrap_pyfunction!(metrics::get_router_metrics_count, m)?)?;