Skip to content
This repository was archived by the owner on Jan 28, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import Request

from reworkd_platform.services.tokenizer.service import TokenService
from reworkd_platform.services.tokenizer.token_service import TokenService


def get_token_service(request: Request) -> TokenService:
Expand Down
15 changes: 0 additions & 15 deletions platform/reworkd_platform/services/tokenizer/service.py

This file was deleted.

26 changes: 26 additions & 0 deletions platform/reworkd_platform/services/tokenizer/token_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from tiktoken import Encoding

from reworkd_platform.schemas import LLM_MODEL_MAX_TOKENS
from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI


class TokenService:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

def __init__(self, encoding: Encoding):
self.encoding = encoding

def tokenize(self, text: str) -> list[int]:
return self.encoding.encode(text)

def detokenize(self, tokens: list[int]) -> str:
return self.encoding.decode(tokens)

def count(self, text: str) -> int:
return len(self.tokenize(text))

def calculate_max_tokens(self, model: WrappedChatOpenAI, *prompts: str) -> None:
max_allowed_tokens = LLM_MODEL_MAX_TOKENS.get(model.model_name, 4000)
prompt_tokens = sum([self.count(p) for p in prompts])
requested_tokens = max_allowed_tokens - prompt_tokens

model.max_tokens = min(model.max_tokens, requested_tokens)
model.max_tokens = max(model.max_tokens, 1)
91 changes: 85 additions & 6 deletions platform/reworkd_platform/tests/test_token_service.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,105 @@
from unittest.mock import Mock

import tiktoken

from reworkd_platform.services.tokenizer.service import TokenService
from reworkd_platform.schemas import LLM_MODEL_MAX_TOKENS
from reworkd_platform.services.tokenizer.token_service import TokenService

encoding = tiktoken.get_encoding("cl100k_base")


def test_happy_path():
def test_happy_path() -> None:
service = TokenService(encoding)
text = "Hello world!"

validate_tokenize_and_detokenize(service, text, 3)


def test_nothing():
def test_nothing() -> None:
service = TokenService(encoding)
text = ""

validate_tokenize_and_detokenize(service, text, 0)


def validate_tokenize_and_detokenize(service, text, expected_token_count):
def validate_tokenize_and_detokenize(
service: TokenService, text: str, expected_token_count: int
) -> None:
tokens = service.tokenize(text)
assert text == service.detokenize(tokens)
assert len(tokens) == service.count(text)
assert len(tokens) == expected_token_count


def test_calculate_max_tokens_with_small_max_tokens() -> None:
initial_max_tokens = 3000
service = TokenService(encoding)
model = Mock(spec=["model_name", "max_tokens"])
model.model_name = "gpt-3.5-turbo"
model.max_tokens = initial_max_tokens

service.calculate_max_tokens(model, "Hello")

assert model.max_tokens == initial_max_tokens


def test_calculate_max_tokens_with_high_completion_tokens() -> None:
service = TokenService(encoding)
prompt_tokens = service.count(LONG_TEXT)
model = Mock(spec=["model_name", "max_tokens"])
model.model_name = "gpt-3.5-turbo"
model.max_tokens = 8000

service.calculate_max_tokens(model, LONG_TEXT)

assert model.max_tokens == (
LLM_MODEL_MAX_TOKENS.get("gpt-3.5-turbo") - prompt_tokens
)


def test_calculate_max_tokens_with_negative_result() -> None:
service = TokenService(encoding)
model = Mock(spec=["model_name", "max_tokens"])
model.model_name = "gpt-3.5-turbo"
model.max_tokens = 8000

service.calculate_max_tokens(model, *([LONG_TEXT] * 100))

# We use the minimum length of 1
assert model.max_tokens == 1


LONG_TEXT = """
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
"""
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Coroutine, Callable
from typing import Any, Callable, Coroutine

from fastapi import Depends

from reworkd_platform.schemas import AgentRun, UserBase
from reworkd_platform.services.tokenizer.dependencies import get_token_service
from reworkd_platform.services.tokenizer.service import TokenService
from reworkd_platform.services.tokenizer.token_service import TokenService
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
from reworkd_platform.web.api.agent.agent_service.mock_agent_service import (
Expand All @@ -13,9 +13,7 @@
from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import (
OpenAIAgentService,
)
from reworkd_platform.web.api.agent.dependancies import (
get_agent_memory,
)
from reworkd_platform.web.api.agent.dependancies import get_agent_memory
from reworkd_platform.web.api.agent.model_settings import create_model
from reworkd_platform.web.api.dependencies import get_current_user
from reworkd_platform.web.api.memory.memory import AgentMemory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from loguru import logger
from pydantic import ValidationError

from reworkd_platform.schemas import LLM_MODEL_MAX_TOKENS, ModelSettings
from reworkd_platform.services.tokenizer.service import TokenService
from reworkd_platform.schemas import ModelSettings
from reworkd_platform.services.tokenizer.token_service import TokenService
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
from reworkd_platform.web.api.agent.analysis import Analysis, AnalysisArguments
from reworkd_platform.web.api.agent.helpers import (
Expand Down Expand Up @@ -54,7 +54,8 @@ async def start_goal_agent(self, *, goal: str) -> List[str]:
[SystemMessagePromptTemplate(prompt=start_goal_prompt)]
)

self.calculate_max_tokens(
self.token_service.calculate_max_tokens(
self.model,
prompt.format_prompt(
goal=goal,
language=self.settings.language,
Expand Down Expand Up @@ -90,7 +91,8 @@ async def analyze_task_agent(
language=self.settings.language,
)

self.calculate_max_tokens(
self.token_service.calculate_max_tokens(
self.model,
prompt.to_string(),
str(functions),
)
Expand Down Expand Up @@ -153,7 +155,9 @@ async def create_tasks_agent(
"result": result,
}

self.calculate_max_tokens(prompt.format_prompt(**args).to_string())
self.token_service.calculate_max_tokens(
self.model, prompt.format_prompt(**args).to_string()
)

completion = await call_model_with_handling(
self.model, prompt, args, settings=self.settings, callbacks=self.callbacks
Expand All @@ -177,10 +181,3 @@ async def create_tasks_agent(
memory.add_tasks(unique_tasks)

return unique_tasks

def calculate_max_tokens(self, *prompts: str) -> None:
max_allowed_tokens = LLM_MODEL_MAX_TOKENS.get(self.model.model_name, 4000)
prompt_tokens = sum([self.token_service.count(p) for p in prompts])
requested_tokens = max_allowed_tokens - prompt_tokens

self.model.max_tokens = min(self.model.max_tokens, requested_tokens)