1515from __future__ import annotations
1616
1717# Standard library imports
18- import json
1918import logging
2019import os
2120from typing import Any , Optional
2625from reactivex .scheduler import ThreadPoolScheduler
2726from reactivex .subject import Subject
2827import torch
29- from transformers import AutoModelForCausalLM , AutoTokenizer
28+ from transformers import AutoModelForCausalLM
3029
3130# Local imports
3231from dimos .agents .agent import LLMAgent
4342# Initialize logger for the agent module
4443logger = setup_logger ("dimos.agents" , level = logging .DEBUG )
4544
45+
4646# HuggingFaceLLMAgent Class
4747class HuggingFaceLocalAgent (LLMAgent ):
48- def __init__ (self ,
49- dev_name : str ,
50- agent_type : str = "HF-LLM" ,
51- model_name : str = "Qwen/Qwen2.5-3B " ,
52- device : str = "auto " ,
53- query : str = "How many r's are in the word 'strawberry'? " ,
54- input_query_stream : Optional [ Observable ] = None ,
55- input_video_stream : Optional [Observable ] = None ,
56- output_dir : str = os . path . join ( os . getcwd (), "assets" ,
57- "agent" ),
58- agent_memory : Optional [AbstractAgentSemanticMemory ] = None ,
59- system_query : Optional [str ] = None ,
60- max_output_tokens_per_request : int = None ,
61- max_input_tokens_per_request : int = None ,
62- prompt_builder : Optional [PromptBuilder ] = None ,
63- tokenizer : Optional [AbstractTokenizer ] = None ,
64- image_detail : str = "low" ,
65- pool_scheduler : Optional [ThreadPoolScheduler ] = None ,
66- process_all_inputs : Optional [bool ] = None ,):
67-
48+ def __init__ (
49+ self ,
50+ dev_name : str ,
51+ agent_type : str = "HF-LLM " ,
52+ model_name : str = "Qwen/Qwen2.5-3B " ,
53+ device : str = "auto " ,
54+ query : str = "How many r's are in the word 'strawberry'?" ,
55+ input_query_stream : Optional [Observable ] = None ,
56+ input_video_stream : Optional [ Observable ] = None ,
57+ output_dir : str = os . path . join ( os . getcwd (), "assets" , "agent" ),
58+ agent_memory : Optional [AbstractAgentSemanticMemory ] = None ,
59+ system_query : Optional [str ] = None ,
60+ max_output_tokens_per_request : int = None ,
61+ max_input_tokens_per_request : int = None ,
62+ prompt_builder : Optional [PromptBuilder ] = None ,
63+ tokenizer : Optional [AbstractTokenizer ] = None ,
64+ image_detail : str = "low" ,
65+ pool_scheduler : Optional [ThreadPoolScheduler ] = None ,
66+ process_all_inputs : Optional [bool ] = None ,
67+ ):
6868 # Determine appropriate default for process_all_inputs if not provided
6969 if process_all_inputs is None :
7070 # Default to True for text queries, False for video streams
@@ -79,7 +79,7 @@ def __init__(self,
7979 agent_memory = agent_memory or LocalSemanticMemory (),
8080 pool_scheduler = pool_scheduler ,
8181 process_all_inputs = process_all_inputs ,
82- system_query = system_query
82+ system_query = system_query ,
8383 )
8484
8585 self .query = query
@@ -98,15 +98,10 @@ def __init__(self,
9898
9999 self .tokenizer = tokenizer or HuggingFaceTokenizer (self .model_name )
100100
101- self .prompt_builder = prompt_builder or PromptBuilder (
102- self .model_name ,
103- tokenizer = self .tokenizer
104- )
101+ self .prompt_builder = prompt_builder or PromptBuilder (self .model_name , tokenizer = self .tokenizer )
105102
106103 self .model = AutoModelForCausalLM .from_pretrained (
107- model_name ,
108- torch_dtype = torch .float16 if self .device == "cuda" else torch .float32 ,
109- device_map = self .device
104+ model_name , torch_dtype = torch .float16 if self .device == "cuda" else torch .float32 , device_map = self .device
110105 )
111106
112107 self .max_output_tokens_per_request = max_output_tokens_per_request
@@ -118,111 +113,101 @@ def __init__(self,
118113
119114 # Ensure only one input stream is provided.
120115 if self .input_video_stream is not None and self .input_query_stream is not None :
121- raise ValueError (
122- "More than one input stream provided. Please provide only one input stream."
123- )
116+ raise ValueError ("More than one input stream provided. Please provide only one input stream." )
124117
125118 if self .input_video_stream is not None :
126119 logger .info ("Subscribing to input video stream..." )
127- self .disposables .add (
128- self .subscribe_to_image_processing (self .input_video_stream ))
120+ self .disposables .add (self .subscribe_to_image_processing (self .input_video_stream ))
129121 if self .input_query_stream is not None :
130122 logger .info ("Subscribing to input query stream..." )
131- self .disposables .add (
132- self .subscribe_to_query_processing (self .input_query_stream ))
133-
123+ self .disposables .add (self .subscribe_to_query_processing (self .input_query_stream ))
134124
135125 def _send_query (self , messages : list ) -> Any :
136126 _BLUE_PRINT_COLOR : str = "\033 [34m"
137127 _RESET_COLOR : str = "\033 [0m"
138-
128+
139129 try :
140130 # Log the incoming messages
141131 print (f"{ _BLUE_PRINT_COLOR } Messages: { str (messages )} { _RESET_COLOR } " )
142-
132+
143133 # Process with chat template
144134 try :
145135 print ("Applying chat template..." )
146136 prompt_text = self .tokenizer .tokenizer .apply_chat_template (
147137 conversation = [{"role" : "user" , "content" : str (messages )}],
148138 tokenize = False ,
149- add_generation_prompt = True
139+ add_generation_prompt = True ,
150140 )
151141 print ("Chat template applied." )
152-
142+
153143 # Tokenize the prompt
154144 print ("Preparing model inputs..." )
155145 model_inputs = self .tokenizer .tokenizer ([prompt_text ], return_tensors = "pt" ).to (self .model .device )
156146 print ("Model inputs prepared." )
157-
147+
158148 # Generate the response
159149 print ("Generating response..." )
160- generated_ids = self .model .generate (
161- ** model_inputs ,
162- max_new_tokens = self .max_output_tokens_per_request
163- )
164-
150+ generated_ids = self .model .generate (** model_inputs , max_new_tokens = self .max_output_tokens_per_request )
151+
165152 # Extract the generated tokens (excluding the input prompt tokens)
166153 print ("Processing generated output..." )
167154 generated_ids = [
168- output_ids [len (input_ids ):]
169- for input_ids , output_ids in zip (model_inputs .input_ids , generated_ids )
155+ output_ids [len (input_ids ) :] for input_ids , output_ids in zip (model_inputs .input_ids , generated_ids )
170156 ]
171-
157+
172158 # Convert tokens back to text
173159 response = self .tokenizer .tokenizer .batch_decode (generated_ids , skip_special_tokens = True )[0 ]
174160 print ("Response successfully generated." )
175-
161+
176162 return response
177-
163+
178164 except AttributeError as e :
179165 # Handle case where tokenizer doesn't have the expected methods
180166 logger .warning (f"Chat template not available: { e } . Using simple format." )
181167 # Continue with execution and use simple format
182-
168+
183169 except Exception as e :
184170 # Log any other errors but continue execution
185171 logger .warning (f"Error in chat template processing: { e } . Falling back to simple format." )
186-
172+
187173 # Fallback approach for models without chat template support
188174 # This code runs if the try block above raises an exception
189175 print ("Using simple prompt format..." )
190-
176+
191177 # Convert messages to a simple text format
192178 if isinstance (messages , list ) and messages and isinstance (messages [0 ], dict ) and "content" in messages [0 ]:
193179 prompt_text = messages [0 ]["content" ]
194180 else :
195181 prompt_text = str (messages )
196-
182+
197183 # Tokenize the prompt
198184 model_inputs = self .tokenizer .tokenize_text (prompt_text )
199185 model_inputs = torch .tensor ([model_inputs ], device = self .model .device )
200-
186+
201187 # Generate the response
202188 generated_ids = self .model .generate (
203- input_ids = model_inputs ,
204- max_new_tokens = self .max_output_tokens_per_request
189+ input_ids = model_inputs , max_new_tokens = self .max_output_tokens_per_request
205190 )
206-
191+
207192 # Extract the generated tokens
208- generated_ids = generated_ids [0 ][len (model_inputs [0 ]):]
209-
193+ generated_ids = generated_ids [0 ][len (model_inputs [0 ]) :]
194+
210195 # Convert tokens back to text
211196 response = self .tokenizer .detokenize_text (generated_ids .tolist ())
212197 print ("Response generated using simple format." )
213-
198+
214199 return response
215-
200+
216201 except Exception as e :
217202 # Catch all other errors
218203 logger .error (f"Error during query processing: { e } " , exc_info = True )
219- return f "Error processing request. Please try again."
204+ return "Error processing request. Please try again."
220205
221206 def stream_query (self , query_text : str ) -> Subject :
222207 """
223208 Creates an observable that processes a text query and emits the response.
224209 """
225- return create (lambda observer , _ : self ._observable_query (
226- observer , incoming_query = query_text ))
210+ return create (lambda observer , _ : self ._observable_query (observer , incoming_query = query_text ))
211+
227212
228213# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation)
0 commit comments