diff --git a/platform/reworkd_platform/services/tokenizer/dependencies.py b/platform/reworkd_platform/services/tokenizer/dependencies.py index c076cf2171..6e73487813 100644 --- a/platform/reworkd_platform/services/tokenizer/dependencies.py +++ b/platform/reworkd_platform/services/tokenizer/dependencies.py @@ -1,6 +1,6 @@ from fastapi import Request -from reworkd_platform.services.tokenizer.service import TokenService +from reworkd_platform.services.tokenizer.token_service import TokenService def get_token_service(request: Request) -> TokenService: diff --git a/platform/reworkd_platform/services/tokenizer/service.py b/platform/reworkd_platform/services/tokenizer/service.py deleted file mode 100644 index d654dcf971..0000000000 --- a/platform/reworkd_platform/services/tokenizer/service.py +++ /dev/null @@ -1,15 +0,0 @@ -from tiktoken import Encoding - - -class TokenService: - def __init__(self, encoding: Encoding): - self.encoding = encoding - - def tokenize(self, text: str) -> list[int]: - return self.encoding.encode(text) - - def detokenize(self, tokens: list[int]) -> str: - return self.encoding.decode(tokens) - - def count(self, text: str) -> int: - return len(self.tokenize(text)) diff --git a/platform/reworkd_platform/services/tokenizer/token_service.py b/platform/reworkd_platform/services/tokenizer/token_service.py new file mode 100644 index 0000000000..d1eef326a8 --- /dev/null +++ b/platform/reworkd_platform/services/tokenizer/token_service.py @@ -0,0 +1,26 @@ +from tiktoken import Encoding + +from reworkd_platform.schemas import LLM_MODEL_MAX_TOKENS +from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI + + +class TokenService: + def __init__(self, encoding: Encoding): + self.encoding = encoding + + def tokenize(self, text: str) -> list[int]: + return self.encoding.encode(text) + + def detokenize(self, tokens: list[int]) -> str: + return self.encoding.decode(tokens) + + def count(self, text: str) -> int: + return len(self.tokenize(text)) + + def calculate_max_tokens(self, model: WrappedChatOpenAI, *prompts: str) -> None: + max_allowed_tokens = LLM_MODEL_MAX_TOKENS.get(model.model_name, 4000) + prompt_tokens = sum([self.count(p) for p in prompts]) + requested_tokens = max_allowed_tokens - prompt_tokens + + model.max_tokens = min(model.max_tokens, requested_tokens) + model.max_tokens = max(model.max_tokens, 1) diff --git a/platform/reworkd_platform/tests/test_token_service.py b/platform/reworkd_platform/tests/test_token_service.py index 56d61acd26..15f3a4cc1c 100644 --- a/platform/reworkd_platform/tests/test_token_service.py +++ b/platform/reworkd_platform/tests/test_token_service.py @@ -1,26 +1,105 @@ +from unittest.mock import Mock + import tiktoken -from reworkd_platform.services.tokenizer.service import TokenService +from reworkd_platform.schemas import LLM_MODEL_MAX_TOKENS +from reworkd_platform.services.tokenizer.token_service import TokenService encoding = tiktoken.get_encoding("cl100k_base") -def test_happy_path(): +def test_happy_path() -> None: service = TokenService(encoding) text = "Hello world!" - validate_tokenize_and_detokenize(service, text, 3) -def test_nothing(): +def test_nothing() -> None: service = TokenService(encoding) text = "" - validate_tokenize_and_detokenize(service, text, 0) -def validate_tokenize_and_detokenize(service, text, expected_token_count): +def validate_tokenize_and_detokenize( + service: TokenService, text: str, expected_token_count: int +) -> None: tokens = service.tokenize(text) assert text == service.detokenize(tokens) assert len(tokens) == service.count(text) assert len(tokens) == expected_token_count + + +def test_calculate_max_tokens_with_small_max_tokens() -> None: + initial_max_tokens = 3000 + service = TokenService(encoding) + model = Mock(spec=["model_name", "max_tokens"]) + model.model_name = "gpt-3.5-turbo" + model.max_tokens = initial_max_tokens + + service.calculate_max_tokens(model, "Hello") + + assert model.max_tokens == initial_max_tokens + + +def test_calculate_max_tokens_with_high_completion_tokens() -> None: + service = TokenService(encoding) + prompt_tokens = service.count(LONG_TEXT) + model = Mock(spec=["model_name", "max_tokens"]) + model.model_name = "gpt-3.5-turbo" + model.max_tokens = 8000 + + service.calculate_max_tokens(model, LONG_TEXT) + + assert model.max_tokens == ( + LLM_MODEL_MAX_TOKENS.get("gpt-3.5-turbo") - prompt_tokens + ) + + +def test_calculate_max_tokens_with_negative_result() -> None: + service = TokenService(encoding) + model = Mock(spec=["model_name", "max_tokens"]) + model.model_name = "gpt-3.5-turbo" + model.max_tokens = 8000 + + service.calculate_max_tokens(model, *([LONG_TEXT] * 100)) + + # We use the minimum length of 1 + assert model.max_tokens == 1 + + +LONG_TEXT = """ +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +This is some long text. This is some long text. This is some long text. +""" diff --git a/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py b/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py index 593688d06d..8096fbbe51 100644 --- a/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py +++ b/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py @@ -1,10 +1,10 @@ -from typing import Any, Coroutine, Callable +from typing import Any, Callable, Coroutine from fastapi import Depends from reworkd_platform.schemas import AgentRun, UserBase from reworkd_platform.services.tokenizer.dependencies import get_token_service -from reworkd_platform.services.tokenizer.service import TokenService +from reworkd_platform.services.tokenizer.token_service import TokenService from reworkd_platform.settings import settings from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService from reworkd_platform.web.api.agent.agent_service.mock_agent_service import ( @@ -13,9 +13,7 @@ from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import ( OpenAIAgentService, ) -from reworkd_platform.web.api.agent.dependancies import ( - get_agent_memory, -) +from reworkd_platform.web.api.agent.dependancies import get_agent_memory from reworkd_platform.web.api.agent.model_settings import create_model from reworkd_platform.web.api.dependencies import get_current_user from reworkd_platform.web.api.memory.memory import AgentMemory diff --git a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py index 20f6a6ad7d..f903dab8cf 100644 --- a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py +++ b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py @@ -7,8 +7,8 @@ from loguru import logger from pydantic import ValidationError -from reworkd_platform.schemas import LLM_MODEL_MAX_TOKENS, ModelSettings -from reworkd_platform.services.tokenizer.service import TokenService +from reworkd_platform.schemas import ModelSettings +from reworkd_platform.services.tokenizer.token_service import TokenService from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService from reworkd_platform.web.api.agent.analysis import Analysis, AnalysisArguments from reworkd_platform.web.api.agent.helpers import ( @@ -54,7 +54,8 @@ async def start_goal_agent(self, *, goal: str) -> List[str]: [SystemMessagePromptTemplate(prompt=start_goal_prompt)] ) - self.calculate_max_tokens( + self.token_service.calculate_max_tokens( + self.model, prompt.format_prompt( goal=goal, language=self.settings.language, @@ -90,7 +91,8 @@ async def analyze_task_agent( language=self.settings.language, ) - self.calculate_max_tokens( + self.token_service.calculate_max_tokens( + self.model, prompt.to_string(), str(functions), ) @@ -153,7 +155,9 @@ async def create_tasks_agent( "result": result, } - self.calculate_max_tokens(prompt.format_prompt(**args).to_string()) + self.token_service.calculate_max_tokens( + self.model, prompt.format_prompt(**args).to_string() + ) completion = await call_model_with_handling( self.model, prompt, args, settings=self.settings, callbacks=self.callbacks @@ -177,10 +181,3 @@ async def create_tasks_agent( memory.add_tasks(unique_tasks) return unique_tasks - - def calculate_max_tokens(self, *prompts: str) -> None: - max_allowed_tokens = LLM_MODEL_MAX_TOKENS.get(self.model.model_name, 4000) - prompt_tokens = sum([self.token_service.count(p) for p in prompts]) - requested_tokens = max_allowed_tokens - prompt_tokens - - self.model.max_tokens = min(self.model.max_tokens, requested_tokens)