Skip to content

Commit d0191da

Browse files
authored
Merge pull request #301 - Fix global 100 line length
Ruff line length 100
2 parents b6b7f72 + 373f6ce commit d0191da

File tree

189 files changed

+3617
-978
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

189 files changed

+3617
-978
lines changed

dimos/agents/agent.py

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ def __init__(
175175
self.image_detail: str = "low"
176176
self.max_input_tokens_per_request: int = max_input_tokens_per_request
177177
self.max_output_tokens_per_request: int = max_output_tokens_per_request
178-
self.max_tokens_per_request: int = self.max_input_tokens_per_request + self.max_output_tokens_per_request
178+
self.max_tokens_per_request: int = (
179+
self.max_input_tokens_per_request + self.max_output_tokens_per_request
180+
)
179181
self.rag_query_n: int = 4
180182
self.rag_similarity_threshold: float = 0.45
181183
self.frame_processor: Optional[FrameProcessor] = None
@@ -200,10 +202,14 @@ def __init__(
200202
RxOps.map(
201203
lambda combined: {
202204
"query": combined[0],
203-
"objects": combined[1] if len(combined) > 1 else "No object data available",
205+
"objects": combined[1]
206+
if len(combined) > 1
207+
else "No object data available",
204208
}
205209
),
206-
RxOps.map(lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}"),
210+
RxOps.map(
211+
lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}"
212+
),
207213
RxOps.do_action(
208214
lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m")
209215
or [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]]
@@ -222,7 +228,9 @@ def __init__(
222228
# Define a query extractor for the merged stream
223229
query_extractor = lambda emission: (emission[0], emission[1][0])
224230
self.disposables.add(
225-
self.subscribe_to_image_processing(self.merged_stream, query_extractor=query_extractor)
231+
self.subscribe_to_image_processing(
232+
self.merged_stream, query_extractor=query_extractor
233+
)
226234
)
227235
else:
228236
# If no merged stream, fall back to individual streams
@@ -250,7 +258,9 @@ def _get_rag_context(self) -> Tuple[str, str]:
250258
and condensed results (for use in the prompt).
251259
"""
252260
results = self.agent_memory.query(
253-
query_texts=self.query, n_results=self.rag_query_n, similarity_threshold=self.rag_similarity_threshold
261+
query_texts=self.query,
262+
n_results=self.rag_query_n,
263+
similarity_threshold=self.rag_similarity_threshold,
254264
)
255265
formatted_results = "\n".join(
256266
f"Document ID: {doc.id}\nMetadata: {doc.metadata}\nContent: {doc.page_content}\nScore: {score}\n"
@@ -334,7 +344,12 @@ def _tooling_callback(message, messages, response_message, skill_library: SkillL
334344
result = skill_library.call(name, **args)
335345
logger.info(f"Function Call Results: {result}")
336346
new_messages.append(
337-
{"role": "tool", "tool_call_id": tool_call.id, "content": str(result), "name": name}
347+
{
348+
"role": "tool",
349+
"tool_call_id": tool_call.id,
350+
"content": str(result),
351+
"name": name,
352+
}
338353
)
339354
if has_called_tools:
340355
logger.info("Sending Another Query.")
@@ -347,7 +362,9 @@ def _tooling_callback(message, messages, response_message, skill_library: SkillL
347362
return None
348363

349364
if response_message.tool_calls is not None:
350-
return _tooling_callback(response_message, messages, response_message, self.skill_library)
365+
return _tooling_callback(
366+
response_message, messages, response_message, self.skill_library
367+
)
351368
return None
352369

353370
def _observable_query(
@@ -373,7 +390,9 @@ def _observable_query(
373390
try:
374391
self._update_query(incoming_query)
375392
_, condensed_results = self._get_rag_context()
376-
messages = self._build_prompt(base64_image, dimensions, override_token_limit, condensed_results)
393+
messages = self._build_prompt(
394+
base64_image, dimensions, override_token_limit, condensed_results
395+
)
377396
# logger.debug(f"Sending Query: {messages}")
378397
logger.info("Sending Query.")
379398
response_message = self._send_query(messages)
@@ -391,13 +410,19 @@ def _observable_query(
391410
final_msg = (
392411
response_message.parsed
393412
if hasattr(response_message, "parsed") and response_message.parsed
394-
else (response_message.content if hasattr(response_message, "content") else response_message)
413+
else (
414+
response_message.content
415+
if hasattr(response_message, "content")
416+
else response_message
417+
)
395418
)
396419
observer.on_next(final_msg)
397420
self.response_subject.on_next(final_msg)
398421
else:
399422
response_message_2 = self._handle_tooling(response_message, messages)
400-
final_msg = response_message_2 if response_message_2 is not None else response_message
423+
final_msg = (
424+
response_message_2 if response_message_2 is not None else response_message
425+
)
401426
if isinstance(final_msg, BaseModel): # TODO: Test
402427
final_msg = str(final_msg.content)
403428
observer.on_next(final_msg)
@@ -440,7 +465,9 @@ def _log_response_to_file(self, response, output_dir: str = None):
440465
file.write(f"{self.dev_name}: {response}\n")
441466
logger.info(f"LLM Response [{self.dev_name}]: {response}")
442467

443-
def subscribe_to_image_processing(self, frame_observable: Observable, query_extractor=None) -> Disposable:
468+
def subscribe_to_image_processing(
469+
self, frame_observable: Observable, query_extractor=None
470+
) -> Disposable:
444471
"""Subscribes to a stream of video frames for processing.
445472
446473
This method sets up a subscription to process incoming video frames.
@@ -480,7 +507,9 @@ def _process_frame(emission) -> Observable:
480507
RxOps.subscribe_on(self.pool_scheduler),
481508
MyOps.print_emission(id="D", **print_emission_args),
482509
MyVidOps.with_jpeg_export(
483-
self.frame_processor, suffix=f"{self.dev_name}_frame_", save_limit=_MAX_SAVED_FRAMES
510+
self.frame_processor,
511+
suffix=f"{self.dev_name}_frame_",
512+
save_limit=_MAX_SAVED_FRAMES,
484513
),
485514
MyOps.print_emission(id="E", **print_emission_args),
486515
MyVidOps.encode_image(),
@@ -562,7 +591,9 @@ def _process_query(query) -> Observable:
562591
return just(query).pipe(
563592
MyOps.print_emission(id="Pr A", **print_emission_args),
564593
RxOps.flat_map(
565-
lambda query: create(lambda observer, _: self._observable_query(observer, incoming_query=query))
594+
lambda query: create(
595+
lambda observer, _: self._observable_query(observer, incoming_query=query)
596+
)
566597
),
567598
MyOps.print_emission(id="Pr B", **print_emission_args),
568599
)
@@ -612,7 +643,9 @@ def get_response_observable(self) -> Observable:
612643
Observable: An observable that emits string responses from the agent.
613644
"""
614645
return self.response_subject.pipe(
615-
RxOps.observe_on(self.pool_scheduler), RxOps.subscribe_on(self.pool_scheduler), RxOps.share()
646+
RxOps.observe_on(self.pool_scheduler),
647+
RxOps.subscribe_on(self.pool_scheduler),
648+
RxOps.share(),
616649
)
617650

618651
def run_observable_query(self, query_text: str, **kwargs) -> Observable:
@@ -631,7 +664,11 @@ def run_observable_query(self, query_text: str, **kwargs) -> Observable:
631664
Returns:
632665
Observable: An observable that emits the response as a string.
633666
"""
634-
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text, **kwargs))
667+
return create(
668+
lambda observer, _: self._observable_query(
669+
observer, incoming_query=query_text, **kwargs
670+
)
671+
)
635672

636673
def dispose_all(self):
637674
"""Disposes of all active subscriptions managed by this agent."""
@@ -749,7 +786,9 @@ def __init__(
749786
self.response_model = response_model if response_model is not None else NOT_GIVEN
750787
self.model_name = model_name
751788
self.tokenizer = tokenizer or OpenAITokenizer(model_name=self.model_name)
752-
self.prompt_builder = prompt_builder or PromptBuilder(self.model_name, tokenizer=self.tokenizer)
789+
self.prompt_builder = prompt_builder or PromptBuilder(
790+
self.model_name, tokenizer=self.tokenizer
791+
)
753792
self.rag_query_n = rag_query_n
754793
self.rag_similarity_threshold = rag_similarity_threshold
755794
self.image_detail = image_detail
@@ -767,8 +806,14 @@ def __init__(
767806
def _add_context_to_memory(self):
768807
"""Adds initial context to the agent's memory."""
769808
context_data = [
770-
("id0", "Optical Flow is a technique used to track the movement of objects in a video sequence."),
771-
("id1", "Edge Detection is a technique used to identify the boundaries of objects in an image."),
809+
(
810+
"id0",
811+
"Optical Flow is a technique used to track the movement of objects in a video sequence.",
812+
),
813+
(
814+
"id1",
815+
"Edge Detection is a technique used to identify the boundaries of objects in an image.",
816+
),
772817
("id2", "Video is a sequence of frames captured at regular intervals."),
773818
(
774819
"id3",
@@ -805,15 +850,23 @@ def _send_query(self, messages: list) -> Any:
805850
model=self.model_name,
806851
messages=messages,
807852
response_format=self.response_model,
808-
tools=(self.skill_library.get_tools() if self.skill_library is not None else NOT_GIVEN),
853+
tools=(
854+
self.skill_library.get_tools()
855+
if self.skill_library is not None
856+
else NOT_GIVEN
857+
),
809858
max_tokens=self.max_output_tokens_per_request,
810859
)
811860
else:
812861
response = self.client.chat.completions.create(
813862
model=self.model_name,
814863
messages=messages,
815864
max_tokens=self.max_output_tokens_per_request,
816-
tools=(self.skill_library.get_tools() if self.skill_library is not None else NOT_GIVEN),
865+
tools=(
866+
self.skill_library.get_tools()
867+
if self.skill_library is not None
868+
else NOT_GIVEN
869+
),
817870
)
818871
response_message = response.choices[0].message
819872
if response_message is None:
@@ -843,7 +896,9 @@ def stream_query(self, query_text: str) -> Observable:
843896
Returns:
844897
Observable: An observable that emits the response as a string.
845898
"""
846-
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
899+
return create(
900+
lambda observer, _: self._observable_query(observer, incoming_query=query_text)
901+
)
847902

848903

849904
# endregion OpenAIAgent Subclass (OpenAI-Specific Implementation)

dimos/agents/agent_ctransformers_gguf.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def __init__(
141141

142142
self.tokenizer = CTransformersTokenizerAdapter(self.model)
143143

144-
self.prompt_builder = prompt_builder or PromptBuilder(self.model_name, tokenizer=self.tokenizer)
144+
self.prompt_builder = prompt_builder or PromptBuilder(
145+
self.model_name, tokenizer=self.tokenizer
146+
)
145147

146148
self.max_output_tokens_per_request = max_output_tokens_per_request
147149

@@ -152,7 +154,9 @@ def __init__(
152154

153155
# Ensure only one input stream is provided.
154156
if self.input_video_stream is not None and self.input_query_stream is not None:
155-
raise ValueError("More than one input stream provided. Please provide only one input stream.")
157+
raise ValueError(
158+
"More than one input stream provided. Please provide only one input stream."
159+
)
156160

157161
if self.input_video_stream is not None:
158162
logger.info("Subscribing to input video stream...")
@@ -198,7 +202,9 @@ def stream_query(self, query_text: str) -> Subject:
198202
"""
199203
Creates an observable that processes a text query and emits the response.
200204
"""
201-
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
205+
return create(
206+
lambda observer, _: self._observable_query(observer, incoming_query=query_text)
207+
)
202208

203209

204210
# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation)

dimos/agents/agent_huggingface_local.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,14 @@ def __init__(
9898

9999
self.tokenizer = tokenizer or HuggingFaceTokenizer(self.model_name)
100100

101-
self.prompt_builder = prompt_builder or PromptBuilder(self.model_name, tokenizer=self.tokenizer)
101+
self.prompt_builder = prompt_builder or PromptBuilder(
102+
self.model_name, tokenizer=self.tokenizer
103+
)
102104

103105
self.model = AutoModelForCausalLM.from_pretrained(
104-
model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map=self.device
106+
model_name,
107+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
108+
device_map=self.device,
105109
)
106110

107111
self.max_output_tokens_per_request = max_output_tokens_per_request
@@ -113,7 +117,9 @@ def __init__(
113117

114118
# Ensure only one input stream is provided.
115119
if self.input_video_stream is not None and self.input_query_stream is not None:
116-
raise ValueError("More than one input stream provided. Please provide only one input stream.")
120+
raise ValueError(
121+
"More than one input stream provided. Please provide only one input stream."
122+
)
117123

118124
if self.input_video_stream is not None:
119125
logger.info("Subscribing to input video stream...")
@@ -142,21 +148,28 @@ def _send_query(self, messages: list) -> Any:
142148

143149
# Tokenize the prompt
144150
print("Preparing model inputs...")
145-
model_inputs = self.tokenizer.tokenizer([prompt_text], return_tensors="pt").to(self.model.device)
151+
model_inputs = self.tokenizer.tokenizer([prompt_text], return_tensors="pt").to(
152+
self.model.device
153+
)
146154
print("Model inputs prepared.")
147155

148156
# Generate the response
149157
print("Generating response...")
150-
generated_ids = self.model.generate(**model_inputs, max_new_tokens=self.max_output_tokens_per_request)
158+
generated_ids = self.model.generate(
159+
**model_inputs, max_new_tokens=self.max_output_tokens_per_request
160+
)
151161

152162
# Extract the generated tokens (excluding the input prompt tokens)
153163
print("Processing generated output...")
154164
generated_ids = [
155-
output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
165+
output_ids[len(input_ids) :]
166+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
156167
]
157168

158169
# Convert tokens back to text
159-
response = self.tokenizer.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
170+
response = self.tokenizer.tokenizer.batch_decode(
171+
generated_ids, skip_special_tokens=True
172+
)[0]
160173
print("Response successfully generated.")
161174

162175
return response
@@ -168,14 +181,21 @@ def _send_query(self, messages: list) -> Any:
168181

169182
except Exception as e:
170183
# Log any other errors but continue execution
171-
logger.warning(f"Error in chat template processing: {e}. Falling back to simple format.")
184+
logger.warning(
185+
f"Error in chat template processing: {e}. Falling back to simple format."
186+
)
172187

173188
# Fallback approach for models without chat template support
174189
# This code runs if the try block above raises an exception
175190
print("Using simple prompt format...")
176191

177192
# Convert messages to a simple text format
178-
if isinstance(messages, list) and messages and isinstance(messages[0], dict) and "content" in messages[0]:
193+
if (
194+
isinstance(messages, list)
195+
and messages
196+
and isinstance(messages[0], dict)
197+
and "content" in messages[0]
198+
):
179199
prompt_text = messages[0]["content"]
180200
else:
181201
prompt_text = str(messages)
@@ -207,7 +227,9 @@ def stream_query(self, query_text: str) -> Subject:
207227
"""
208228
Creates an observable that processes a text query and emits the response.
209229
"""
210-
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
230+
return create(
231+
lambda observer, _: self._observable_query(observer, incoming_query=query_text)
232+
)
211233

212234

213235
# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation)

dimos/agents/agent_huggingface_remote.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def __init__(
110110

111111
# Ensure only one input stream is provided.
112112
if self.input_video_stream is not None and self.input_query_stream is not None:
113-
raise ValueError("More than one input stream provided. Please provide only one input stream.")
113+
raise ValueError(
114+
"More than one input stream provided. Please provide only one input stream."
115+
)
114116

115117
if self.input_video_stream is not None:
116118
logger.info("Subscribing to input video stream...")
@@ -136,4 +138,6 @@ def stream_query(self, query_text: str) -> Subject:
136138
"""
137139
Creates an observable that processes a text query and emits the response.
138140
"""
139-
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
141+
return create(
142+
lambda observer, _: self._observable_query(observer, incoming_query=query_text)
143+
)

0 commit comments

Comments
 (0)