Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions crates/forge_api/src/forge_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;

use anyhow::{Context, Result};
use anyhow::Result;
use forge_app::dto::ToolsOverview;
use forge_app::{
AgentProviderResolver, AgentRegistry, AppConfigService, AuthService, CommandInfra,
Expand Down Expand Up @@ -67,14 +67,7 @@ impl<A: Services, F: CommandInfra + EnvironmentInfra + SkillRepository + AppConf
}

async fn get_models(&self) -> Result<Vec<Model>> {
Ok(self
.services
.models(
self.get_default_provider()
.await
.context("Failed to fetch models")?,
)
.await?)
self.app().get_models().await
}
async fn get_agents(&self) -> Result<Vec<Agent>> {
self.services.get_agents().await
Expand Down
30 changes: 0 additions & 30 deletions crates/forge_app/src/agent_provider_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,36 +39,6 @@ where
self.0.get_default_provider().await?
};

// Check if credential needs refresh (5 minute buffer before expiry)
if let Some(credential) = &provider.credential {
let buffer = chrono::Duration::minutes(5);

if credential.needs_refresh(buffer) {
for auth_method in &provider.auth_methods {
match auth_method {
forge_domain::AuthMethod::OAuthDevice(_)
| forge_domain::AuthMethod::OAuthCode(_) => {
match self
.0
.refresh_provider_credential(&provider, auth_method.clone())
.await
{
Ok(refreshed_credential) => {
let mut updated_provider = provider.clone();
updated_provider.credential = Some(refreshed_credential);
return Ok(updated_provider);
}
Err(_) => {
return Ok(provider);
}
}
}
forge_domain::AuthMethod::ApiKey => {}
}
}
}
}

Ok(provider)
}

Expand Down
24 changes: 23 additions & 1 deletion crates/forge_app/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use crate::changed_files::ChangedFiles;
use crate::dto::ToolsOverview;
use crate::init_conversation_metrics::InitConversationMetrics;
use crate::orch::Orchestrator;
use crate::services::{AgentRegistry, CustomInstructionsService, TemplateService};
use crate::services::{
AgentRegistry, CustomInstructionsService, ProviderAuthService, TemplateService,
};
use crate::set_conversation_id::SetConversationId;
use crate::system_prompt::SystemPrompt;
use crate::tool_registry::ToolRegistry;
Expand Down Expand Up @@ -104,6 +106,12 @@ impl<S: Services> ForgeApp<S> {
let agent_provider = agent_provider_resolver
.get_provider(Some(agent.id.clone()))
.await?;
let agent_provider = self
.services
.provider_auth_service()
.refresh_provider_credential(agent_provider)
.await?;

let models = services.models(agent_provider).await?;

// Get system and mcp tool definitions and resolve them for the agent
Expand Down Expand Up @@ -265,6 +273,20 @@ impl<S: Services> ForgeApp<S> {
pub async fn list_tools(&self) -> Result<ToolsOverview> {
self.tool_registry.tools_overview().await
}

/// Gets available models for the default provider with automatic credential
/// refresh.
pub async fn get_models(&self) -> Result<Vec<Model>> {
let agent_provider_resolver = AgentProviderResolver::new(self.services.clone());
let provider = agent_provider_resolver.get_provider(None).await?;
let provider = self
.services
.provider_auth_service()
.refresh_provider_credential(provider)
.await?;

self.services.models(provider).await
}
pub async fn login(&self, init_auth: &InitAuth) -> Result<()> {
self.authenticator.login(init_auth).await
}
Expand Down
2 changes: 1 addition & 1 deletion crates/forge_app/src/git_app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ where
agent_provider_resolver.get_provider(agent_id.clone()),
agent_provider_resolver.get_model(agent_id)
)?;

let provider = self.services.refresh_provider_credential(provider).await?;
// Build git diff content with optional truncation notice
// Build user message using Element
let mut user_message = Element::new("user_message")
Expand Down
27 changes: 15 additions & 12 deletions crates/forge_app/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ use std::time::Duration;
use bytes::Bytes;
use derive_setters::Setters;
use forge_domain::{
AgentId, AnyProvider, Attachment, AuthContextRequest, AuthContextResponse, AuthCredential,
AuthMethod, ChatCompletionMessage, CommandOutput, Context, Conversation, ConversationId,
Environment, File, Image, InitAuth, LoginInfo, McpConfig, McpServers, Model, ModelId,
PatchOperation, Provider, ProviderId, ResultStream, Scope, Template, ToolCallFull, ToolOutput,
Workflow,
AgentId, AnyProvider, Attachment, AuthContextRequest, AuthContextResponse, AuthMethod,
ChatCompletionMessage, CommandOutput, Context, Conversation, ConversationId, Environment, File,
Image, InitAuth, LoginInfo, McpConfig, McpServers, Model, ModelId, PatchOperation, Provider,
ProviderId, ResultStream, Scope, Template, ToolCallFull, ToolOutput, Workflow,
};
use merge::Merge;
use reqwest::Response;
Expand Down Expand Up @@ -474,11 +473,16 @@ pub trait ProviderAuthService: Send + Sync {
context: AuthContextResponse,
timeout: Duration,
) -> anyhow::Result<()>;

/// Refreshes provider credentials if they're about to expire.
/// Checks if credential needs refresh (5 minute buffer before expiry),
/// iterates through provider's auth methods, and attempts to refresh.
/// Returns the provider with updated credentials, or original if refresh
/// fails or isn't needed.
async fn refresh_provider_credential(
&self,
provider: &Provider<Url>,
method: AuthMethod,
) -> anyhow::Result<AuthCredential>;
provider: Provider<Url>,
) -> anyhow::Result<Provider<Url>>;
}

/// Core app trait providing access to services and repositories.
Expand Down Expand Up @@ -986,11 +990,10 @@ impl<I: Services> ProviderAuthService for I {
}
async fn refresh_provider_credential(
&self,
provider: &Provider<Url>,
method: AuthMethod,
) -> anyhow::Result<AuthCredential> {
provider: Provider<Url>,
) -> anyhow::Result<Provider<Url>> {
self.provider_auth_service()
.refresh_provider_credential(provider, method)
.refresh_provider_credential(provider)
.await
}
}
98 changes: 67 additions & 31 deletions crates/forge_services/src/provider_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use std::time::Duration;

use forge_app::{AuthStrategy, ProviderAuthService, StrategyFactory};
use forge_domain::{
AuthContextRequest, AuthContextResponse, AuthCredential, AuthMethod, Provider, ProviderId,
ProviderRepository,
AuthContextRequest, AuthContextResponse, AuthMethod, Provider, ProviderId, ProviderRepository,
};

/// Forge Provider Authentication Service
Expand Down Expand Up @@ -103,37 +102,74 @@ where
self.infra.upsert_credential(credential).await
}

/// Refresh provider credential
/// Refreshes provider credentials if they're about to expire.
/// Checks if credential needs refresh (5 minute buffer before expiry),
/// iterates through provider's auth methods, and attempts to refresh.
/// Returns the provider with updated credentials, or original if refresh
/// fails or isn't needed.
async fn refresh_provider_credential(
&self,
provider: &Provider<url::Url>,
auth_method: AuthMethod,
) -> anyhow::Result<AuthCredential> {
// Get existing credential
let credential = self
.infra
.get_credential(&provider.id)
.await?
.ok_or_else(|| forge_domain::Error::ProviderNotAvailable {
provider: provider.id.clone(),
})?;

// Get required params (only used for API key, but needed for factory)
let required_params = if matches!(auth_method, AuthMethod::ApiKey) {
provider.url_params.clone()
} else {
vec![]
};

// Create strategy and refresh credential
let strategy =
self.infra
.create_auth_strategy(provider.id.clone(), auth_method, required_params)?;
let refreshed = strategy.refresh(&credential).await?;

// Store refreshed credential
self.infra.upsert_credential(refreshed.clone()).await?;
mut provider: Provider<url::Url>,
) -> anyhow::Result<Provider<url::Url>> {
// Check if credential needs refresh (5 minute buffer before expiry)
if let Some(credential) = &provider.credential {
let buffer = chrono::Duration::minutes(5);

if credential.needs_refresh(buffer) {
// Iterate through auth methods and try to refresh
for auth_method in &provider.auth_methods {
match auth_method {
AuthMethod::OAuthDevice(_) | AuthMethod::OAuthCode(_) => {
// Get existing credential
let existing_credential =
self.infra.get_credential(&provider.id).await?.ok_or_else(
|| forge_domain::Error::ProviderNotAvailable {
provider: provider.id.clone(),
},
)?;

// Get required params (only used for API key, but needed for factory)
let required_params = if matches!(auth_method, AuthMethod::ApiKey) {
provider.url_params.clone()
} else {
vec![]
};

// Create strategy and refresh credential
if let Ok(strategy) = self.infra.create_auth_strategy(
provider.id.clone(),
auth_method.clone(),
required_params,
) {
match strategy.refresh(&existing_credential).await {
Ok(refreshed) => {
// Store refreshed credential
if self
.infra
.upsert_credential(refreshed.clone())
.await
.is_err()
{
continue;
}

// Update provider with refreshed credential
provider.credential = Some(refreshed);
break; // Success, stop trying other methods
}
Err(_) => {
// If refresh fails, continue with
// existing credentials
}
}
}
}
_ => {}
}
}
}
}

Ok(refreshed)
Ok(provider)
}
}
Loading