Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 96 additions & 28 deletions debot/agent/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,36 +284,104 @@ async def _process_message(self, msg: InboundMessage) -> OutboundMessage | None:
messages=messages, tools=self.tools.get_definitions(), model=chosen_model
)

# Auto-escalate: if model failed, try next tier (up to 3 escalations)
if (
_debot_rust
and current_tier
and response.finish_reason in ("error", "context_length_exceeded")
):
for _esc in range(3):
fb_json = _debot_rust.get_fallback_model(current_tier)
if not fb_json:
break
fb = json.loads(fb_json)
logger.warning(
"Escalating: {} ({}) failed [{}] → {} ({})",
chosen_model,
current_tier,
response.finish_reason,
fb["model"],
fb["tier"],
)
# Auto-reroute on failure
_fail_reasons = ("error", "context_length_exceeded", "insufficient_credits")
if _debot_rust and current_tier and response.finish_reason in _fail_reasons:
if response.finish_reason == "insufficient_credits":
# Billing error strategy:
# 1) Try same-tier alternatives (cheaper models that handle same complexity)
# 2) If all same-tier alternatives exhausted, escalate to next tier
tried = {chosen_model}
rerouted = False
# Phase 1: same-tier alternatives sorted by cost ascending
try:
_debot_rust.record_escalation()
alts = json.loads(_debot_rust.get_tier_alternatives(current_tier))
except Exception:
pass
chosen_model = fb["model"]
current_tier = fb["tier"]
response = await self.provider.chat(
messages=messages, tools=self.tools.get_definitions(), model=chosen_model
)
if response.finish_reason not in ("error", "context_length_exceeded"):
break
alts = []
for alt in alts:
if alt["model"] in tried:
continue
tried.add(alt["model"])
logger.warning(
"Billing fallback: {} failed [{}] → trying same-tier {} (${:.2f}/M)",
chosen_model, response.finish_reason,
alt["model"], alt["cost"],
)
try:
_debot_rust.record_escalation()
except Exception:
pass
chosen_model = alt["model"]
response = await self.provider.chat(
messages=messages, tools=self.tools.get_definitions(), model=chosen_model
)
if response.finish_reason not in _fail_reasons:
rerouted = True
break
# Phase 2: same-tier exhausted, escalate up tier by tier
if not rerouted and response.finish_reason in _fail_reasons:
esc_tier = current_tier
for _esc in range(3):
fb_json = _debot_rust.get_fallback_model(esc_tier)
if not fb_json:
break
fb = json.loads(fb_json)
if fb["model"] in tried:
esc_tier = fb["tier"]
continue
tried.add(fb["model"])
logger.warning(
"Billing fallback: same-tier exhausted, escalating → {} ({})",
fb["model"], fb["tier"],
)
try:
_debot_rust.record_escalation()
except Exception:
pass
chosen_model = fb["model"]
current_tier = fb["tier"]
esc_tier = fb["tier"]
response = await self.provider.chat(
messages=messages, tools=self.tools.get_definitions(), model=chosen_model
)
if response.finish_reason not in _fail_reasons:
break
else:
# Context / other errors → escalate to more capable model
for _esc in range(3):
fb_json = _debot_rust.get_fallback_model(current_tier)
if not fb_json:
break
fb = json.loads(fb_json)
logger.warning(
"Escalating: {} ({}) failed [{}] → {} ({})",
chosen_model, current_tier, response.finish_reason,
fb["model"], fb["tier"],
)
try:
_debot_rust.record_escalation()
except Exception:
pass
chosen_model = fb["model"]
current_tier = fb["tier"]
response = await self.provider.chat(
messages=messages, tools=self.tools.get_definitions(), model=chosen_model
)
if response.finish_reason not in _fail_reasons:
break

# If all fallbacks exhausted, give a friendly error instead of raw API dump
if response.finish_reason == "insufficient_credits":
response.content = (
"All available models failed due to insufficient credits. "
"Please top up your API provider credits and try again.\n\n"
"OpenRouter: https://openrouter.ai/settings/credits"
)
elif response.finish_reason == "context_length_exceeded":
response.content = (
"The conversation is too long for all available models. "
"Try starting a new conversation or use /compact to compress history."
)

# Handle tool calls
if response.has_tool_calls:
Expand Down
6 changes: 4 additions & 2 deletions debot/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ def gateway(
raise typer.Exit(1)

provider = LiteLLMProvider(
api_key=api_key, api_base=api_base, default_model=config.agents.defaults.model
api_key=api_key, api_base=api_base, default_model=config.agents.defaults.model,
all_api_keys=config.get_all_api_keys(),
)

# Create agent
Expand Down Expand Up @@ -423,7 +424,8 @@ def agent(

bus = MessageBus()
provider = LiteLLMProvider(
api_key=api_key, api_base=api_base, default_model=config.agents.defaults.model
api_key=api_key, api_base=api_base, default_model=config.agents.defaults.model,
all_api_keys=config.get_all_api_keys(),
)

agent_loop = AgentLoop(
Expand Down
17 changes: 17 additions & 0 deletions debot/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,23 @@ def get_api_key(self) -> str | None:
or None
)

def get_all_api_keys(self) -> dict[str, str]:
"""Return all configured provider API keys (non-empty only)."""
keys: dict[str, str] = {}
if self.providers.openrouter.api_key:
keys["openrouter"] = self.providers.openrouter.api_key
if self.providers.anthropic.api_key:
keys["anthropic"] = self.providers.anthropic.api_key
if self.providers.openai.api_key:
keys["openai"] = self.providers.openai.api_key
if self.providers.gemini.api_key:
keys["gemini"] = self.providers.gemini.api_key
if self.providers.groq.api_key:
keys["groq"] = self.providers.groq.api_key
if self.providers.zhipu.api_key:
keys["zhipu"] = self.providers.zhipu.api_key
return keys

def get_api_base(self) -> str | None:
"""Get API base URL if using OpenRouter, Zhipu or vLLM."""
if self.providers.openrouter.api_key:
Expand Down
78 changes: 54 additions & 24 deletions debot/providers/litellm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
api_key: str | None = None,
api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5",
all_api_keys: dict[str, str] | None = None,
):
super().__init__(api_key, api_base)
self.default_model = default_model
Expand All @@ -34,27 +35,42 @@ def __init__(
# Track if using custom endpoint (vLLM, etc.)
self.is_vllm = bool(api_base) and not self.is_openrouter

# Configure LiteLLM based on provider
# Set ALL configured provider API keys so LiteLLM can route across
# providers during fallback (e.g. OpenRouter credits exhausted → Anthropic direct).
_key_env_map = {
"openrouter": "OPENROUTER_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"openai": "OPENAI_API_KEY",
"gemini": "GEMINI_API_KEY",
"groq": "GROQ_API_KEY",
"zhipu": "ZHIPUAI_API_KEY",
}
if all_api_keys:
for provider_name, key in all_api_keys.items():
env_var = _key_env_map.get(provider_name)
if env_var and key:
os.environ.setdefault(env_var, key)

# Configure primary provider key (overrides setdefault above)
if api_key:
if self.is_openrouter:
# OpenRouter mode - set key
os.environ["OPENROUTER_API_KEY"] = api_key
elif self.is_vllm:
# vLLM/custom endpoint - uses OpenAI-compatible API
os.environ["OPENAI_API_KEY"] = api_key
elif "anthropic" in default_model:
os.environ.setdefault("ANTHROPIC_API_KEY", api_key)
os.environ["ANTHROPIC_API_KEY"] = api_key
elif "openai" in default_model or "gpt" in default_model:
os.environ.setdefault("OPENAI_API_KEY", api_key)
os.environ["OPENAI_API_KEY"] = api_key
elif "gemini" in default_model.lower():
os.environ.setdefault("GEMINI_API_KEY", api_key)
os.environ["GEMINI_API_KEY"] = api_key
elif "zhipu" in default_model or "glm" in default_model or "zai" in default_model:
os.environ.setdefault("ZHIPUAI_API_KEY", api_key)
os.environ["ZHIPUAI_API_KEY"] = api_key
elif "groq" in default_model:
os.environ.setdefault("GROQ_API_KEY", api_key)
os.environ["GROQ_API_KEY"] = api_key

if api_base:
litellm.api_base = api_base
# Do NOT set litellm.api_base globally — it would route ALL calls
# (including cross-provider fallback) through OpenRouter.
# api_base is passed per-call in chat() kwargs instead.

# Disable LiteLLM logging noise
litellm.suppress_debug_info = True
Expand Down Expand Up @@ -82,9 +98,18 @@ async def chat(
"""
model = model or self.default_model

# For OpenRouter, prefix model name if not already prefixed
# For OpenRouter, prefix model name — but skip if the model's native
# provider has a direct API key configured (cross-provider fallback).
if self.is_openrouter and not model.startswith("openrouter/"):
model = f"openrouter/{model}"
# Check if a direct provider key is available for this model
_direct_available = (
(model.startswith("anthropic/") and os.environ.get("ANTHROPIC_API_KEY"))
or (model.startswith("openai/") and os.environ.get("OPENAI_API_KEY"))
or (model.startswith("groq/") and os.environ.get("GROQ_API_KEY"))
or (model.startswith("gemini/") and os.environ.get("GEMINI_API_KEY"))
)
if not _direct_available:
model = f"openrouter/{model}"

# For Zhipu/Z.ai, ensure prefix is present
# Handle cases like "glm-4.7-flash" -> "zhipu/glm-4.7-flash"
Expand All @@ -111,8 +136,10 @@ async def chat(
"temperature": temperature,
}

# Pass api_base directly for custom endpoints (vLLM, etc.)
if self.api_base:
# Pass api_base for custom endpoints — but NOT for direct-provider calls
# (cross-provider fallback should hit the native API, not OpenRouter).
_is_direct = not model.startswith("openrouter/") and not model.startswith("hosted_vllm/")
if self.api_base and not _is_direct:
kwargs["api_base"] = self.api_base

if tools:
Expand All @@ -124,18 +151,21 @@ async def chat(
return self._parse_response(response)
except Exception as e:
err_str = str(e).lower()
# Classify context-window / token-limit errors for auto-escalation
# Billing check FIRST — billing errors often contain context-related
# words like "max_tokens" (e.g. "requires more credits, or fewer
# max_tokens"), so must be checked before context keywords.
billing_keywords = (
"credits", "afford", "402", "billing",
"payment", "quota", "budget",
)
context_keywords = (
"context_length",
"context window",
"maximum context",
"token limit",
"too many tokens",
"max_tokens",
"input too long",
"reduce your prompt",
"context_length", "context window", "maximum context",
"token limit", "too many tokens",
"input too long", "reduce your prompt",
)
if any(kw in err_str for kw in context_keywords):
if any(kw in err_str for kw in billing_keywords):
finish_reason = "insufficient_credits"
elif any(kw in err_str for kw in context_keywords):
finish_reason = "context_length_exceeded"
else:
finish_reason = "error"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "debot"
version = "0.1.2"
version = "0.1.3"
description = "A lightweight and secure personal AI assistant framework"
requires-python = ">=3.11"
license = {file = "LICENSE"}
Expand Down
41 changes: 41 additions & 0 deletions rust/src/router/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,44 @@ pub fn next_tier(current: &str) -> Option<&'static str> {
let idx = TIER_ORDER.iter().position(|t| *t == current)?;
TIER_ORDER.get(idx + 1).copied()
}

/// Returns the next cheaper tier for downgrade, or None if already at bottom.
pub fn prev_tier(current: &str) -> Option<&'static str> {
let idx = TIER_ORDER.iter().position(|t| *t == current)?;
if idx == 0 {
None
} else {
Some(TIER_ORDER[idx - 1])
}
}

/// Alternative models per tier, sorted by cost ascending (cheapest first).
/// Includes models from multiple providers for cross-provider billing fallback.
pub fn tier_alternatives() -> HashMap<&'static str, Vec<&'static str>> {
let mut m = HashMap::new();
m.insert("SIMPLE", vec![
"groq/llama-3.3-70b-versatile", // free tier
"deepseek/deepseek-chat", // $0.42
"openai/gpt-4o-mini", // $0.60
"openai/gpt-3.5-turbo", // $1.50
]);
m.insert("MEDIUM", vec![
"groq/llama-3.3-70b-versatile", // free tier
"deepseek/deepseek-chat", // $0.42
"openai/gpt-4o-mini", // $0.60
"minimax/minimax-m2", // $1.20
]);
m.insert("COMPLEX", vec![
"groq/llama-3.3-70b-versatile", // free tier (best-effort)
"anthropic/claude-sonnet-4-5", // $15.00
"openai/gpt-4o", // $10.00
"anthropic/claude-opus-4-5", // $25.00
]);
m.insert("REASONING", vec![
"groq/llama-3.3-70b-versatile", // free tier (best-effort)
"deepseek/deepseek-reasoner", // $2.19
"openai/o3-mini", // $4.40
"openai/o3", // $8.00
]);
m
}
25 changes: 25 additions & 0 deletions rust/src/router/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,35 @@ fn get_fallback_model(current_tier: &str) -> PyResult<String> {
}
}

/// Returns a JSON array of alternative models for a tier, sorted by cost ascending.
/// Each entry: {"model": "...", "cost": ...}
/// Used for billing fallback: try same-tier alternatives before escalating.
#[pyfunction]
fn get_tier_alternatives(tier: &str) -> PyResult<String> {
let alts = config::tier_alternatives();
let pricing = catalog::default_pricing();
let models = alts.get(tier).cloned().unwrap_or_default();
let mut entries: Vec<serde_json::Value> = models
.iter()
.map(|&model| {
let cost = *pricing.get(model).unwrap_or(&1.0);
json!({"model": model, "cost": cost})
})
.collect();
// Sort by cost ascending (cheapest first)
entries.sort_by(|a, b| {
let ca = a["cost"].as_f64().unwrap_or(f64::MAX);
let cb = b["cost"].as_f64().unwrap_or(f64::MAX);
ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
});
Ok(serde_json::to_string(&entries).unwrap_or_else(|_| "[]".to_string()))
}

pub fn pybindings(m: &pyo3::Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(route_text, m)?)?;
m.add_function(wrap_pyfunction!(get_context_length, m)?)?;
m.add_function(wrap_pyfunction!(get_fallback_model, m)?)?;
m.add_function(wrap_pyfunction!(get_tier_alternatives, m)?)?;
m.add_function(wrap_pyfunction!(metrics::get_router_metrics, m)?)?;
m.add_function(wrap_pyfunction!(metrics::reset_router_metrics, m)?)?;
m.add_function(wrap_pyfunction!(metrics::get_router_metrics_count, m)?)?;
Expand Down
Loading