Skip to content

Commit 11e27d0

Browse files
hcyz33shuaills
andauthored
[PD]: Support Muti Prefill in one node (#5704)
Co-authored-by: shuaills <shishuaiuoe@gmail.com>
1 parent 50eda83 commit 11e27d0

File tree

6 files changed

+55
-9
lines changed

6 files changed

+55
-9
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def add(self, req: Req) -> None:
137137
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
138138
kv_receiver = kv_receiver_class(
139139
mgr=self.kv_manager,
140-
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
140+
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
141141
bootstrap_room=req.bootstrap_room,
142142
)
143143
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))

python/sglang/srt/disaggregation/mini_lb.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import random
77
import urllib
88
from itertools import chain
9+
from typing import List
910

1011
import aiohttp
1112
import orjson
@@ -14,13 +15,22 @@
1415
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
1516

1617

18+
class PrefillConfig:
19+
def __init__(self, url: str, bootstrap_port: int):
20+
self.url = url
21+
self.bootstrap_port = bootstrap_port
22+
23+
1724
class MiniLoadBalancer:
18-
def __init__(self, prefill_servers, decode_servers):
19-
self.prefill_servers = prefill_servers
25+
def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
26+
self.prefill_configs = prefill_configs
27+
self.prefill_servers = [p.url for p in prefill_configs]
2028
self.decode_servers = decode_servers
2129

2230
def select_pair(self):
23-
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
31+
prefill_config = random.choice(self.prefill_configs)
32+
decode_server = random.choice(self.decode_servers)
33+
return prefill_config.url, prefill_config.bootstrap_port, decode_server
2434

2535
async def generate(
2636
self, modified_request, prefill_server, decode_server, endpoint
@@ -160,7 +170,7 @@ async def get_model_info():
160170

161171
@app.post("/generate")
162172
async def handle_generate_request(request_data: dict):
163-
prefill_server, decode_server = load_balancer.select_pair()
173+
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
164174

165175
# Parse and transform prefill_server for bootstrap data
166176
parsed_url = urllib.parse.urlparse(prefill_server)
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
172182
modified_request.update(
173183
{
174184
"bootstrap_host": [hostname] * batch_size,
185+
"bootstrap_port": [bootstrap_port] * batch_size,
175186
"bootstrap_room": [
176187
_generate_bootstrap_room() for _ in range(batch_size)
177188
],
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
181192
modified_request.update(
182193
{
183194
"bootstrap_host": hostname,
195+
"bootstrap_port": bootstrap_port,
184196
"bootstrap_room": _generate_bootstrap_room(),
185197
}
186198
)
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
197209

198210
@app.post("/v1/chat/completions")
199211
async def handle_completion_request(request_data: dict):
200-
prefill_server, decode_server = load_balancer.select_pair()
212+
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
201213

202214
# Parse and transform prefill_server for bootstrap data
203215
parsed_url = urllib.parse.urlparse(prefill_server)
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
206218
modified_request.update(
207219
{
208220
"bootstrap_host": hostname,
221+
"bootstrap_port": bootstrap_port,
209222
"bootstrap_room": random.randint(0, 2**63 - 1),
210223
}
211224
)
@@ -255,9 +268,9 @@ async def get_models():
255268
raise HTTPException(status_code=500, detail=str(e))
256269

257270

258-
def run(prefill_addrs, decode_addrs, host, port):
271+
def run(prefill_configs, decode_addrs, host, port):
259272
global load_balancer
260-
load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
273+
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
261274
uvicorn.run(app, host=host, port=port)
262275

263276

@@ -268,6 +281,11 @@ def run(prefill_addrs, decode_addrs, host, port):
268281
parser.add_argument(
269282
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
270283
)
284+
parser.add_argument(
285+
"--prefill-bootstrap-ports",
286+
help="Comma-separated bootstrap ports for prefill servers",
287+
default="8998",
288+
)
271289
parser.add_argument(
272290
"--decode", required=True, help="Comma-separated URLs for decode servers"
273291
)
@@ -278,4 +296,23 @@ def run(prefill_addrs, decode_addrs, host, port):
278296
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
279297
)
280298
args = parser.parse_args()
281-
run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
299+
300+
prefill_urls = args.prefill.split(",")
301+
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
302+
303+
if len(bootstrap_ports) == 1:
304+
bootstrap_ports = bootstrap_ports * len(prefill_urls)
305+
else:
306+
if len(bootstrap_ports) != len(prefill_urls):
307+
raise ValueError(
308+
"Number of prefill URLs must match number of bootstrap ports"
309+
)
310+
exit(1)
311+
312+
prefill_configs = []
313+
for url, port in zip(prefill_urls, bootstrap_ports):
314+
prefill_configs.append(PrefillConfig(url, port))
315+
316+
decode_addrs = args.decode.split(",")
317+
318+
run(prefill_configs, decode_addrs, args.host, args.port)

python/sglang/srt/managers/io_struct.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class GenerateReqInput:
9797

9898
# For disaggregated inference
9999
bootstrap_host: Optional[Union[List[str], str]] = None
100+
bootstrap_port: Optional[Union[List[int], int]] = None
100101
bootstrap_room: Optional[Union[List[int], int]] = None
101102

102103
def normalize_batch_and_arguments(self):
@@ -400,6 +401,9 @@ def __getitem__(self, i):
400401
bootstrap_host=(
401402
self.bootstrap_host[i] if self.bootstrap_host is not None else None
402403
),
404+
bootstrap_port=(
405+
self.bootstrap_port[i] if self.bootstrap_port is not None else None
406+
),
403407
bootstrap_room=(
404408
self.bootstrap_room[i] if self.bootstrap_room is not None else None
405409
),
@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
447451

448452
# For disaggregated inference
449453
bootstrap_host: Optional[str] = None
454+
bootstrap_port: Optional[int] = None
450455
bootstrap_room: Optional[int] = None
451456

452457

python/sglang/srt/managers/schedule_batch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def __init__(
391391
return_hidden_states: bool = False,
392392
eos_token_ids: Optional[Set[int]] = None,
393393
bootstrap_host: Optional[str] = None,
394+
bootstrap_port: Optional[int] = None,
394395
bootstrap_room: Optional[int] = None,
395396
):
396397
# Input and output info
@@ -526,6 +527,7 @@ def __init__(
526527

527528
# For disaggregation
528529
self.bootstrap_host: str = bootstrap_host
530+
self.bootstrap_port: Optional[int] = bootstrap_port
529531
self.bootstrap_room: Optional[int] = bootstrap_room
530532
self.disagg_kv_sender: Optional[BaseKVSender] = None
531533

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,7 @@ def handle_generate_request(
791791
return_hidden_states=recv_req.return_hidden_states,
792792
eos_token_ids=self.model_config.hf_eos_token_id,
793793
bootstrap_host=recv_req.bootstrap_host,
794+
bootstrap_port=recv_req.bootstrap_port,
794795
bootstrap_room=recv_req.bootstrap_room,
795796
)
796797
req.tokenizer = self.tokenizer

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ def _create_tokenized_object(
498498
token_ids_logprob,
499499
obj.stream,
500500
bootstrap_host=obj.bootstrap_host,
501+
bootstrap_port=obj.bootstrap_port,
501502
bootstrap_room=obj.bootstrap_room,
502503
lora_path=obj.lora_path,
503504
input_embeds=input_embeds,

0 commit comments

Comments
 (0)