Skip to content

Commit aedd928

Browse files
authored
Merge pull request #300 - Global Ruff Reformat to 100 line length
Global reformat 100 line length
2 parents c328efd + e686ab2 commit aedd928

File tree

245 files changed

+31786
-10281
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

245 files changed

+31786
-10281
lines changed

dimos/agents/agent.py

Lines changed: 215 additions & 234 deletions
Large diffs are not rendered by default.

dimos/agents/agent_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import List
1616
from dimos.agents.agent import Agent
1717

18+
1819
class AgentConfig:
1920
def __init__(self, agents: List[Agent] = None):
2021
"""

dimos/agents/agent_ctransformers_gguf.py

Lines changed: 31 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
# Standard library imports
18-
import json
1918
import logging
2019
import os
2120
from typing import Any, Optional
@@ -26,14 +25,11 @@
2625
from reactivex.scheduler import ThreadPoolScheduler
2726
from reactivex.subject import Subject
2827
import torch
29-
from transformers import AutoModelForCausalLM, AutoTokenizer
3028

3129
# Local imports
3230
from dimos.agents.agent import LLMAgent
3331
from dimos.agents.memory.base import AbstractAgentSemanticMemory
3432
from dimos.agents.prompt_builder.impl import PromptBuilder
35-
from dimos.agents.tokenizer.base import AbstractTokenizer
36-
from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer
3733
from dimos.utils.logging_config import setup_logger
3834

3935
# Initialize environment variables
@@ -44,6 +40,7 @@
4440

4541
from ctransformers import AutoModelForCausalLM as CTransformersModel
4642

43+
4744
class CTransformersTokenizerAdapter:
4845
def __init__(self, model):
4946
self.model = model
@@ -84,26 +81,27 @@ def apply_chat_template(self, conversation, tokenize=False, add_generation_promp
8481

8582
# CTransformers Agent Class
8683
class CTransformersGGUFAgent(LLMAgent):
87-
def __init__(self,
88-
dev_name: str,
89-
agent_type: str = "HF-LLM",
90-
model_name: str = "TheBloke/Llama-2-7B-GGUF",
91-
model_file: str = "llama-2-7b.Q4_K_M.gguf",
92-
model_type: str = "llama",
93-
gpu_layers: int = 50,
94-
device: str = "auto",
95-
query: str = "How many r's are in the word 'strawberry'?",
96-
input_query_stream: Optional[Observable] = None,
97-
input_video_stream: Optional[Observable] = None,
98-
output_dir: str = os.path.join(os.getcwd(), "assets", "agent"),
99-
agent_memory: Optional[AbstractAgentSemanticMemory] = None,
100-
system_query: Optional[str] = "You are a helpful assistant.",
101-
max_output_tokens_per_request: int = 10,
102-
max_input_tokens_per_request: int = 250,
103-
prompt_builder: Optional[PromptBuilder] = None,
104-
pool_scheduler: Optional[ThreadPoolScheduler] = None,
105-
process_all_inputs: Optional[bool] = None,):
106-
84+
def __init__(
85+
self,
86+
dev_name: str,
87+
agent_type: str = "HF-LLM",
88+
model_name: str = "TheBloke/Llama-2-7B-GGUF",
89+
model_file: str = "llama-2-7b.Q4_K_M.gguf",
90+
model_type: str = "llama",
91+
gpu_layers: int = 50,
92+
device: str = "auto",
93+
query: str = "How many r's are in the word 'strawberry'?",
94+
input_query_stream: Optional[Observable] = None,
95+
input_video_stream: Optional[Observable] = None,
96+
output_dir: str = os.path.join(os.getcwd(), "assets", "agent"),
97+
agent_memory: Optional[AbstractAgentSemanticMemory] = None,
98+
system_query: Optional[str] = "You are a helpful assistant.",
99+
max_output_tokens_per_request: int = 10,
100+
max_input_tokens_per_request: int = 250,
101+
prompt_builder: Optional[PromptBuilder] = None,
102+
pool_scheduler: Optional[ThreadPoolScheduler] = None,
103+
process_all_inputs: Optional[bool] = None,
104+
):
107105
# Determine appropriate default for process_all_inputs if not provided
108106
if process_all_inputs is None:
109107
# Default to True for text queries, False for video streams
@@ -120,7 +118,7 @@ def __init__(self,
120118
process_all_inputs=process_all_inputs,
121119
system_query=system_query,
122120
max_output_tokens_per_request=max_output_tokens_per_request,
123-
max_input_tokens_per_request=max_input_tokens_per_request
121+
max_input_tokens_per_request=max_input_tokens_per_request,
124122
)
125123

126124
self.query = query
@@ -138,18 +136,12 @@ def __init__(self,
138136
print(f"Device: {self.device}")
139137

140138
self.model = CTransformersModel.from_pretrained(
141-
model_name,
142-
model_file=model_file,
143-
model_type=model_type,
144-
gpu_layers=gpu_layers
139+
model_name, model_file=model_file, model_type=model_type, gpu_layers=gpu_layers
145140
)
146141

147142
self.tokenizer = CTransformersTokenizerAdapter(self.model)
148143

149-
self.prompt_builder = prompt_builder or PromptBuilder(
150-
self.model_name,
151-
tokenizer=self.tokenizer
152-
)
144+
self.prompt_builder = prompt_builder or PromptBuilder(self.model_name, tokenizer=self.tokenizer)
153145

154146
self.max_output_tokens_per_request = max_output_tokens_per_request
155147

@@ -160,19 +152,14 @@ def __init__(self,
160152

161153
# Ensure only one input stream is provided.
162154
if self.input_video_stream is not None and self.input_query_stream is not None:
163-
raise ValueError(
164-
"More than one input stream provided. Please provide only one input stream."
165-
)
155+
raise ValueError("More than one input stream provided. Please provide only one input stream.")
166156

167157
if self.input_video_stream is not None:
168158
logger.info("Subscribing to input video stream...")
169-
self.disposables.add(
170-
self.subscribe_to_image_processing(self.input_video_stream))
159+
self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream))
171160
if self.input_query_stream is not None:
172161
logger.info("Subscribing to input query stream...")
173-
self.disposables.add(
174-
self.subscribe_to_query_processing(self.input_query_stream))
175-
162+
self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream))
176163

177164
def _send_query(self, messages: list) -> Any:
178165
try:
@@ -194,9 +181,7 @@ def _send_query(self, messages: list) -> Any:
194181

195182
print("Applying chat template...")
196183
prompt_text = self.tokenizer.apply_chat_template(
197-
conversation=flat_messages,
198-
tokenize=False,
199-
add_generation_prompt=True
184+
conversation=flat_messages, tokenize=False, add_generation_prompt=True
200185
)
201186
print("Chat template applied.")
202187
print(f"Prompt text:\n{prompt_text}")
@@ -213,7 +198,7 @@ def stream_query(self, query_text: str) -> Subject:
213198
"""
214199
Creates an observable that processes a text query and emits the response.
215200
"""
216-
return create(lambda observer, _: self._observable_query(
217-
observer, incoming_query=query_text))
201+
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
202+
218203

219204
# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation)

dimos/agents/agent_huggingface_local.py

Lines changed: 53 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
# Standard library imports
18-
import json
1918
import logging
2019
import os
2120
from typing import Any, Optional
@@ -26,7 +25,7 @@
2625
from reactivex.scheduler import ThreadPoolScheduler
2726
from reactivex.subject import Subject
2827
import torch
29-
from transformers import AutoModelForCausalLM, AutoTokenizer
28+
from transformers import AutoModelForCausalLM
3029

3130
# Local imports
3231
from dimos.agents.agent import LLMAgent
@@ -43,28 +42,29 @@
4342
# Initialize logger for the agent module
4443
logger = setup_logger("dimos.agents", level=logging.DEBUG)
4544

45+
4646
# HuggingFaceLLMAgent Class
4747
class HuggingFaceLocalAgent(LLMAgent):
48-
def __init__(self,
49-
dev_name: str,
50-
agent_type: str = "HF-LLM",
51-
model_name: str = "Qwen/Qwen2.5-3B",
52-
device: str = "auto",
53-
query: str = "How many r's are in the word 'strawberry'?",
54-
input_query_stream: Optional[Observable] = None,
55-
input_video_stream: Optional[Observable] = None,
56-
output_dir: str = os.path.join(os.getcwd(), "assets",
57-
"agent"),
58-
agent_memory: Optional[AbstractAgentSemanticMemory] = None,
59-
system_query: Optional[str] = None,
60-
max_output_tokens_per_request: int = None,
61-
max_input_tokens_per_request: int = None,
62-
prompt_builder: Optional[PromptBuilder] = None,
63-
tokenizer: Optional[AbstractTokenizer] = None,
64-
image_detail: str = "low",
65-
pool_scheduler: Optional[ThreadPoolScheduler] = None,
66-
process_all_inputs: Optional[bool] = None,):
67-
48+
def __init__(
49+
self,
50+
dev_name: str,
51+
agent_type: str = "HF-LLM",
52+
model_name: str = "Qwen/Qwen2.5-3B",
53+
device: str = "auto",
54+
query: str = "How many r's are in the word 'strawberry'?",
55+
input_query_stream: Optional[Observable] = None,
56+
input_video_stream: Optional[Observable] = None,
57+
output_dir: str = os.path.join(os.getcwd(), "assets", "agent"),
58+
agent_memory: Optional[AbstractAgentSemanticMemory] = None,
59+
system_query: Optional[str] = None,
60+
max_output_tokens_per_request: int = None,
61+
max_input_tokens_per_request: int = None,
62+
prompt_builder: Optional[PromptBuilder] = None,
63+
tokenizer: Optional[AbstractTokenizer] = None,
64+
image_detail: str = "low",
65+
pool_scheduler: Optional[ThreadPoolScheduler] = None,
66+
process_all_inputs: Optional[bool] = None,
67+
):
6868
# Determine appropriate default for process_all_inputs if not provided
6969
if process_all_inputs is None:
7070
# Default to True for text queries, False for video streams
@@ -79,7 +79,7 @@ def __init__(self,
7979
agent_memory=agent_memory or LocalSemanticMemory(),
8080
pool_scheduler=pool_scheduler,
8181
process_all_inputs=process_all_inputs,
82-
system_query=system_query
82+
system_query=system_query,
8383
)
8484

8585
self.query = query
@@ -98,15 +98,10 @@ def __init__(self,
9898

9999
self.tokenizer = tokenizer or HuggingFaceTokenizer(self.model_name)
100100

101-
self.prompt_builder = prompt_builder or PromptBuilder(
102-
self.model_name,
103-
tokenizer=self.tokenizer
104-
)
101+
self.prompt_builder = prompt_builder or PromptBuilder(self.model_name, tokenizer=self.tokenizer)
105102

106103
self.model = AutoModelForCausalLM.from_pretrained(
107-
model_name,
108-
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
109-
device_map=self.device
104+
model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map=self.device
110105
)
111106

112107
self.max_output_tokens_per_request = max_output_tokens_per_request
@@ -118,111 +113,101 @@ def __init__(self,
118113

119114
# Ensure only one input stream is provided.
120115
if self.input_video_stream is not None and self.input_query_stream is not None:
121-
raise ValueError(
122-
"More than one input stream provided. Please provide only one input stream."
123-
)
116+
raise ValueError("More than one input stream provided. Please provide only one input stream.")
124117

125118
if self.input_video_stream is not None:
126119
logger.info("Subscribing to input video stream...")
127-
self.disposables.add(
128-
self.subscribe_to_image_processing(self.input_video_stream))
120+
self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream))
129121
if self.input_query_stream is not None:
130122
logger.info("Subscribing to input query stream...")
131-
self.disposables.add(
132-
self.subscribe_to_query_processing(self.input_query_stream))
133-
123+
self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream))
134124

135125
def _send_query(self, messages: list) -> Any:
136126
_BLUE_PRINT_COLOR: str = "\033[34m"
137127
_RESET_COLOR: str = "\033[0m"
138-
128+
139129
try:
140130
# Log the incoming messages
141131
print(f"{_BLUE_PRINT_COLOR}Messages: {str(messages)}{_RESET_COLOR}")
142-
132+
143133
# Process with chat template
144134
try:
145135
print("Applying chat template...")
146136
prompt_text = self.tokenizer.tokenizer.apply_chat_template(
147137
conversation=[{"role": "user", "content": str(messages)}],
148138
tokenize=False,
149-
add_generation_prompt=True
139+
add_generation_prompt=True,
150140
)
151141
print("Chat template applied.")
152-
142+
153143
# Tokenize the prompt
154144
print("Preparing model inputs...")
155145
model_inputs = self.tokenizer.tokenizer([prompt_text], return_tensors="pt").to(self.model.device)
156146
print("Model inputs prepared.")
157-
147+
158148
# Generate the response
159149
print("Generating response...")
160-
generated_ids = self.model.generate(
161-
**model_inputs,
162-
max_new_tokens=self.max_output_tokens_per_request
163-
)
164-
150+
generated_ids = self.model.generate(**model_inputs, max_new_tokens=self.max_output_tokens_per_request)
151+
165152
# Extract the generated tokens (excluding the input prompt tokens)
166153
print("Processing generated output...")
167154
generated_ids = [
168-
output_ids[len(input_ids):]
169-
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
155+
output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
170156
]
171-
157+
172158
# Convert tokens back to text
173159
response = self.tokenizer.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
174160
print("Response successfully generated.")
175-
161+
176162
return response
177-
163+
178164
except AttributeError as e:
179165
# Handle case where tokenizer doesn't have the expected methods
180166
logger.warning(f"Chat template not available: {e}. Using simple format.")
181167
# Continue with execution and use simple format
182-
168+
183169
except Exception as e:
184170
# Log any other errors but continue execution
185171
logger.warning(f"Error in chat template processing: {e}. Falling back to simple format.")
186-
172+
187173
# Fallback approach for models without chat template support
188174
# This code runs if the try block above raises an exception
189175
print("Using simple prompt format...")
190-
176+
191177
# Convert messages to a simple text format
192178
if isinstance(messages, list) and messages and isinstance(messages[0], dict) and "content" in messages[0]:
193179
prompt_text = messages[0]["content"]
194180
else:
195181
prompt_text = str(messages)
196-
182+
197183
# Tokenize the prompt
198184
model_inputs = self.tokenizer.tokenize_text(prompt_text)
199185
model_inputs = torch.tensor([model_inputs], device=self.model.device)
200-
186+
201187
# Generate the response
202188
generated_ids = self.model.generate(
203-
input_ids=model_inputs,
204-
max_new_tokens=self.max_output_tokens_per_request
189+
input_ids=model_inputs, max_new_tokens=self.max_output_tokens_per_request
205190
)
206-
191+
207192
# Extract the generated tokens
208-
generated_ids = generated_ids[0][len(model_inputs[0]):]
209-
193+
generated_ids = generated_ids[0][len(model_inputs[0]) :]
194+
210195
# Convert tokens back to text
211196
response = self.tokenizer.detokenize_text(generated_ids.tolist())
212197
print("Response generated using simple format.")
213-
198+
214199
return response
215-
200+
216201
except Exception as e:
217202
# Catch all other errors
218203
logger.error(f"Error during query processing: {e}", exc_info=True)
219-
return f"Error processing request. Please try again."
204+
return "Error processing request. Please try again."
220205

221206
def stream_query(self, query_text: str) -> Subject:
222207
"""
223208
Creates an observable that processes a text query and emits the response.
224209
"""
225-
return create(lambda observer, _: self._observable_query(
226-
observer, incoming_query=query_text))
210+
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
211+
227212

228213
# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation)

0 commit comments

Comments
 (0)