@@ -175,7 +175,9 @@ def __init__(
175175 self .image_detail : str = "low"
176176 self .max_input_tokens_per_request : int = max_input_tokens_per_request
177177 self .max_output_tokens_per_request : int = max_output_tokens_per_request
178- self .max_tokens_per_request : int = self .max_input_tokens_per_request + self .max_output_tokens_per_request
178+ self .max_tokens_per_request : int = (
179+ self .max_input_tokens_per_request + self .max_output_tokens_per_request
180+ )
179181 self .rag_query_n : int = 4
180182 self .rag_similarity_threshold : float = 0.45
181183 self .frame_processor : Optional [FrameProcessor ] = None
@@ -200,10 +202,14 @@ def __init__(
200202 RxOps .map (
201203 lambda combined : {
202204 "query" : combined [0 ],
203- "objects" : combined [1 ] if len (combined ) > 1 else "No object data available" ,
205+ "objects" : combined [1 ]
206+ if len (combined ) > 1
207+ else "No object data available" ,
204208 }
205209 ),
206- RxOps .map (lambda data : f"{ data ['query' ]} \n \n Current objects detected:\n { data ['objects' ]} " ),
210+ RxOps .map (
211+ lambda data : f"{ data ['query' ]} \n \n Current objects detected:\n { data ['objects' ]} "
212+ ),
207213 RxOps .do_action (
208214 lambda x : print (f"\033 [34mEnriched query: { x .split (chr (10 ))[0 ]} \033 [0m" )
209215 or [print (f"\033 [34m{ line } \033 [0m" ) for line in x .split (chr (10 ))[1 :]]
@@ -222,7 +228,9 @@ def __init__(
222228 # Define a query extractor for the merged stream
223229 query_extractor = lambda emission : (emission [0 ], emission [1 ][0 ])
224230 self .disposables .add (
225- self .subscribe_to_image_processing (self .merged_stream , query_extractor = query_extractor )
231+ self .subscribe_to_image_processing (
232+ self .merged_stream , query_extractor = query_extractor
233+ )
226234 )
227235 else :
228236 # If no merged stream, fall back to individual streams
@@ -250,7 +258,9 @@ def _get_rag_context(self) -> Tuple[str, str]:
250258 and condensed results (for use in the prompt).
251259 """
252260 results = self .agent_memory .query (
253- query_texts = self .query , n_results = self .rag_query_n , similarity_threshold = self .rag_similarity_threshold
261+ query_texts = self .query ,
262+ n_results = self .rag_query_n ,
263+ similarity_threshold = self .rag_similarity_threshold ,
254264 )
255265 formatted_results = "\n " .join (
256266 f"Document ID: { doc .id } \n Metadata: { doc .metadata } \n Content: { doc .page_content } \n Score: { score } \n "
@@ -334,7 +344,12 @@ def _tooling_callback(message, messages, response_message, skill_library: SkillL
334344 result = skill_library .call (name , ** args )
335345 logger .info (f"Function Call Results: { result } " )
336346 new_messages .append (
337- {"role" : "tool" , "tool_call_id" : tool_call .id , "content" : str (result ), "name" : name }
347+ {
348+ "role" : "tool" ,
349+ "tool_call_id" : tool_call .id ,
350+ "content" : str (result ),
351+ "name" : name ,
352+ }
338353 )
339354 if has_called_tools :
340355 logger .info ("Sending Another Query." )
@@ -347,7 +362,9 @@ def _tooling_callback(message, messages, response_message, skill_library: SkillL
347362 return None
348363
349364 if response_message .tool_calls is not None :
350- return _tooling_callback (response_message , messages , response_message , self .skill_library )
365+ return _tooling_callback (
366+ response_message , messages , response_message , self .skill_library
367+ )
351368 return None
352369
353370 def _observable_query (
@@ -373,7 +390,9 @@ def _observable_query(
373390 try :
374391 self ._update_query (incoming_query )
375392 _ , condensed_results = self ._get_rag_context ()
376- messages = self ._build_prompt (base64_image , dimensions , override_token_limit , condensed_results )
393+ messages = self ._build_prompt (
394+ base64_image , dimensions , override_token_limit , condensed_results
395+ )
377396 # logger.debug(f"Sending Query: {messages}")
378397 logger .info ("Sending Query." )
379398 response_message = self ._send_query (messages )
@@ -391,13 +410,19 @@ def _observable_query(
391410 final_msg = (
392411 response_message .parsed
393412 if hasattr (response_message , "parsed" ) and response_message .parsed
394- else (response_message .content if hasattr (response_message , "content" ) else response_message )
413+ else (
414+ response_message .content
415+ if hasattr (response_message , "content" )
416+ else response_message
417+ )
395418 )
396419 observer .on_next (final_msg )
397420 self .response_subject .on_next (final_msg )
398421 else :
399422 response_message_2 = self ._handle_tooling (response_message , messages )
400- final_msg = response_message_2 if response_message_2 is not None else response_message
423+ final_msg = (
424+ response_message_2 if response_message_2 is not None else response_message
425+ )
401426 if isinstance (final_msg , BaseModel ): # TODO: Test
402427 final_msg = str (final_msg .content )
403428 observer .on_next (final_msg )
@@ -440,7 +465,9 @@ def _log_response_to_file(self, response, output_dir: str = None):
440465 file .write (f"{ self .dev_name } : { response } \n " )
441466 logger .info (f"LLM Response [{ self .dev_name } ]: { response } " )
442467
443- def subscribe_to_image_processing (self , frame_observable : Observable , query_extractor = None ) -> Disposable :
468+ def subscribe_to_image_processing (
469+ self , frame_observable : Observable , query_extractor = None
470+ ) -> Disposable :
444471 """Subscribes to a stream of video frames for processing.
445472
446473 This method sets up a subscription to process incoming video frames.
@@ -480,7 +507,9 @@ def _process_frame(emission) -> Observable:
480507 RxOps .subscribe_on (self .pool_scheduler ),
481508 MyOps .print_emission (id = "D" , ** print_emission_args ),
482509 MyVidOps .with_jpeg_export (
483- self .frame_processor , suffix = f"{ self .dev_name } _frame_" , save_limit = _MAX_SAVED_FRAMES
510+ self .frame_processor ,
511+ suffix = f"{ self .dev_name } _frame_" ,
512+ save_limit = _MAX_SAVED_FRAMES ,
484513 ),
485514 MyOps .print_emission (id = "E" , ** print_emission_args ),
486515 MyVidOps .encode_image (),
@@ -562,7 +591,9 @@ def _process_query(query) -> Observable:
562591 return just (query ).pipe (
563592 MyOps .print_emission (id = "Pr A" , ** print_emission_args ),
564593 RxOps .flat_map (
565- lambda query : create (lambda observer , _ : self ._observable_query (observer , incoming_query = query ))
594+ lambda query : create (
595+ lambda observer , _ : self ._observable_query (observer , incoming_query = query )
596+ )
566597 ),
567598 MyOps .print_emission (id = "Pr B" , ** print_emission_args ),
568599 )
@@ -612,7 +643,9 @@ def get_response_observable(self) -> Observable:
612643 Observable: An observable that emits string responses from the agent.
613644 """
614645 return self .response_subject .pipe (
615- RxOps .observe_on (self .pool_scheduler ), RxOps .subscribe_on (self .pool_scheduler ), RxOps .share ()
646+ RxOps .observe_on (self .pool_scheduler ),
647+ RxOps .subscribe_on (self .pool_scheduler ),
648+ RxOps .share (),
616649 )
617650
618651 def run_observable_query (self , query_text : str , ** kwargs ) -> Observable :
@@ -631,7 +664,11 @@ def run_observable_query(self, query_text: str, **kwargs) -> Observable:
631664 Returns:
632665 Observable: An observable that emits the response as a string.
633666 """
634- return create (lambda observer , _ : self ._observable_query (observer , incoming_query = query_text , ** kwargs ))
667+ return create (
668+ lambda observer , _ : self ._observable_query (
669+ observer , incoming_query = query_text , ** kwargs
670+ )
671+ )
635672
636673 def dispose_all (self ):
637674 """Disposes of all active subscriptions managed by this agent."""
@@ -749,7 +786,9 @@ def __init__(
749786 self .response_model = response_model if response_model is not None else NOT_GIVEN
750787 self .model_name = model_name
751788 self .tokenizer = tokenizer or OpenAITokenizer (model_name = self .model_name )
752- self .prompt_builder = prompt_builder or PromptBuilder (self .model_name , tokenizer = self .tokenizer )
789+ self .prompt_builder = prompt_builder or PromptBuilder (
790+ self .model_name , tokenizer = self .tokenizer
791+ )
753792 self .rag_query_n = rag_query_n
754793 self .rag_similarity_threshold = rag_similarity_threshold
755794 self .image_detail = image_detail
@@ -767,8 +806,14 @@ def __init__(
767806 def _add_context_to_memory (self ):
768807 """Adds initial context to the agent's memory."""
769808 context_data = [
770- ("id0" , "Optical Flow is a technique used to track the movement of objects in a video sequence." ),
771- ("id1" , "Edge Detection is a technique used to identify the boundaries of objects in an image." ),
809+ (
810+ "id0" ,
811+ "Optical Flow is a technique used to track the movement of objects in a video sequence." ,
812+ ),
813+ (
814+ "id1" ,
815+ "Edge Detection is a technique used to identify the boundaries of objects in an image." ,
816+ ),
772817 ("id2" , "Video is a sequence of frames captured at regular intervals." ),
773818 (
774819 "id3" ,
@@ -805,15 +850,23 @@ def _send_query(self, messages: list) -> Any:
805850 model = self .model_name ,
806851 messages = messages ,
807852 response_format = self .response_model ,
808- tools = (self .skill_library .get_tools () if self .skill_library is not None else NOT_GIVEN ),
853+ tools = (
854+ self .skill_library .get_tools ()
855+ if self .skill_library is not None
856+ else NOT_GIVEN
857+ ),
809858 max_tokens = self .max_output_tokens_per_request ,
810859 )
811860 else :
812861 response = self .client .chat .completions .create (
813862 model = self .model_name ,
814863 messages = messages ,
815864 max_tokens = self .max_output_tokens_per_request ,
816- tools = (self .skill_library .get_tools () if self .skill_library is not None else NOT_GIVEN ),
865+ tools = (
866+ self .skill_library .get_tools ()
867+ if self .skill_library is not None
868+ else NOT_GIVEN
869+ ),
817870 )
818871 response_message = response .choices [0 ].message
819872 if response_message is None :
@@ -843,7 +896,9 @@ def stream_query(self, query_text: str) -> Observable:
843896 Returns:
844897 Observable: An observable that emits the response as a string.
845898 """
846- return create (lambda observer , _ : self ._observable_query (observer , incoming_query = query_text ))
899+ return create (
900+ lambda observer , _ : self ._observable_query (observer , incoming_query = query_text )
901+ )
847902
848903
849904# endregion OpenAIAgent Subclass (OpenAI-Specific Implementation)
0 commit comments