Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ runtime_common = [
"modelscope",
"msgspec",
"ninja",
"openai-harmony==0.0.3",
"orjson",
"outlines==0.1.11",
"packaging",
Expand Down Expand Up @@ -96,7 +97,7 @@ srt_cpu = ["sglang[runtime_common]", "einops"]
# https://vllm-ascend.readthedocs.io/en/latest/installation.html
srt_npu = ["sglang[runtime_common]"]

openai = ["openai>=1.0", "tiktoken"]
openai = ["openai>=1.99.1", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
Expand Down
244 changes: 244 additions & 0 deletions python/sglang/srt/entrypoints/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# SPDX-License-Identifier: Apache-2.0
# Copied from vLLM
import json
import logging
from abc import ABC, abstractmethod
from typing import Union

logger = logging.getLogger(__name__)

try:
from mcp import ClientSession
except ImportError:
logger.warning("Ignoring mcp import error")

from openai_harmony import Author, Message, Role, StreamState, TextContent

from sglang.srt.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
render_for_completion,
)
from sglang.srt.entrypoints.tool import Tool


class ConversationContext(ABC):

@abstractmethod
def append_output(self, output) -> None:
pass

@abstractmethod
async def call_tool(self) -> list[Message]:
pass

@abstractmethod
def need_builtin_tool_call(self) -> bool:
pass

@abstractmethod
def render_for_completion(self) -> list[int]:
pass


class SimpleContext(ConversationContext):

def __init__(self):
self.last_output = None

def append_output(self, output) -> None:
self.last_output = output

def need_builtin_tool_call(self) -> bool:
return False

async def call_tool(self) -> list[Message]:
raise NotImplementedError("Should not be called.")

def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")


class HarmonyContext(ConversationContext):

def __init__(
self,
messages: list,
tool_sessions: dict[str, Union["ClientSession", Tool]],
):
# TODO: Remove the hack of Union[ClientSession, Tool] by using MCP
# when demo.
self._messages = messages
self.tool_sessions = tool_sessions

self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
# TODO
self.num_prompt_tokens = 0
self.num_cached_tokens = 0
self.num_output_tokens = 0
self.num_reasoning_tokens = 0

def append_output(self, output) -> None:
if isinstance(output, dict) and "output_ids" in output:
output_token_ids = output["output_ids"]

# TODO: REMOVE here:
# Very hacky, find the first occurrence of token 200006 and cut from there
try:
start_index = output_token_ids.index(200006)
output_token_ids = output_token_ids[start_index:]
except ValueError:
pass

for token_id in output_token_ids:
self.parser.process(token_id)
output_msgs = self.parser.messages

meta_info = output["meta_info"]

if isinstance(meta_info, dict):
if "prompt_token_ids" in meta_info:
self.num_prompt_tokens = meta_info["prompt_tokens"]
if "cached_tokens" in meta_info:
self.num_cached_tokens = meta_info["cached_tokens"]
if "completion_tokens" in meta_info:
self.num_output_tokens += meta_info["completion_tokens"]

else:
output_msgs = output

self._messages.extend(output_msgs)

@property
def messages(self) -> list:
return self._messages

def need_builtin_tool_call(self) -> bool:
last_msg = self.messages[-1]
recipient = last_msg.recipient
return recipient is not None and (
recipient.startswith("browser.") or recipient.startswith("python")
)

async def call_tool(self) -> list[Message]:
if not self.messages:
return []
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is not None:
if recipient.startswith("browser."):
return await self.call_search_tool(
self.tool_sessions["browser"], last_msg
)
elif recipient.startswith("python"):
return await self.call_python_tool(
self.tool_sessions["python"], last_msg
)
raise ValueError("No tool call found")

def render_for_completion(self) -> list[int]:
return render_for_completion(self.messages)

async def call_search_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]

async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
param = {
"code": last_msg.content[0].text,
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text

content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name="python")

return [
Message(
author=author,
content=[content],
channel=last_msg.channel,
recipient=Role.ASSISTANT,
)
]


class StreamingHarmonyContext(HarmonyContext):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_output = None

self.parser = get_streamable_parser_for_assistant()
self.encoding = get_encoding()
self.last_tok = None

@property
def messages(self) -> list:
return self.parser.messages

def append_output(self, output) -> None:
if isinstance(output, dict) and "output_ids" in output:
# RequestOutput from SGLang with outputs
output_token_ids = output["output_ids"]

# TODO: REMOVE here:
# Very hacky, find the first occurrence of token 200006 and cut from there
# Find the first occurrence of token 200006 and cut from there
try:
start_index = output_token_ids.index(200006)
output_token_ids = output_token_ids[start_index:]
except ValueError:
pass

for token_id in output_token_ids:
self.parser.process(token_id)

else:
# Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message"
msg = output[0]
# Sometimes the recipient is not set for tool messages,
# so we set it to "assistant"
if msg.author.role == Role.TOOL and msg.recipient is None:
msg.recipient = "assistant"
toks = self.encoding.render(msg)
for tok in toks:
self.parser.process(tok)
self.last_tok = toks[-1]

def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START

def is_assistant_action_turn(self) -> bool:
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()

def render_for_completion(self) -> list[int]:
# now this list of tokens as next turn's starting tokens
# `<|start|>assistant``,
# we need to process them in parser.
rendered_tokens = super().render_for_completion()

last_n = -1
to_process = []
while rendered_tokens[last_n] != self.last_tok:
to_process.append(rendered_tokens[last_n])
last_n -= 1
for tok in reversed(to_process):
self.parser.process(tok)

return rendered_tokens
Loading
Loading