Skip to content

Commit ee65bd2

Browse files
Aiswarya PrakasanAiswarya Prakasan
authored andcommitted
atomic model and provider selection on login
1 parent b1054bd commit ee65bd2

6 files changed

Lines changed: 115 additions & 48 deletions

File tree

crates/forge_api/src/api.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ pub trait API: Sync + Send {
129129
/// Sets the default provider for all the agents
130130
async fn set_default_provider(&self, provider_id: ProviderId) -> anyhow::Result<()>;
131131

132+
/// Updates the caller's default provider and model together, ensuring all
133+
/// commands resolve a consistent pair without requiring a follow-up model
134+
/// selection call.
135+
async fn set_default_provider_and_model(
136+
&self,
137+
provider_id: ProviderId,
138+
model: ModelId,
139+
) -> anyhow::Result<()>;
140+
132141
/// Retrieves information about the currently authenticated user
133142
async fn user_info(&self) -> anyhow::Result<Option<User>>;
134143

crates/forge_api/src/forge_api.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,19 @@ impl<A: Services, F: CommandInfra + EnvironmentInfra + SkillRepository + GrpcInf
278278
result
279279
}
280280

281+
async fn set_default_provider_and_model(
282+
&self,
283+
provider_id: ProviderId,
284+
model: ModelId,
285+
) -> anyhow::Result<()> {
286+
let result = self
287+
.services
288+
.set_default_provider_and_model(provider_id, model)
289+
.await;
290+
let _ = self.services.reload_agents().await;
291+
result
292+
}
293+
281294
async fn get_commit_config(&self) -> anyhow::Result<Option<CommitConfig>> {
282295
self.services.get_commit_config().await
283296
}

crates/forge_app/src/services.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,14 @@ pub trait AppConfigService: Send + Sync {
211211
/// Returns an error if no default provider is configured.
212212
async fn set_default_model(&self, model: ModelId) -> anyhow::Result<()>;
213213

214+
/// Sets the user's default provider and default model in a single atomic
215+
/// update so the persisted configuration never stores a mismatched pair.
216+
async fn set_default_provider_and_model(
217+
&self,
218+
provider_id: ProviderId,
219+
model: ModelId,
220+
) -> anyhow::Result<()>;
221+
214222
/// Gets the commit configuration (provider and model for commit message
215223
/// generation).
216224
async fn get_commit_config(&self) -> anyhow::Result<Option<forge_domain::CommitConfig>>;
@@ -971,6 +979,16 @@ impl<I: Services> AppConfigService for I {
971979
self.config_service().get_provider_model(provider_id).await
972980
}
973981

982+
async fn set_default_provider_and_model(
983+
&self,
984+
provider_id: forge_domain::ProviderId,
985+
model: ModelId,
986+
) -> anyhow::Result<()> {
987+
self.config_service()
988+
.set_default_provider_and_model(provider_id, model)
989+
.await
990+
}
991+
974992
async fn set_default_model(&self, model: ModelId) -> anyhow::Result<()> {
975993
self.config_service().set_default_model(model).await
976994
}

crates/forge_main/src/info.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ impl Section {
7575
/// # Output Format
7676
///
7777
/// ```text
78-
///
78+
///
7979
/// CONFIGURATION
8080
/// model gpt-4
8181
/// provider openai

crates/forge_main/src/ui.rs

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use convert_case::{Case, Casing};
1212
use forge_api::{
1313
API, AgentId, AnyProvider, ApiKeyRequest, AuthContextRequest, AuthContextResponse, ChatRequest,
1414
ChatResponse, CodeRequest, Conversation, ConversationId, DeviceCodeRequest, Event,
15-
InterruptionReason, Model, ModelId, Provider, ProviderId, TextMessage, UserPrompt,
15+
InterruptionReason, ModelId, Provider, ProviderId, TextMessage, UserPrompt,
1616
};
1717
use forge_app::utils::{format_display_path, truncate_key};
1818
use forge_app::{CommitResult, ToolResolver};
@@ -127,14 +127,6 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
127127
self.spinner.ewrite_ln(title)
128128
}
129129

130-
/// Retrieve available models
131-
async fn get_models(&mut self) -> Result<Vec<Model>> {
132-
self.spinner.start(Some("Loading"))?;
133-
let models = self.api.get_models().await?;
134-
self.spinner.stop(None)?;
135-
Ok(models)
136-
}
137-
138130
/// Helper to get provider for an optional agent, defaulting to the current
139131
/// active agent's provider
140132
async fn get_provider(&self, agent_id: Option<AgentId>) -> Result<Provider<Url>> {
@@ -649,6 +641,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
649641
return Ok(());
650642
}
651643
TopLevelCommand::Commit(commit_group) => {
644+
self.init_state(false).await?;
652645
let preview = commit_group.preview;
653646
let result = self.handle_commit_command(commit_group).await?;
654647
if preview {
@@ -1899,7 +1892,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
18991892
self.on_custom_event(event.into()).await?;
19001893
}
19011894
SlashCommand::Model => {
1902-
self.on_model_selection(None).await?;
1895+
self.on_model_selection(None, None).await?;
19031896
}
19041897
SlashCommand::Provider => {
19051898
self.on_provider_selection().await?;
@@ -2074,15 +2067,11 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
20742067
provider_filter: Option<ProviderId>,
20752068
) -> Result<Option<ModelId>> {
20762069
// Check if provider is set otherwise first ask to select a provider
2077-
if self.api.get_default_provider().await.is_err() {
2078-
self.on_provider_selection().await?;
2079-
2080-
// Check if a model was already selected during provider activation
2081-
// Return None to signal the model selection is complete and message was already
2082-
// printed
2083-
if self.api.get_default_model().await.is_some() {
2070+
if provider_filter.is_none() && self.api.get_default_provider().await.is_err() {
2071+
if !self.on_provider_selection().await? {
20842072
return Ok(None);
20852073
}
2074+
return Ok(None);
20862075
}
20872076

20882077
// Fetch models from ALL configured providers (matches shell plugin's
@@ -2713,6 +2702,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27132702
async fn on_model_selection(
27142703
&mut self,
27152704
provider_filter: Option<ProviderId>,
2705+
provider_to_activate: Option<ProviderId>,
27162706
) -> Result<Option<ModelId>> {
27172707
// Select a model
27182708
let model_option = self.select_model(provider_filter).await?;
@@ -2723,8 +2713,14 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27232713
None => return Ok(None),
27242714
};
27252715

2726-
// Update the operating model via API
2727-
self.api.set_default_model(model.clone()).await?;
2716+
// If we have a provider to activate, write both atomically
2717+
if let Some(provider_id) = provider_to_activate {
2718+
self.api
2719+
.set_default_provider_and_model(provider_id, model.clone())
2720+
.await?;
2721+
} else {
2722+
self.api.set_default_model(model.clone()).await?;
2723+
}
27282724

27292725
// Update the UI state with the new model
27302726
self.update_model(Some(model.clone()));
@@ -2734,15 +2730,18 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27342730
Ok(Some(model))
27352731
}
27362732

2737-
async fn on_provider_selection(&mut self) -> Result<()> {
2733+
async fn on_provider_selection(&mut self) -> Result<bool> {
27382734
// Select a provider
27392735
// If no provider was selected (user canceled), return early
27402736
let any_provider = match self.select_provider().await? {
27412737
Some(provider) => provider,
2742-
None => return Ok(()),
2738+
None => return Ok(false),
27432739
};
27442740

2745-
self.activate_provider(any_provider).await
2741+
self.activate_provider(any_provider).await?;
2742+
// Check if provider was actually saved — if user cancelled model selection
2743+
// inside activate_provider, nothing was written
2744+
Ok(self.api.get_default_provider().await.is_ok())
27462745
}
27472746

27482747
/// Activates a provider by configuring it if needed, setting it as default,
@@ -2789,21 +2788,19 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27892788
provider: Provider<Url>,
27902789
model: Option<ModelId>,
27912790
) -> Result<()> {
2792-
// Set the provider via API
2793-
self.api.set_default_provider(provider.id.clone()).await?;
2794-
2795-
self.writeln_title(
2796-
TitleFormat::action(format!("{}", provider.id))
2797-
.sub_title("is now the default provider"),
2798-
)?;
2799-
28002791
// If a model was pre-selected (e.g. from :model), validate and set it
28012792
// directly without prompting
28022793
if let Some(model) = model {
28032794
let model_id = self
28042795
.validate_model(model.as_str(), Some(&provider.id))
28052796
.await?;
2806-
self.api.set_default_model(model_id.clone()).await?;
2797+
self.api
2798+
.set_default_provider_and_model(provider.id.clone(), model_id.clone())
2799+
.await?;
2800+
self.writeln_title(
2801+
TitleFormat::action(format!("{}", provider.id))
2802+
.sub_title("is now the default provider"),
2803+
)?;
28072804
self.writeln_title(
28082805
TitleFormat::action(model_id.as_str()).sub_title("is now the default model"),
28092806
)?;
@@ -2812,18 +2809,37 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
28122809

28132810
// Check if the current model is available for the new provider
28142811
let current_model = self.api.get_default_model().await;
2815-
if let Some(current_model) = current_model {
2816-
let models = self.get_models().await?;
2817-
let model_available = models.iter().any(|m| m.id == current_model);
2812+
let needs_model_selection = match current_model {
2813+
None => true,
2814+
Some(current_model) => {
2815+
let provider_models = self.api.get_all_provider_models().await?;
2816+
let model_available = provider_models
2817+
.iter()
2818+
.find(|pm| pm.provider_id == provider.id)
2819+
.map(|pm| pm.models.iter().any(|m| m.id == current_model))
2820+
.unwrap_or(false);
2821+
!model_available
2822+
}
2823+
};
28182824

2819-
if !model_available {
2820-
// Prompt user to select a new model, scoped to the activated provider
2821-
self.writeln_title(TitleFormat::info("Please select a new model"))?;
2822-
self.on_model_selection(Some(provider.id.clone())).await?;
2825+
if needs_model_selection {
2826+
self.writeln_title(TitleFormat::info("Please select a new model"))?;
2827+
let selected = self
2828+
.on_model_selection(Some(provider.id.clone()), Some(provider.id.clone()))
2829+
.await?;
2830+
if selected.is_none() {
2831+
// User cancelled — preserve existing config untouched
2832+
return Ok(());
28232833
}
28242834
} else {
2825-
// No model set, select one now scoped to the activated provider
2826-
self.on_model_selection(Some(provider.id.clone())).await?;
2835+
// Set the provider via API
2836+
// Only reaches here if model is confirmed — safe to write provider now
2837+
self.api.set_default_provider(provider.id.clone()).await?;
2838+
2839+
self.writeln_title(
2840+
TitleFormat::action(format!("{}", provider.id))
2841+
.sub_title("is now the default provider"),
2842+
)?;
28272843
}
28282844

28292845
Ok(())
@@ -2931,17 +2947,19 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
29312947
// Ensure we have a model selected before proceeding with initialization
29322948
let active_agent = self.api.get_active_agent().await;
29332949

2934-
let mut operating_model = self.get_agent_model(active_agent.clone()).await;
2935-
if operating_model.is_none() {
2936-
// Use the model returned from selection instead of re-fetching
2937-
operating_model = self.on_model_selection(None).await?;
2938-
}
2939-
29402950
// Validate provider is configured before loading agents
29412951
// If provider is set in config but not configured (no credentials), prompt user
29422952
// to login
29432953
if self.api.get_default_provider().await.is_err() {
2944-
self.on_provider_selection().await?;
2954+
if !self.on_provider_selection().await? {
2955+
return Ok(());
2956+
}
2957+
}
2958+
2959+
let mut operating_model = self.get_agent_model(active_agent.clone()).await;
2960+
if operating_model.is_none() {
2961+
// Use the model returned from selection instead of re-fetching
2962+
operating_model = self.on_model_selection(None, None).await?;
29452963
}
29462964

29472965
if first {

crates/forge_services/src/app_config.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ impl<F: ProviderRepository + EnvironmentInfra + Send + Sync> AppConfigService
9393
.await
9494
}
9595

96+
async fn set_default_provider_and_model(
97+
&self,
98+
provider_id: ProviderId,
99+
model: ModelId,
100+
) -> anyhow::Result<()> {
101+
self.update(ConfigOperation::SetModel(provider_id, model))
102+
.await
103+
}
104+
96105
async fn get_commit_config(&self) -> anyhow::Result<Option<forge_domain::CommitConfig>> {
97106
let config = self.infra.get_config();
98107
Ok(config.commit.map(|mc| CommitConfig {

0 commit comments

Comments
 (0)