Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import (
BlockReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
Expand Down Expand Up @@ -243,6 +244,9 @@ def event_loop(self):
),
):
self.dispatching(recv_req)
elif isinstance(recv_req, BlockReqInput):
for worker in self.workers:
worker.send_pyobj(recv_req)
else:
# Send other control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]:
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,3 +911,13 @@ class RpcReqInput:
class RpcReqOutput:
success: bool
message: str


class BlockReqType(Enum):
BLOCK = 1
UNBLOCK = 2


@dataclass
class BlockReqInput:
type: BlockReqType
12 changes: 12 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin,
)
Expand All @@ -122,6 +123,7 @@
broadcast_pyobj,
configure_logger,
crash_on_warnings,
enable_colocated_batch_gen,
get_bool_env_var,
get_zmq_socket,
kill_itself_when_parent_died,
Expand Down Expand Up @@ -386,6 +388,13 @@ def __init__(
enable=server_args.enable_memory_saver
)

self.input_blocker = (
SchedulerInputBlocker(server_args, noop=self.attn_tp_rank != 0)
if enable_colocated_batch_gen()
or server_args.enable_scheduler_input_blocker
else None
)

# Init profiler
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
Expand Down Expand Up @@ -739,6 +748,9 @@ def recv_requests(self) -> List[Req]:
else:
recv_reqs = None

if self.input_blocker is not None:
recv_reqs = self.input_blocker.handle(recv_reqs)

if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
work_reqs = [
Expand Down
97 changes: 97 additions & 0 deletions python/sglang/srt/managers/scheduler_input_blocker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from enum import Enum, auto
from typing import Any, List, Optional

import torch

from sglang import ServerArgs
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType


class SchedulerInputBlocker:
def __init__(self, server_args: ServerArgs, noop: bool):
self._state = _State.UNBLOCKED
self._pending_reqs = []
self._noop = noop
assert (
server_args.disable_overlap_schedule
), "SchedulerInputBlocker requires overlap scheduler to be disabled"

def handle(self, recv_reqs: Optional[List[Any]]):
assert (recv_reqs is None) == self._noop

if not self._noop:
output_reqs = []
for recv_req in recv_reqs:
output_reqs += self._handle_recv_req(recv_req)

global_arrived_unblock_barrier = self._compute_global_unblock_barrier()
if (
self._state == _State.GLOBAL_UNBLOCK_BARRIER
and global_arrived_unblock_barrier
):
output_reqs += self._handle_arrive_unblock_barrier()

if not self._noop:
return output_reqs

def _handle_recv_req(self, recv_req):
if isinstance(recv_req, BlockReqInput):
if recv_req.type == BlockReqType.BLOCK:
self._execute_block_req()
return []
elif recv_req.type == BlockReqType.UNBLOCK:
self._execute_unblock_req()
return []
else:
raise NotImplementedError(f"{recv_req=}")
else:
if self._state == _State.UNBLOCKED:
return [recv_req]
else:
self._pending_reqs.append(recv_req)
return []

def _execute_block_req(self):
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)

def _execute_unblock_req(self):
self._change_state(
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
)

def _compute_global_unblock_barrier(self):
local_arrived = self._noop or (self._state == _State.GLOBAL_UNBLOCK_BARRIER)
global_arrived = torch.tensor(local_arrived).cuda()
torch.distributed.all_reduce(global_arrived, torch.distributed.ReduceOp.MIN)
return global_arrived.cpu().item()

def _handle_arrive_unblock_barrier(self):
self._change_state(
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
)
output_reqs = [*self._pending_reqs]
self._pending_reqs.clear()
return output_reqs

def _change_state(self, original: "_State", target: "_State"):
assert self._state == original, f"{self._state=} {original=} {target=}"
self._state = target


class _State(Enum):
UNBLOCKED = auto()
BLOCKED = auto()
GLOBAL_UNBLOCK_BARRIER = auto()
11 changes: 11 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
BlockReqInput,
BlockReqType,
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
Expand Down Expand Up @@ -100,6 +102,8 @@
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
dataclass_to_string_truncated,
enable_colocated_batch_gen,
get_bool_env_var,
get_zmq_socket,
kill_process_tree,
)
Expand Down Expand Up @@ -481,6 +485,9 @@ def _send_one_request(
self.rid_to_state[obj.rid] = state
self.send_to_scheduler.send_pyobj(tokenized_obj)

def _send_block_request(self, type: BlockReqType):
self.send_to_scheduler.send_pyobj(BlockReqInput(type))

async def _wait_one_response(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
Expand Down Expand Up @@ -550,12 +557,16 @@ async def _handle_batch_request(
rids = []
if getattr(obj, "parallel_sample_num", 1) == 1:
# Send all requests
if enable_colocated_batch_gen():
self._send_block_request(BlockReqType.BLOCK)
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request))
rids.append(tmp_obj.rid)
if enable_colocated_batch_gen():
self._send_block_request(BlockReqType.UNBLOCK)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class ServerArgs:
warmups: Optional[str] = None
n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False
enable_scheduler_input_blocker: bool = False

# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
Expand Down Expand Up @@ -1117,6 +1118,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
)
parser.add_argument(
"--enable-scheduler-input-blocker",
action="store_true",
help="Enable input blocker for Scheduler.",
)

# Server warmups
parser.add_argument(
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def clone(self) -> "DynamicGradMode":
return self.__class__()


def enable_colocated_batch_gen():
return get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN", "false")


def enable_show_time_cost():
global show_time_cost
show_time_cost = True
Expand Down
Loading