Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ def __init__(
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.data_parallel_rank = data_parallel_rank

if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
Expand Down Expand Up @@ -180,7 +182,11 @@ def __init__(
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1

self.target_dp_group = bootstrap_room % self.prefill_dp_size
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size

# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def add(self, req: Req) -> None:
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/fake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.has_init = False

Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,13 +765,15 @@ def __init__(
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.conclude_state = None
self.data_parallel_rank = data_parallel_rank

if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
Expand Down Expand Up @@ -845,7 +847,11 @@ def __init__(
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1

self.target_dp_group = self.bootstrap_room % self.prefill_dp_size
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size

# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,10 @@ def __init__(
mgr: NixlKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.started_transfer = False
super().__init__(mgr, bootstrap_addr, bootstrap_room)
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)

def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/entrypoints/EngineBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def generate(
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: Optional[bool] = None,
stream: Optional[bool] = None,
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, Iterator[Dict]]:
"""Generate outputs based on given inputs."""
pass
Expand Down
26 changes: 26 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,22 @@ def generate(
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.info("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
)

obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
Expand All @@ -188,6 +199,7 @@ def generate(
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
)
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None)
Expand Down Expand Up @@ -237,11 +249,24 @@ async def async_generate(
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, AsyncIterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""

if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.info("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
)

logger.info(f"data_parallel_rank: {data_parallel_rank}")
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
Expand All @@ -257,6 +282,7 @@ async def async_generate(
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
)
generator = self.tokenizer_manager.generate_request(obj, None)

Expand Down
18 changes: 13 additions & 5 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,20 @@ def launch_tensor_parallel_group(

def round_robin_scheduler(self, req: Req):
if self.server_args.disaggregation_mode == "null":
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
if req.data_parallel_rank is not None:
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
self.workers[req.data_parallel_rank].send_pyobj(req)
else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)

def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class GenerateReqInput:
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None

# For data parallel rank routing
data_parallel_rank: Optional[int] = None

def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)

Expand Down Expand Up @@ -417,6 +420,9 @@ def __getitem__(self, i):
bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None
),
data_parallel_rank=(
self.data_parallel_rank if self.data_parallel_rank is not None else None
),
)


Expand Down Expand Up @@ -464,6 +470,9 @@ class TokenizedGenerateReqInput:
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None

# For data parallel rank routing
data_parallel_rank: Optional[int] = None


@dataclass
class EmbeddingReqInput:
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def __init__(
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
# Input and output info
self.rid = rid
Expand Down Expand Up @@ -605,6 +606,9 @@ def __init__(
self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None

# For data parallel rank routing
self.data_parallel_rank: Optional[int] = data_parallel_rank

# the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill.
# After every chunk forward, we do the following:
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,7 @@ def handle_generate_request(
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank,
)
req.tokenizer = self.tokenizer

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def _create_tokenized_object(
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
Expand Down
Loading