-
Notifications
You must be signed in to change notification settings - Fork 4.7k
feat: Add token counting utility + Add support for it in Compression #5593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
bc3cc01
517f2d7
22ab15d
7011c06
8a43877
4c73db2
6f43ed5
ff1e84a
f6e7200
259b5a7
8728502
3e51b13
2fec0e9
4269391
131f190
5d1ed33
5f09d4b
6458a30
be4e3c1
0f17f6d
d568ff0
4dc5a2b
5e7dbeb
bb73ed7
7f4498e
c4a74aa
b38b84b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| """ | ||
| This example shows how to set a context token based limit for tool call compression. | ||
| Run: `python cookbook/agents/context_compression/token_based_tool_call_compression.py` | ||
| """ | ||
|
|
||
| from agno.agent import Agent | ||
| from agno.compression.manager import CompressionManager | ||
| from agno.db.sqlite import SqliteDb | ||
| from agno.models.openai import OpenAIChat | ||
| from agno.tools.duckduckgo import DuckDuckGoTools | ||
|
|
||
| compression_prompt = """ | ||
| You are a compression expert. Your goal is to compress web search results for a competitive intelligence analyst. | ||
| YOUR GOAL: Extract only actionable competitive insights while being extremely concise. | ||
| MUST PRESERVE: | ||
| - Competitor names and specific actions (product launches, partnerships, acquisitions, pricing changes) | ||
| - Exact numbers (revenue, market share, growth rates, pricing, headcount) | ||
| - Precise dates (announcement dates, launch dates, deal dates) | ||
| - Direct quotes from executives or official statements | ||
| - Funding rounds and valuations | ||
| MUST REMOVE: | ||
| - Company history and background information | ||
| - General industry trends (unless competitor-specific) | ||
| - Analyst opinions and speculation (keep only facts) | ||
| - Detailed product descriptions (keep only key differentiators and pricing) | ||
| - Marketing fluff and promotional language | ||
| OUTPUT FORMAT: | ||
| Return a bullet-point list where each line follows this format: | ||
| "[Company Name] - [Date]: [Action/Event] ([Key Numbers/Details])" | ||
| Keep it under 200 words total. Be ruthlessly concise. Facts only. | ||
| Example: | ||
| - Acme Corp - Mar 15, 2024: Launched AcmeGPT at $99/user/month, targeting enterprise market | ||
| - TechCo - Feb 10, 2024: Acquired DataStart for $150M, gaining 500 enterprise customers | ||
| """ | ||
|
|
||
| compression_manager = CompressionManager( | ||
| model=OpenAIChat(id="gpt-5-mini"), | ||
| compress_tool_results_token_limit=5000, | ||
| compress_tool_call_instructions=compression_prompt, | ||
| ) | ||
|
|
||
| agent = Agent( | ||
| model=OpenAIChat(id="gpt-4o-mini"), | ||
| tools=[DuckDuckGoTools()], | ||
| description="Specialized in tracking competitor activities", | ||
| instructions="Use the search tools and always use the latest information and data.", | ||
| db=SqliteDb(db_file="tmp/dbs/token_based_tool_call_compression.db"), | ||
| compression_manager=compression_manager, | ||
| add_history_to_context=True, # Add history to context | ||
| num_history_runs=3, | ||
| session_id="token_based_tool_call_compression", | ||
| ) | ||
|
|
||
| agent.print_response( | ||
| """ | ||
| Use the search tools and always use the latest information and data. | ||
| Research recent activities (last 3 months) for these AI companies: | ||
| 1. OpenAI - product launches, partnerships, pricing | ||
| 2. Anthropic - new features, enterprise deals, funding | ||
| 3. Google DeepMind - research breakthroughs, product releases | ||
| 4. Meta AI - open source releases, research papers | ||
| For each, find specific actions with dates and numbers.""", | ||
| stream=True, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,27 +48,41 @@ | |
| class CompressionManager: | ||
| model: Optional[Model] = None | ||
| compress_tool_results: bool = True | ||
| compress_tool_results_limit: int = 3 | ||
| compress_tool_results_limit: Optional[int] = None | ||
manuhortet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| compress_tool_results_token_limit: Optional[int] = None | ||
Mustafa-Esoofally marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| compress_tool_call_instructions: Optional[str] = None | ||
|
|
||
| stats: Dict[str, Any] = field(default_factory=dict) | ||
|
|
||
| def _is_tool_result_message(self, msg: Message) -> bool: | ||
| return msg.role == "tool" | ||
|
|
||
| def should_compress(self, messages: List[Message]) -> bool: | ||
| def should_compress( | ||
| self, | ||
| messages: List[Message], | ||
| tools: Optional[List] = None, | ||
| main_model: Optional[Model] = None, | ||
|
||
| ) -> bool: | ||
| if not self.compress_tool_results: | ||
| return False | ||
|
|
||
| uncompressed_tools_count = len( | ||
| [m for m in messages if self._is_tool_result_message(m) and m.compressed_content is None] | ||
| ) | ||
| should_compress = uncompressed_tools_count >= self.compress_tool_results_limit | ||
|
|
||
| if should_compress: | ||
| log_info(f"Tool call compression threshold hit. Compressing {uncompressed_tools_count} tool results") | ||
| # Token-based threshold check | ||
| if self.compress_tool_results_token_limit is not None and main_model is not None: | ||
| tokens = main_model.count_tokens(messages, tools) | ||
| if tokens >= self.compress_tool_results_token_limit: | ||
| log_info(f"Token limit hit: {tokens} >= {self.compress_tool_results_token_limit}") | ||
| return True | ||
|
|
||
| # Count-based threshold check | ||
| if self.compress_tool_results_limit is not None: | ||
| uncompressed_tools_count = len( | ||
| [m for m in messages if self._is_tool_result_message(m) and m.compressed_content is None] | ||
| ) | ||
| if uncompressed_tools_count >= self.compress_tool_results_limit: | ||
| log_info(f"Tool count limit hit: {uncompressed_tools_count} >= {self.compress_tool_results_limit}") | ||
| return True | ||
|
|
||
| return should_compress | ||
| return False | ||
|
|
||
| def _compress_tool_result(self, tool_result: Message) -> Optional[str]: | ||
| if not tool_result: | ||
|
|
@@ -112,8 +126,11 @@ def compress(self, messages: List[Message]) -> None: | |
| compressed = self._compress_tool_result(tool_msg) | ||
| if compressed: | ||
| tool_msg.compressed_content = compressed | ||
| # Track stats | ||
| self.stats["messages_compressed"] = self.stats.get("messages_compressed", 0) + 1 | ||
| # Count actual tool results (Gemini combines multiple in one message) | ||
| tool_results_count = len(tool_msg.tool_calls) if tool_msg.tool_calls else 1 | ||
| self.stats["tool_results_compressed"] = ( | ||
| self.stats.get("tool_results_compressed", 0) + tool_results_count | ||
| ) | ||
| self.stats["original_size"] = self.stats.get("original_size", 0) + original_len | ||
| self.stats["compressed_size"] = self.stats.get("compressed_size", 0) + len(compressed) | ||
| else: | ||
|
|
@@ -168,8 +185,11 @@ async def acompress(self, messages: List[Message]) -> None: | |
| for msg, compressed, original_len in zip(uncompressed_tools, results, original_sizes): | ||
| if compressed: | ||
| msg.compressed_content = compressed | ||
| # Track stats | ||
| self.stats["messages_compressed"] = self.stats.get("messages_compressed", 0) + 1 | ||
| # Count actual tool results (Gemini combines multiple in one message) | ||
| tool_results_count = len(msg.tool_calls) if msg.tool_calls else 1 | ||
| self.stats["tool_results_compressed"] = ( | ||
| self.stats.get("tool_results_compressed", 0) + tool_results_count | ||
| ) | ||
| self.stats["original_size"] = self.stats.get("original_size", 0) + original_len | ||
| self.stats["compressed_size"] = self.stats.get("compressed_size", 0) + len(compressed) | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -357,6 +357,32 @@ def _format_messages( | |
| # TODO: Add caching: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html | ||
| return formatted_messages, system_message | ||
|
|
||
| def count_tokens( | ||
| self, | ||
| messages: List[Message], | ||
| tools: Optional[List[Dict[str, Any]]] = None, | ||
| ) -> int: | ||
| try: | ||
| formatted_messages, system_message = self._format_messages(messages, compress_tool_results=True) | ||
| converse_input: Dict[str, Any] = {"messages": formatted_messages} | ||
| if system_message: | ||
| converse_input["system"] = system_message | ||
|
|
||
| response = self.get_client().count_tokens(modelId=self.id, input={"converse": converse_input}) | ||
| tokens = response.get("inputTokens", 0) | ||
|
|
||
| # Count tool tokens | ||
| if tools: | ||
| from agno.utils.tokens import _count_tool_tokens | ||
|
|
||
| includes_system = any(m.role == "system" for m in messages) | ||
| tokens += _count_tool_tokens(tools, self.id, includes_system) | ||
|
|
||
| return tokens | ||
| except Exception as e: | ||
| log_warning(f"Failed to count tokens via Bedrock API: {e}") | ||
| return super().count_tokens(messages, tools) | ||
|
||
|
|
||
| def invoke( | ||
| self, | ||
| messages: List[Message], | ||
|
|
@@ -719,4 +745,9 @@ def _get_metrics(self, response_usage: Dict[str, Any]) -> Metrics: | |
| metrics.output_tokens = response_usage.get("outputTokens", 0) or 0 | ||
| metrics.total_tokens = metrics.input_tokens + metrics.output_tokens | ||
|
|
||
| log_debug( | ||
manuhortet marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f"Bedrock response metrics: input_tokens={metrics.input_tokens}, " | ||
| f"output_tokens={metrics.output_tokens}, total_tokens={metrics.total_tokens}" | ||
| ) | ||
|
|
||
| return metrics | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| List, | ||
| Literal, | ||
| Optional, | ||
| Sequence, | ||
| Tuple, | ||
| Type, | ||
| Union, | ||
|
|
@@ -427,6 +428,15 @@ def _format_tools(self, tools: Optional[List[Union[Function, dict]]]) -> List[Di | |
| _tool_dicts.append(tool) | ||
| return _tool_dicts | ||
|
|
||
| def count_tokens( | ||
manuhortet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| messages: List[Message], | ||
| tools: Optional[Sequence[Union[Function, Dict[str, Any]]]] = None, | ||
| ) -> int: | ||
| from agno.utils.tokens import count_tokens | ||
|
|
||
| return count_tokens(messages, tools=list(tools) if tools else None, model_id=self.id) | ||
|
|
||
| def response( | ||
| self, | ||
| messages: List[Message], | ||
|
|
@@ -476,6 +486,10 @@ def response( | |
| _compress_tool_results = compression_manager is not None and compression_manager.compress_tool_results | ||
|
|
||
| while True: | ||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(messages, tools, main_model=self): | ||
Mustafa-Esoofally marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| compression_manager.compress(messages) | ||
|
|
||
| # Get response from model | ||
| assistant_message = Message(role=self.assistant_message_role) | ||
| self._process_model_response( | ||
|
|
@@ -574,11 +588,6 @@ def response( | |
| # Add a function call for each successful execution | ||
| function_call_count += len(function_call_results) | ||
|
|
||
| all_messages = messages + function_call_results | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we changing this? I think probably you are right, but there was a reason we did it here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before we were limited by |
||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(all_messages): | ||
| compression_manager.compress(all_messages) | ||
|
|
||
| # Format and add results to messages | ||
| self.format_function_call_results( | ||
| messages=messages, | ||
|
|
@@ -678,6 +687,10 @@ async def aresponse( | |
| function_call_count = 0 | ||
|
|
||
| while True: | ||
| # Compress existing tool results BEFORE making API call to avoid context overflow | ||
| if compression_manager and compression_manager.should_compress(messages, tools, main_model=self): | ||
| await compression_manager.acompress(messages) | ||
|
|
||
| # Get response from model | ||
| assistant_message = Message(role=self.assistant_message_role) | ||
| await self._aprocess_model_response( | ||
|
|
@@ -775,11 +788,6 @@ async def aresponse( | |
| # Add a function call for each successful execution | ||
| function_call_count += len(function_call_results) | ||
|
|
||
| all_messages = messages + function_call_results | ||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(all_messages): | ||
| await compression_manager.acompress(all_messages) | ||
|
|
||
| # Format and add results to messages | ||
| self.format_function_call_results( | ||
| messages=messages, | ||
|
|
@@ -1105,6 +1113,10 @@ def response_stream( | |
| function_call_count = 0 | ||
|
|
||
| while True: | ||
| # Compress existing tool results BEFORE invoke | ||
| if compression_manager and compression_manager.should_compress(messages, tools, main_model=self): | ||
| compression_manager.compress(messages) | ||
|
|
||
| assistant_message = Message(role=self.assistant_message_role) | ||
| # Create assistant message and stream data | ||
| stream_data = MessageData() | ||
|
|
@@ -1166,11 +1178,6 @@ def response_stream( | |
| # Add a function call for each successful execution | ||
| function_call_count += len(function_call_results) | ||
|
|
||
| all_messages = messages + function_call_results | ||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(all_messages): | ||
| compression_manager.compress(all_messages) | ||
|
|
||
| # Format and add results to messages | ||
| if stream_data and stream_data.extra is not None: | ||
| self.format_function_call_results( | ||
|
|
@@ -1323,6 +1330,10 @@ async def aresponse_stream( | |
| function_call_count = 0 | ||
|
|
||
| while True: | ||
| # Compress existing tool results BEFORE making API call to avoid context overflow | ||
| if compression_manager and compression_manager.should_compress(messages, tools, main_model=self): | ||
| await compression_manager.acompress(messages) | ||
|
|
||
| # Create assistant message and stream data | ||
| assistant_message = Message(role=self.assistant_message_role) | ||
| stream_data = MessageData() | ||
|
|
@@ -1384,11 +1395,6 @@ async def aresponse_stream( | |
| # Add a function call for each successful execution | ||
| function_call_count += len(function_call_results) | ||
|
|
||
| all_messages = messages + function_call_results | ||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(all_messages): | ||
| await compression_manager.acompress(all_messages) | ||
|
|
||
| # Format and add results to messages | ||
| if stream_data and stream_data.extra is not None: | ||
| self.format_function_call_results( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.