Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
from collections import deque
from enum import Enum, auto
from typing import List, Optional
from typing import Callable, List, Optional

import psutil
import setproctitle
Expand Down Expand Up @@ -119,14 +119,19 @@ def dispatch(self):
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""

def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
run_scheduler_process_func: Callable,
) -> None:
# Parse args
self.server_args = server_args
self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)
self.run_scheduler_process = run_scheduler_process
self.run_scheduler_process_func = run_scheduler_process_func

# For DP balance
self.global_balance_id = 0
Expand Down Expand Up @@ -429,7 +434,7 @@ def launch_tensor_parallel_group(
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
with self.env_lock, maybe_reindex_device_id(gpu_id) as gpu_id:
proc = mp.Process(
target=self.run_scheduler_process,
target=self.run_scheduler_process_func,
args=(
server_args,
rank_port_args,
Expand Down Expand Up @@ -511,7 +516,7 @@ def run_data_parallel_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
data_parallel_controller_class=DataParallelController,
run_scheduler_process_func: Callable = run_scheduler_process,
):
setproctitle.setproctitle("sglang::data_parallel_controller")
faulthandler.enable()
Expand All @@ -529,7 +534,9 @@ def run_data_parallel_controller_process(
trace_set_thread_info(thread_label)

try:
controller = data_parallel_controller_class(server_args, port_args)
controller = DataParallelController(
server_args, port_args, run_scheduler_process_func
)
pipe_writer.send(
{
"status": "ready",
Expand Down
57 changes: 38 additions & 19 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)

# Init tokenizer
if server_args.skip_tokenizer_init:
self.tokenizer = None
else:
Expand All @@ -95,8 +96,11 @@ def __init__(
)

self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
self.is_dummy = server_args.load_format == "dummy"
self.is_dummy = False
self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss"
self.disable_tokenizer_batch_decode = server_args.disable_tokenizer_batch_decode

# Init dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
(BatchEmbeddingOutput, self.handle_batch_embedding_out),
Expand All @@ -106,9 +110,6 @@ def __init__(
]
)

self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss"
self.disable_tokenizer_batch_decode = server_args.disable_tokenizer_batch_decode

def event_loop(self):
"""The event loop that handles requests"""
while True:
Expand Down Expand Up @@ -148,7 +149,7 @@ def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOutput):
# If it is embedding model, no detokenization is needed.
return recv_obj

def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput):
bs = len(recv_obj.rids)

# Initialize decode status
Expand Down Expand Up @@ -176,8 +177,31 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
)
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])

# TODO(lmzheng): better handle skip_special_tokens/spaces_between_special_tokens per request
if self.disable_tokenizer_batch_decode:
# Decode token ids to strings
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
if not self.disable_tokenizer_batch_decode:
if not self.is_dummy:
# Run normal batch decode
surr_texts = self.tokenizer.batch_decode(
surr_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0
],
)
read_texts = self.tokenizer.batch_decode(
read_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0
],
)
else:
# If it is dummy weights, just return dummy strings to prevent potential detokenization edge cases
surr_texts = ["dog" for _ in surr_ids]
read_texts = ["cat" for _ in read_ids]
else:
# Do not use batch decode to prevent some detokenization edge cases (e.g., gpt-oss).
surr_texts = [
self.tokenizer.decode(
surr, skip_special_tokens=skip, spaces_between_special_tokens=space
Expand All @@ -198,17 +222,6 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
recv_obj.spaces_between_special_tokens,
)
]
else:
surr_texts = self.tokenizer.batch_decode(
surr_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
read_texts = self.tokenizer.batch_decode(
read_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)

# Incremental decoding
output_strs = []
Expand Down Expand Up @@ -247,6 +260,11 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
s.sent_offset = len(output_str)
output_strs.append(incremental_output)

return output_strs

def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
output_strs = self._decode_batch_token_id_output(recv_obj)

return BatchStrOutput(
rids=recv_obj.rids,
http_worker_ipcs=recv_obj.http_worker_ipcs,
Expand Down Expand Up @@ -306,14 +324,15 @@ def __setitem__(self, key, value):
def run_detokenizer_process(
server_args: ServerArgs,
port_args: PortArgs,
detokenizer_manager_class=DetokenizerManager,
):
kill_itself_when_parent_died()
setproctitle.setproctitle("sglang::detokenizer")
configure_logger(server_args)
parent_process = psutil.Process().parent()

try:
manager = DetokenizerManager(server_args, port_args)
manager = detokenizer_manager_class(server_args, port_args)
if server_args.tokenizer_worker_num > 1:
manager.multi_http_worker_event_loop()
else:
Expand Down
Loading