@@ -601,6 +601,7 @@ def __init__(self, fallback_to_cascadeflow: bool = True):
601601 self .fallback_to_cascadeflow = fallback_to_cascadeflow
602602 self .budget_manager = None
603603 self .cost_provider = LiteLLMCostProvider ()
604+ self ._user_budgets : dict [str , dict ] = {}
604605
605606 if BUDGET_MANAGER_AVAILABLE :
606607 self .budget_manager = BudgetManager (project_name = "cascadeflow" )
@@ -633,11 +634,18 @@ def set_user_budget(self, user: str, max_budget: float) -> None:
633634 Example:
634635 >>> tracker.set_user_budget("user_123", max_budget=10.0)
635636 """
637+ self ._user_budgets [user ] = {
638+ "max_budget" : max_budget ,
639+ "current_cost" : 0.0 ,
640+ }
641+
636642 if self .budget_manager :
637- self .budget_manager .create_budget (user = user , max_budget = max_budget )
638- logger .info (f"Set budget for { user } : ${ max_budget :.2f} " )
639- else :
640- logger .warning (f"Cannot set budget for { user } - BudgetManager unavailable" )
643+ try :
644+ self .budget_manager .create_budget (user = user , total_budget = max_budget )
645+ except Exception as e :
646+ logger .debug (f"BudgetManager.create_budget failed for { user } : { e } " )
647+
648+ logger .info (f"Set budget for { user } : ${ max_budget :.2f} " )
641649
642650 def update_cost (
643651 self ,
@@ -677,51 +685,36 @@ def update_cost(
677685 ... response=api_response
678686 ... )
679687 """
680- if self .budget_manager :
681- try :
682- # If we have actual API response, use it
683- if response :
684- cost = self .budget_manager .update_cost (completion_obj = response , user = user )
685- else :
686- # Calculate cost from tokens
687- cost = self .cost_provider .calculate_cost (
688- model = model ,
689- input_tokens = prompt_tokens ,
690- output_tokens = completion_tokens ,
691- )
692-
693- # Update budget manager
694- self .budget_manager .update_cost (user = user , cost = cost )
695-
696- logger .debug (f"Updated cost for { user } : ${ cost :.6f} " )
697- return cost
698-
699- except Exception as e :
700- logger .error (f"Error updating cost for { user } : { e } " )
701- # Fall through to fallback
702-
703- # Fallback to cascadeflow CostTracker
704- if self .fallback_to_cascadeflow and hasattr (self , "cost_tracker" ):
688+ # Calculate cost from tokens or response
689+ if response :
690+ cost = self .cost_provider .calculate_cost (
691+ model = model ,
692+ input_tokens = prompt_tokens ,
693+ output_tokens = completion_tokens ,
694+ )
695+ else :
705696 cost = self .cost_provider .calculate_cost (
706697 model = model ,
707698 input_tokens = prompt_tokens ,
708699 output_tokens = completion_tokens ,
709700 )
701+
702+ # Track in internal budget dict
703+ if user in self ._user_budgets :
704+ self ._user_budgets [user ]["current_cost" ] += cost
705+
706+ # Also track in cascadeflow CostTracker if available
707+ if self .fallback_to_cascadeflow and hasattr (self , "cost_tracker" ) and self .cost_tracker :
710708 self .cost_tracker .add_cost (
711709 model = model ,
712710 provider = "" ,
713711 tokens = prompt_tokens + completion_tokens ,
714712 cost = cost ,
715713 user_id = user ,
716714 )
717- return cost
718715
719- # Just calculate cost without tracking
720- return self .cost_provider .calculate_cost (
721- model = model ,
722- input_tokens = prompt_tokens ,
723- output_tokens = completion_tokens ,
724- )
716+ logger .debug (f"Updated cost for { user } : ${ cost :.6f} " )
717+ return cost
725718
726719 def get_user_budget (self , user : str ) -> dict :
727720 """
@@ -742,23 +735,19 @@ def get_user_budget(self, user: str) -> dict:
742735 >>> print(f"Spent: ${info['current_cost']:.2f}")
743736 >>> print(f"Remaining: ${info['remaining']:.2f}")
744737 """
745- if self .budget_manager :
746- try :
747- budget = self .budget_manager .get_budget (user )
748-
749- max_budget = budget .get ("max_budget" , 0 )
750- current_cost = budget .get ("current_cost" , 0 )
751- remaining = max_budget - current_cost
752- exceeded = current_cost > max_budget
753-
754- return {
755- "max_budget" : max_budget ,
756- "current_cost" : current_cost ,
757- "remaining" : remaining ,
758- "exceeded" : exceeded ,
759- }
760- except Exception as e :
761- logger .error (f"Error getting budget for { user } : { e } " )
738+ budget = self ._user_budgets .get (user )
739+ if budget :
740+ max_budget = budget ["max_budget" ]
741+ current_cost = budget ["current_cost" ]
742+ remaining = max_budget - current_cost
743+ exceeded = current_cost > max_budget
744+
745+ return {
746+ "max_budget" : max_budget ,
747+ "current_cost" : current_cost ,
748+ "remaining" : remaining ,
749+ "exceeded" : exceeded ,
750+ }
762751
763752 return {
764753 "max_budget" : 0 ,
@@ -805,12 +794,9 @@ def reset_user_budget(self, user: str) -> None:
805794 Example:
806795 >>> tracker.reset_user_budget("user_123")
807796 """
808- if self .budget_manager :
809- try :
810- self .budget_manager .reset_cost (user = user )
811- logger .info (f"Reset budget for { user } " )
812- except Exception as e :
813- logger .error (f"Error resetting budget for { user } : { e } " )
797+ if user in self ._user_budgets :
798+ self ._user_budgets [user ]["current_cost" ] = 0.0
799+ logger .info (f"Reset budget for { user } " )
814800
815801
816802# ============================================================================
0 commit comments