66import random
77import urllib
88from itertools import chain
9+ from typing import List
910
1011import aiohttp
1112import orjson
1415from fastapi .responses import ORJSONResponse , Response , StreamingResponse
1516
1617
18+ class PrefillConfig :
19+ def __init__ (self , url : str , bootstrap_port : int ):
20+ self .url = url
21+ self .bootstrap_port = bootstrap_port
22+
23+
1724class MiniLoadBalancer :
18- def __init__ (self , prefill_servers , decode_servers ):
19- self .prefill_servers = prefill_servers
25+ def __init__ (self , prefill_configs : List [PrefillConfig ], decode_servers : List [str ]):
26+ self .prefill_configs = prefill_configs
27+ self .prefill_servers = [p .url for p in prefill_configs ]
2028 self .decode_servers = decode_servers
2129
2230 def select_pair (self ):
23- return random .choice (self .prefill_servers ), random .choice (self .decode_servers )
31+ prefill_config = random .choice (self .prefill_configs )
32+ decode_server = random .choice (self .decode_servers )
33+ return prefill_config .url , prefill_config .bootstrap_port , decode_server
2434
2535 async def generate (
2636 self , modified_request , prefill_server , decode_server , endpoint
@@ -160,7 +170,7 @@ async def get_model_info():
160170
161171@app .post ("/generate" )
162172async def handle_generate_request (request_data : dict ):
163- prefill_server , decode_server = load_balancer .select_pair ()
173+ prefill_server , bootstrap_port , decode_server = load_balancer .select_pair ()
164174
165175 # Parse and transform prefill_server for bootstrap data
166176 parsed_url = urllib .parse .urlparse (prefill_server )
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
172182 modified_request .update (
173183 {
174184 "bootstrap_host" : [hostname ] * batch_size ,
185+ "bootstrap_port" : [bootstrap_port ] * batch_size ,
175186 "bootstrap_room" : [
176187 _generate_bootstrap_room () for _ in range (batch_size )
177188 ],
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
181192 modified_request .update (
182193 {
183194 "bootstrap_host" : hostname ,
195+ "bootstrap_port" : bootstrap_port ,
184196 "bootstrap_room" : _generate_bootstrap_room (),
185197 }
186198 )
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
197209
198210@app .post ("/v1/chat/completions" )
199211async def handle_completion_request (request_data : dict ):
200- prefill_server , decode_server = load_balancer .select_pair ()
212+ prefill_server , bootstrap_port , decode_server = load_balancer .select_pair ()
201213
202214 # Parse and transform prefill_server for bootstrap data
203215 parsed_url = urllib .parse .urlparse (prefill_server )
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
206218 modified_request .update (
207219 {
208220 "bootstrap_host" : hostname ,
221+ "bootstrap_port" : bootstrap_port ,
209222 "bootstrap_room" : random .randint (0 , 2 ** 63 - 1 ),
210223 }
211224 )
@@ -255,9 +268,9 @@ async def get_models():
255268 raise HTTPException (status_code = 500 , detail = str (e ))
256269
257270
258- def run (prefill_addrs , decode_addrs , host , port ):
271+ def run (prefill_configs , decode_addrs , host , port ):
259272 global load_balancer
260- load_balancer = MiniLoadBalancer (prefill_addrs , decode_addrs )
273+ load_balancer = MiniLoadBalancer (prefill_configs , decode_addrs )
261274 uvicorn .run (app , host = host , port = port )
262275
263276
@@ -268,6 +281,11 @@ def run(prefill_addrs, decode_addrs, host, port):
268281 parser .add_argument (
269282 "--prefill" , required = True , help = "Comma-separated URLs for prefill servers"
270283 )
284+ parser .add_argument (
285+ "--prefill-bootstrap-ports" ,
286+ help = "Comma-separated bootstrap ports for prefill servers" ,
287+ default = "8998" ,
288+ )
271289 parser .add_argument (
272290 "--decode" , required = True , help = "Comma-separated URLs for decode servers"
273291 )
@@ -278,4 +296,23 @@ def run(prefill_addrs, decode_addrs, host, port):
278296 "--port" , type = int , default = 8000 , help = "Port to bind the server (default: 8000)"
279297 )
280298 args = parser .parse_args ()
281- run (args .prefill .split ("," ), args .decode .split ("," ), args .host , args .port )
299+
300+ prefill_urls = args .prefill .split ("," )
301+ bootstrap_ports = [int (p ) for p in args .prefill_bootstrap_ports .split ("," )]
302+
303+ if len (bootstrap_ports ) == 1 :
304+ bootstrap_ports = bootstrap_ports * len (prefill_urls )
305+ else :
306+ if len (bootstrap_ports ) != len (prefill_urls ):
307+ raise ValueError (
308+ "Number of prefill URLs must match number of bootstrap ports"
309+ )
310+ exit (1 )
311+
312+ prefill_configs = []
313+ for url , port in zip (prefill_urls , bootstrap_ports ):
314+ prefill_configs .append (PrefillConfig (url , port ))
315+
316+ decode_addrs = args .decode .split ("," )
317+
318+ run (prefill_configs , decode_addrs , args .host , args .port )
0 commit comments