Skip to content

Commit 34688c4

Browse files
Make ServeDeploymentStage agnostic to LLM engines
Signed-off-by: jeffreyjeffreywang <jeffjeffreywang@gmail.com>
1 parent 2f1e085 commit 34688c4

File tree

11 files changed

+312
-315
lines changed

11 files changed

+312
-315
lines changed

python/ray/data/llm.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,12 @@ class SGLangEngineProcessorConfig(_SGLangEngineProcessorConfig):
249249
class ServeDeploymentProcessorConfig(_ServeDeploymentProcessorConfig):
250250
"""The configuration for the serve deployment processor.
251251
252-
This processor enables sharing LLM engines across multiple processors and is compatible with
253-
both vLLM and SGLang engines, depending on the underlying serve deployment.
252+
This processor enables sharing serve deployments across multiple processors. This is useful
253+
for sharing the same LLM engine across multiple processors.
254254
255255
Args:
256-
llm_config: The LLM config used to build the serve deployment.
257-
name_prefix: The name prefix of the serve application to use.
256+
deployment_name: The name of the serve deployment to use.
258257
app_name: The name of the serve application to use.
259-
method: The method to invoke on the serve deployment.
260258
batch_size: The batch size to send to the serve deployment. Large batch sizes are
261259
likely to saturate the compute resources and could achieve higher throughput.
262260
On the other hand, small batch sizes are more fault-tolerant and could
@@ -296,24 +294,28 @@ class ServeDeploymentProcessorConfig(_ServeDeploymentProcessorConfig):
296294
),
297295
)
298296
299-
APP_NAME = "facebook"
297+
APP_NAME = "facebook_opt_app"
298+
DEPLOYMENT_NAME = "facebook_deployment"
300299
301-
llm_app = build_llm_deployment(llm_config, name_prefix="chat_completions")
300+
llm_app = build_llm_deployment(llm_config, deployment_name=DEPLOYMENT_NAME)
302301
app = serve.run(llm_app, name=APP_NAME)
303302
304303
config=ServeDeploymentProcessorConfig(
305-
llm_config=llm_config,
304+
deployment_name=DEPLOYMENT_NAME,
306305
app_name=APP_NAME,
307-
method="completions",
308306
batch_size=1,
309307
concurrency=1,
310308
)
311309
processor = build_llm_processor(
312310
config,
313-
preprocess=lambda row: CompletionRequest(
314-
model="facebook/opt-1.3b",
315-
prompt=row["prompt"],
316-
stream=False
311+
preprocess=lambda row: dict(
312+
method="completions",
313+
dtype="CompletionRequest",
314+
request_kwargs=dict(
315+
model="facebook/opt-1.3b",
316+
prompt=row["prompt"],
317+
stream=False
318+
)
317319
),
318320
postprocess=lambda row: dict(
319321
resp=row["choices"][0]["text"],

python/ray/llm/_internal/batch/processor/serve_deployment_proc.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
"""The processor that runs serve deployment."""
22

3-
from typing import Optional
3+
from typing import Optional, Type, Any, Tuple
44
from pydantic import Field
55

66
import ray
77
from ray.data.block import UserDefinedFunction
8-
from ray.llm._internal.batch.observability.usage_telemetry.usage import (
9-
BatchModelTelemetry,
10-
TelemetryAgent,
11-
get_or_create_telemetry_agent,
12-
)
138
from ray.llm._internal.batch.processor.base import (
149
ProcessorConfig,
1510
Processor,
@@ -18,30 +13,24 @@
1813
from ray.llm._internal.batch.stages import (
1914
ServeDeploymentStage,
2015
)
21-
from ray.serve.llm import LLMConfig
2216

2317

2418
class ServeDeploymentProcessorConfig(ProcessorConfig):
2519
"""The configuration for the serve deployment processor."""
2620

27-
# Configurations that was used to build the serve deployment
28-
llm_config: LLMConfig = Field(
29-
description="The LLM config to use for the serve deployment.",
21+
# Configurations used to build the serve deployment
22+
deployment_name: str = Field(
23+
description="The name of the serve deployment to use.",
3024
)
3125
app_name: str = Field(
3226
description="The name of the serve application to use.",
3327
)
34-
# Method to invoke on the serve deployment
35-
method: str = Field(
36-
description="The method to use for the serve deployment.",
37-
)
3828

3929

4030
def build_serve_deployment_processor(
4131
config: ServeDeploymentProcessorConfig,
4232
preprocess: Optional[UserDefinedFunction] = None,
4333
postprocess: Optional[UserDefinedFunction] = None,
44-
telemetry_agent: Optional[TelemetryAgent] = None,
4534
) -> Processor:
4635
"""
4736
Construct a processor that runs a serve deployment.
@@ -60,26 +49,15 @@ def build_serve_deployment_processor(
6049
stages = [
6150
ServeDeploymentStage(
6251
fn_constructor_kwargs=dict(
63-
deployment_name=config.llm_config.deployment_name,
52+
deployment_name=config.deployment_name,
6453
app_name=config.app_name,
65-
method=config.method,
6654
),
6755
map_batches_kwargs=dict(
6856
concurrency=config.concurrency,
6957
),
7058
)
7159
]
72-
telemetry_agent = get_or_create_telemetry_agent()
73-
telemetry_agent.push_telemetry_report(
74-
BatchModelTelemetry(
75-
processor_config_name=type(config).__name__,
76-
model_architecture=config.llm_config.model_architecture,
77-
batch_size=config.batch_size,
78-
accelerator_type=config.llm_config.accelerator_type,
79-
concurrency=config.concurrency,
80-
task_type=config.method,
81-
)
82-
)
60+
# TODO (Kourosh): Add telemetry for ServeDeploymentStage
8361
processor = Processor(
8462
config,
8563
stages,

python/ray/llm/_internal/batch/stages/serve_deployment_stage.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,21 @@
33
import logging
44
import time
55
import uuid
6+
from typing import Any, AsyncIterator, Dict, List, Type, Union, Tuple, Optional
7+
import asyncio
8+
from pydantic import BaseModel
69

710
from ray.llm._internal.batch.stages.base import (
811
StatefulStage,
912
StatefulStageUDF,
1013
)
1114
from ray import serve
12-
from typing import Any, AsyncIterator, Dict, List, Type, Union, Tuple
13-
import asyncio
15+
16+
# The following imports are necessary to resolve class references in the global namespace
1417
from ray.llm._internal.serve.configs.openai_api_models import (
1518
CompletionRequest,
1619
ChatCompletionRequest,
20+
EmbeddingRequest,
1721
)
1822

1923
logger = logging.getLogger(__name__)
@@ -26,7 +30,6 @@ def __init__(
2630
expected_input_keys: List[str],
2731
deployment_name: str,
2832
app_name: str,
29-
method: str,
3033
):
3134
"""
3235
Initialize the ServeDeploymentStageUDF.
@@ -36,52 +39,39 @@ def __init__(
3639
expected_input_keys: The expected input keys of the stage.
3740
deployment_name: The name of the deployment.
3841
app_name: The name of the deployment app.
39-
method: The method to call on the deployment.
4042
"""
4143
super().__init__(data_column, expected_input_keys)
4244
self._dh = serve.get_deployment_handle(deployment_name, app_name).options(
4345
stream=True
4446
)
45-
self._method = method
46-
self._request_type = self._resolve_request_type()
4747
self.request_id = 0
4848

49-
def _prepare_request(self, row: Dict[str, Any]) -> Dict[str, Any]:
49+
def _prepare_request(
50+
self, row: Dict[str, Any]
51+
) -> Tuple[Dict[str, Any], Optional[Type[Any]], str]:
5052
"""
5153
Decorate the request with metadata related to the batch.
5254
5355
Args:
5456
row: The row.
5557
5658
Returns:
57-
The decorated request.
59+
A tuple of (decorated_request, dtype, method_name). dtype is the class of the request object and
60+
can be None if the serve deployment accepts a raw dict. method_name is the name of the method to
61+
invoke on the serve deployment.
5862
"""
63+
method = row.get("method")
64+
dtype = globals()[row.get("dtype")] if row.get("dtype") else None
65+
66+
request_kwargs = row.pop("request_kwargs")
5967
request = {
6068
"request_id": str(self.request_id),
6169
"idx_in_batch": row[self.IDX_IN_BATCH_COLUMN],
62-
**row,
70+
**request_kwargs,
6371
}
6472
self.request_id += 1
65-
return request
6673

67-
def _resolve_request_type(self) -> Union[ChatCompletionRequest, CompletionRequest]:
68-
"""
69-
Resolve the request type based on the method.
70-
71-
Returns:
72-
The request type.
73-
"""
74-
if getattr(self._dh, self._method) is None:
75-
raise ValueError(
76-
f"Method {self._method} is not supported by the serve deployment."
77-
)
78-
79-
if self._method == "chat":
80-
return ChatCompletionRequest
81-
elif self._method == "completions":
82-
return CompletionRequest
83-
else:
84-
raise ValueError(f"Unsupported method: {self._method}")
74+
return request, dtype, method
8575

8676
async def generate_async(
8777
self, row: Dict[str, Any]
@@ -95,14 +85,25 @@ async def generate_async(
9585
Returns:
9686
The response from the serve deployment.
9787
"""
98-
request = self._prepare_request(row)
99-
t = time.perf_counter()
88+
request, dtype, method = self._prepare_request(row)
10089

101-
output_data = await anext(
102-
getattr(self._dh, self._method).remote(self._request_type(**request))
103-
)
90+
if dtype is not None:
91+
request_obj = dtype(**request)
92+
else:
93+
request_obj = request
94+
95+
if getattr(self._dh, method) is None:
96+
raise ValueError(f"Method {method} not found in the serve deployment.")
97+
98+
t = time.perf_counter()
99+
output_data = await anext(getattr(self._dh, method).remote(request_obj))
104100
time_taken = time.perf_counter() - t
105-
return request, output_data.model_dump(), time_taken
101+
102+
# Convert the output data to a dict if it is a Pydantic model.
103+
if isinstance(output_data, BaseModel):
104+
output_data = output_data.model_dump()
105+
106+
return request, output_data, time_taken
106107

107108
async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]:
108109
"""
@@ -122,7 +123,7 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
122123
"request_id": request["request_id"],
123124
self.IDX_IN_BATCH_COLUMN: request["idx_in_batch"],
124125
"batch_uuid": batch_uuid.hex,
125-
"time_taken_llm": time_taken,
126+
"time_taken": time_taken,
126127
**output,
127128
}
128129

@@ -137,3 +138,9 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
137138

138139
class ServeDeploymentStage(StatefulStage):
139140
fn: Type[StatefulStageUDF] = ServeDeploymentStageUDF
141+
142+
def get_required_input_keys(self) -> Dict[str, str]:
143+
return {
144+
"method": "Name of the method to invoke on the serve deployment.",
145+
"request_kwargs": "The request_kwargs to construct the request to the serve deployment.",
146+
}

python/ray/llm/_internal/batch/stages/sglang_engine_stage.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,24 +177,22 @@ async def _prepare_llm_request(self, row: Dict[str, Any]) -> SGLangEngineRequest
177177

178178
async def generate_async(
179179
self, row: Dict[str, Any]
180-
) -> Tuple[SGLangEngineRequest, Dict[str, Any], float]:
180+
) -> Tuple[SGLangEngineRequest, Dict[str, Any]]:
181181
"""Process a single request.
182182
183183
Args:
184184
request: The request.
185185
186186
Returns:
187-
A tuple of index in batch, request output, bypassed custom fields, and time taken.
187+
A tuple of index in batch, request output and bypassed custom fields.
188188
"""
189189
request = await self._prepare_llm_request(row)
190-
t = time.perf_counter()
191190

192191
async with self.semaphore:
193192
output = await self._generate_async(request)
194193

195-
time_taken = time.perf_counter() - t
196194
output_data = SGLangOutputData.from_sglang_engine_output(output)
197-
return request, output_data.model_dump(), time_taken
195+
return request, output_data.model_dump()
198196

199197
async def _generate_async(self, request: SGLangEngineRequest) -> Any:
200198
"""Process a single request.
@@ -327,24 +325,25 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
327325

328326
tasks = [asyncio.create_task(self.llm.generate_async(row)) for row in batch]
329327

328+
time_taken = -1.0
330329
for resp in asyncio.as_completed(tasks):
331-
request, output, time_taken_llm = await resp
330+
request, output = await resp
331+
time_taken = time.perf_counter() - t
332332

333333
yield {
334334
**output,
335335
"request_id": request.request_id,
336336
self.IDX_IN_BATCH_COLUMN: request.idx_in_batch,
337337
"batch_uuid": batch_uuid.hex,
338-
"time_taken_llm": time_taken_llm,
338+
"time_taken_llm": time_taken,
339339
"params": str(request.params),
340340
}
341341

342-
batch_time_taken = time.perf_counter() - t
343342
logger.info(
344343
"[SGLang] Elapsed time for batch %s with size %d: %s",
345344
batch_uuid.hex,
346345
len(batch),
347-
batch_time_taken,
346+
time_taken,
348347
)
349348

350349
def __del__(self):

0 commit comments

Comments
 (0)