Skip to content

Commit 554f51f

Browse files
authored
Revert "fix: prevent partial config write when model selection is cancelled d…"
This reverts commit 6a36b83.
1 parent fbeea84 commit 554f51f

1 file changed

Lines changed: 38 additions & 50 deletions

File tree

  • crates/forge_main/src

crates/forge_main/src/ui.rs

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use convert_case::{Case, Casing};
1212
use forge_api::{
1313
API, AgentId, AnyProvider, ApiKeyRequest, AuthContextRequest, AuthContextResponse, ChatRequest,
1414
ChatResponse, CodeRequest, Conversation, ConversationId, DeviceCodeRequest, Event,
15-
InterruptionReason, ModelId, Provider, ProviderId, TextMessage, UserPrompt,
15+
InterruptionReason, Model, ModelId, Provider, ProviderId, TextMessage, UserPrompt,
1616
};
1717
use forge_app::utils::{format_display_path, truncate_key};
1818
use forge_app::{CommitResult, ToolResolver};
@@ -127,6 +127,14 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
127127
self.spinner.ewrite_ln(title)
128128
}
129129

130+
/// Retrieve available models
131+
async fn get_models(&mut self) -> Result<Vec<Model>> {
132+
self.spinner.start(Some("Loading"))?;
133+
let models = self.api.get_models().await?;
134+
self.spinner.stop(None)?;
135+
Ok(models)
136+
}
137+
130138
/// Helper to get provider for an optional agent, defaulting to the current
131139
/// active agent's provider
132140
async fn get_provider(&self, agent_id: Option<AgentId>) -> Result<Provider<Url>> {
@@ -2019,9 +2027,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
20192027
) -> Result<Option<ModelId>> {
20202028
// Check if provider is set otherwise first ask to select a provider
20212029
if self.api.get_default_provider().await.is_err() {
2022-
if !self.on_provider_selection().await? {
2023-
return Ok(None);
2024-
}
2030+
self.on_provider_selection().await?;
20252031

20262032
// Check if a model was already selected during provider activation
20272033
// Return None to signal the model selection is complete and message was already
@@ -2680,16 +2686,15 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
26802686
Ok(Some(model))
26812687
}
26822688

2683-
async fn on_provider_selection(&mut self) -> Result<bool> {
2689+
async fn on_provider_selection(&mut self) -> Result<()> {
26842690
// Select a provider
26852691
// If no provider was selected (user canceled), return early
26862692
let any_provider = match self.select_provider().await? {
26872693
Some(provider) => provider,
2688-
None => return Ok(false),
2694+
None => return Ok(()),
26892695
};
26902696

2691-
self.activate_provider(any_provider).await?;
2692-
Ok(true)
2697+
self.activate_provider(any_provider).await
26932698
}
26942699

26952700
/// Activates a provider by configuring it if needed, setting it as default,
@@ -2736,22 +2741,20 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27362741
provider: Provider<Url>,
27372742
model: Option<ModelId>,
27382743
) -> Result<()> {
2744+
// Set the provider via API
2745+
self.api.set_default_provider(provider.id.clone()).await?;
2746+
2747+
self.writeln_title(
2748+
TitleFormat::action(format!("{}", provider.id))
2749+
.sub_title("is now the default provider"),
2750+
)?;
2751+
27392752
// If a model was pre-selected (e.g. from :model), validate and set it
27402753
// directly without prompting
27412754
if let Some(model) = model {
27422755
let model_id = self
27432756
.validate_model(model.as_str(), Some(&provider.id))
27442757
.await?;
2745-
2746-
//set provider
2747-
self.api.set_default_provider(provider.id.clone()).await?;
2748-
2749-
self.writeln_title(
2750-
TitleFormat::action(format!("{}", provider.id))
2751-
.sub_title("is now the default provider"),
2752-
)?;
2753-
2754-
//set model
27552758
self.api.set_default_model(model_id.clone()).await?;
27562759
self.writeln_title(
27572760
TitleFormat::action(model_id.as_str()).sub_title("is now the default model"),
@@ -2760,35 +2763,20 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27602763
}
27612764

27622765
// Check if the current model is available for the new provider
2763-
27642766
let current_model = self.api.get_default_model().await;
2767+
if let Some(current_model) = current_model {
2768+
let models = self.get_models().await?;
2769+
let model_available = models.iter().any(|m| m.id == current_model);
27652770

2766-
let needs_model_selection = match current_model {
2767-
None => true,
2768-
Some(current_model) => {
2769-
let provider_models = self.api.get_all_provider_models().await?;
2770-
!provider_models
2771-
.iter()
2772-
.find(|pm| pm.provider_id == provider.id)
2773-
.map(|pm| pm.models.iter().any(|m| m.id == current_model))
2774-
.unwrap_or(false)
2775-
}
2776-
};
2777-
2778-
if needs_model_selection {
2779-
self.writeln_title(TitleFormat::info("Please select a new model"))?;
2780-
let selected = self.on_model_selection(Some(provider.id.clone())).await?;
2781-
if selected.is_none() {
2782-
// User cancelled — preserve existing config untouched
2783-
return Ok(());
2771+
if !model_available {
2772+
// Prompt user to select a new model, scoped to the activated provider
2773+
self.writeln_title(TitleFormat::info("Please select a new model"))?;
2774+
self.on_model_selection(Some(provider.id.clone())).await?;
27842775
}
2776+
} else {
2777+
// No model set, select one now scoped to the activated provider
2778+
self.on_model_selection(Some(provider.id.clone())).await?;
27852779
}
2786-
// Only reaches here if model is confirmed — safe to write provider now
2787-
self.api.set_default_provider(provider.id.clone()).await?;
2788-
self.writeln_title(
2789-
TitleFormat::action(format!("{}", provider.id))
2790-
.sub_title("is now the default provider"),
2791-
)?;
27922780

27932781
Ok(())
27942782
}
@@ -2895,19 +2883,19 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
28952883
// Ensure we have a model selected before proceeding with initialization
28962884
let active_agent = self.api.get_active_agent().await;
28972885

2898-
// Validate provider is configured before loading agents
2899-
// If provider is set in config but not configured (no credentials), prompt user
2900-
// to login
2901-
if self.api.get_default_provider().await.is_err() && !self.on_provider_selection().await? {
2902-
return Ok(());
2903-
}
2904-
29052886
let mut operating_model = self.get_agent_model(active_agent.clone()).await;
29062887
if operating_model.is_none() {
29072888
// Use the model returned from selection instead of re-fetching
29082889
operating_model = self.on_model_selection(None).await?;
29092890
}
29102891

2892+
// Validate provider is configured before loading agents
2893+
// If provider is set in config but not configured (no credentials), prompt user
2894+
// to login
2895+
if self.api.get_default_provider().await.is_err() {
2896+
self.on_provider_selection().await?;
2897+
}
2898+
29112899
if first {
29122900
// For chat, we are trying to get active agent or setting it to default.
29132901
// So for default values, `/info` doesn't show active provider, model, etc.

0 commit comments

Comments
 (0)