@@ -12,7 +12,7 @@ use convert_case::{Case, Casing};
1212use forge_api:: {
1313 API , AgentId , AnyProvider , ApiKeyRequest , AuthContextRequest , AuthContextResponse , ChatRequest ,
1414 ChatResponse , CodeRequest , Conversation , ConversationId , DeviceCodeRequest , Event ,
15- InterruptionReason , Model , ModelId , Provider , ProviderId , TextMessage , UserPrompt ,
15+ InterruptionReason , ModelId , Provider , ProviderId , TextMessage , UserPrompt ,
1616} ;
1717use forge_app:: utils:: { format_display_path, truncate_key} ;
1818use forge_app:: { CommitResult , ToolResolver } ;
@@ -127,14 +127,6 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
127127 self . spinner . ewrite_ln ( title)
128128 }
129129
130- /// Retrieve available models
131- async fn get_models ( & mut self ) -> Result < Vec < Model > > {
132- self . spinner . start ( Some ( "Loading" ) ) ?;
133- let models = self . api . get_models ( ) . await ?;
134- self . spinner . stop ( None ) ?;
135- Ok ( models)
136- }
137-
138130 /// Helper to get provider for an optional agent, defaulting to the current
139131 /// active agent's provider
140132 async fn get_provider ( & self , agent_id : Option < AgentId > ) -> Result < Provider < Url > > {
@@ -649,6 +641,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
649641 return Ok ( ( ) ) ;
650642 }
651643 TopLevelCommand :: Commit ( commit_group) => {
644+ self . init_state ( false ) . await ?;
652645 let preview = commit_group. preview ;
653646 let result = self . handle_commit_command ( commit_group) . await ?;
654647 if preview {
@@ -1899,7 +1892,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
18991892 self . on_custom_event ( event. into ( ) ) . await ?;
19001893 }
19011894 SlashCommand :: Model => {
1902- self . on_model_selection ( None ) . await ?;
1895+ self . on_model_selection ( None , None ) . await ?;
19031896 }
19041897 SlashCommand :: Provider => {
19051898 self . on_provider_selection ( ) . await ?;
@@ -2074,15 +2067,11 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
20742067 provider_filter : Option < ProviderId > ,
20752068 ) -> Result < Option < ModelId > > {
20762069 // Check if provider is set otherwise first ask to select a provider
2077- if self . api . get_default_provider ( ) . await . is_err ( ) {
2078- self . on_provider_selection ( ) . await ?;
2079-
2080- // Check if a model was already selected during provider activation
2081- // Return None to signal the model selection is complete and message was already
2082- // printed
2083- if self . api . get_default_model ( ) . await . is_some ( ) {
2070+ if provider_filter. is_none ( ) && self . api . get_default_provider ( ) . await . is_err ( ) {
2071+ if !self . on_provider_selection ( ) . await ? {
20842072 return Ok ( None ) ;
20852073 }
2074+ return Ok ( None ) ;
20862075 }
20872076
20882077 // Fetch models from ALL configured providers (matches shell plugin's
@@ -2713,6 +2702,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27132702 async fn on_model_selection (
27142703 & mut self ,
27152704 provider_filter : Option < ProviderId > ,
2705+ provider_to_activate : Option < ProviderId > ,
27162706 ) -> Result < Option < ModelId > > {
27172707 // Select a model
27182708 let model_option = self . select_model ( provider_filter) . await ?;
@@ -2723,8 +2713,14 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27232713 None => return Ok ( None ) ,
27242714 } ;
27252715
2726- // Update the operating model via API
2727- self . api . set_default_model ( model. clone ( ) ) . await ?;
2716+ // If we have a provider to activate, write both atomically
2717+ if let Some ( provider_id) = provider_to_activate {
2718+ self . api
2719+ . set_default_provider_and_model ( provider_id, model. clone ( ) )
2720+ . await ?;
2721+ } else {
2722+ self . api . set_default_model ( model. clone ( ) ) . await ?;
2723+ }
27282724
27292725 // Update the UI state with the new model
27302726 self . update_model ( Some ( model. clone ( ) ) ) ;
@@ -2734,15 +2730,18 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27342730 Ok ( Some ( model) )
27352731 }
27362732
2737- async fn on_provider_selection ( & mut self ) -> Result < ( ) > {
2733+ async fn on_provider_selection ( & mut self ) -> Result < bool > {
27382734 // Select a provider
27392735 // If no provider was selected (user canceled), return early
27402736 let any_provider = match self . select_provider ( ) . await ? {
27412737 Some ( provider) => provider,
2742- None => return Ok ( ( ) ) ,
2738+ None => return Ok ( false ) ,
27432739 } ;
27442740
2745- self . activate_provider ( any_provider) . await
2741+ self . activate_provider ( any_provider) . await ?;
2742+ // Check if provider was actually saved — if user cancelled model selection
2743+ // inside activate_provider, nothing was written
2744+ Ok ( self . api . get_default_provider ( ) . await . is_ok ( ) )
27462745 }
27472746
27482747 /// Activates a provider by configuring it if needed, setting it as default,
@@ -2789,21 +2788,19 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
27892788 provider : Provider < Url > ,
27902789 model : Option < ModelId > ,
27912790 ) -> Result < ( ) > {
2792- // Set the provider via API
2793- self . api . set_default_provider ( provider. id . clone ( ) ) . await ?;
2794-
2795- self . writeln_title (
2796- TitleFormat :: action ( format ! ( "{}" , provider. id) )
2797- . sub_title ( "is now the default provider" ) ,
2798- ) ?;
2799-
28002791 // If a model was pre-selected (e.g. from :model), validate and set it
28012792 // directly without prompting
28022793 if let Some ( model) = model {
28032794 let model_id = self
28042795 . validate_model ( model. as_str ( ) , Some ( & provider. id ) )
28052796 . await ?;
2806- self . api . set_default_model ( model_id. clone ( ) ) . await ?;
2797+ self . api
2798+ . set_default_provider_and_model ( provider. id . clone ( ) , model_id. clone ( ) )
2799+ . await ?;
2800+ self . writeln_title (
2801+ TitleFormat :: action ( format ! ( "{}" , provider. id) )
2802+ . sub_title ( "is now the default provider" ) ,
2803+ ) ?;
28072804 self . writeln_title (
28082805 TitleFormat :: action ( model_id. as_str ( ) ) . sub_title ( "is now the default model" ) ,
28092806 ) ?;
@@ -2812,18 +2809,37 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
28122809
28132810 // Check if the current model is available for the new provider
28142811 let current_model = self . api . get_default_model ( ) . await ;
2815- if let Some ( current_model) = current_model {
2816- let models = self . get_models ( ) . await ?;
2817- let model_available = models. iter ( ) . any ( |m| m. id == current_model) ;
2812+ let needs_model_selection = match current_model {
2813+ None => true ,
2814+ Some ( current_model) => {
2815+ let provider_models = self . api . get_all_provider_models ( ) . await ?;
2816+ let model_available = provider_models
2817+ . iter ( )
2818+ . find ( |pm| pm. provider_id == provider. id )
2819+ . map ( |pm| pm. models . iter ( ) . any ( |m| m. id == current_model) )
2820+ . unwrap_or ( false ) ;
2821+ !model_available
2822+ }
2823+ } ;
28182824
2819- if !model_available {
2820- // Prompt user to select a new model, scoped to the activated provider
2821- self . writeln_title ( TitleFormat :: info ( "Please select a new model" ) ) ?;
2822- self . on_model_selection ( Some ( provider. id . clone ( ) ) ) . await ?;
2825+ if needs_model_selection {
2826+ self . writeln_title ( TitleFormat :: info ( "Please select a new model" ) ) ?;
2827+ let selected = self
2828+ . on_model_selection ( Some ( provider. id . clone ( ) ) , Some ( provider. id . clone ( ) ) )
2829+ . await ?;
2830+ if selected. is_none ( ) {
2831+ // User cancelled — preserve existing config untouched
2832+ return Ok ( ( ) ) ;
28232833 }
28242834 } else {
2825- // No model set, select one now scoped to the activated provider
2826- self . on_model_selection ( Some ( provider. id . clone ( ) ) ) . await ?;
2835+ // Set the provider via API
2836+ // Only reaches here if model is confirmed — safe to write provider now
2837+ self . api . set_default_provider ( provider. id . clone ( ) ) . await ?;
2838+
2839+ self . writeln_title (
2840+ TitleFormat :: action ( format ! ( "{}" , provider. id) )
2841+ . sub_title ( "is now the default provider" ) ,
2842+ ) ?;
28272843 }
28282844
28292845 Ok ( ( ) )
@@ -2931,17 +2947,19 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
29312947 // Ensure we have a model selected before proceeding with initialization
29322948 let active_agent = self . api . get_active_agent ( ) . await ;
29332949
2934- let mut operating_model = self . get_agent_model ( active_agent. clone ( ) ) . await ;
2935- if operating_model. is_none ( ) {
2936- // Use the model returned from selection instead of re-fetching
2937- operating_model = self . on_model_selection ( None ) . await ?;
2938- }
2939-
29402950 // Validate provider is configured before loading agents
29412951 // If provider is set in config but not configured (no credentials), prompt user
29422952 // to login
29432953 if self . api . get_default_provider ( ) . await . is_err ( ) {
2944- self . on_provider_selection ( ) . await ?;
2954+ if !self . on_provider_selection ( ) . await ? {
2955+ return Ok ( ( ) ) ;
2956+ }
2957+ }
2958+
2959+ let mut operating_model = self . get_agent_model ( active_agent. clone ( ) ) . await ;
2960+ if operating_model. is_none ( ) {
2961+ // Use the model returned from selection instead of re-fetching
2962+ operating_model = self . on_model_selection ( None , None ) . await ?;
29452963 }
29462964
29472965 if first {
0 commit comments