Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions doc/source/serve/doc_code/multiplexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,34 @@ async def __call__(self, request: starlette.requests.Request):
serve.run(Upstream.bind(Downstream.bind()))
resp = requests.get("http://localhost:8000")
# __serve_model_composition_example_end__


# __serve_multiplexed_batching_example_begin__
from typing import List # noqa: E402
from starlette.requests import Request


@serve.deployment(max_ongoing_requests=15)
class BatchedMultiplexModel:
@serve.multiplexed(max_num_models_per_replica=3)
async def get_model(self, model_id: str):
# Load and return your model here
return model_id

@serve.batch(max_batch_size=10, batch_wait_timeout_s=0.1)
async def batched_predict(self, inputs: List[str]) -> List[str]:
# Get the model ID - this works correctly inside batched functions
# because all requests in the batch target the same model
model_id = serve.get_multiplexed_model_id()
model = await self.get_model(model_id)

# Process the batch with the loaded model
return [f"{model}:{inp}" for inp in inputs]

async def __call__(self, request: Request):
# Extract input from the request body
input_text = await request.body()
return await self.batched_predict(input_text.decode())


# __serve_multiplexed_batching_example_end__
16 changes: 16 additions & 0 deletions doc/source/serve/model-multiplexing.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,19 @@ When using model composition, you can send requests from an upstream deployment
:start-after: __serve_model_composition_example_begin__
:end-before: __serve_model_composition_example_end__
```

## Using model multiplexing with batching

You can combine model multiplexing with the `@serve.batch` decorator for efficient batched inference. When you use both features together, Ray Serve automatically splits batches by model ID to ensure each batch contains only requests for the same model. This prevents issues where a single batch would contain requests targeting different models.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I understand this description is that Serve will treat each model's batch independently, i.e. waiting to reach the max_batch_size or the timeout before firing for each model, but in reality, it waits for the max_batch_size or timeout across all models. For example if our max_batch_size=8, Serve will process sub batches of size [1, 4, 3] instead of waiting for each model to have 8 request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right.


The following example shows how to combine multiplexing with batching:

```{literalinclude} doc_code/multiplexed.py
:language: python
:start-after: __serve_multiplexed_batching_example_begin__
:end-before: __serve_multiplexed_batching_example_end__
```

:::{note}
`serve.get_multiplexed_model_id()` works correctly inside functions decorated with `@serve.batch`. Ray Serve guarantees that all requests in a batch have the same `multiplexed_model_id`, so you can safely use this value to load and apply the appropriate model for the entire batch.
:::
13 changes: 13 additions & 0 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,11 @@ def get_multiplexed_model_id() -> str:
This is used with a function decorated with `@serve.multiplexed`
to retrieve the model ID for the current request.

When called from within a batched function (decorated with `@serve.batch`),
this returns the multiplexed model ID that is common to all requests in
the current batch. This works because batches are automatically split
by model ID to ensure all requests in a batch target the same model.

.. code-block:: python

import ray
Expand All @@ -911,6 +916,14 @@ def get_multiplexed_model_id() -> str:
def my_deployment_function(request):
assert serve.get_multiplexed_model_id() == "model_1"
"""
# First check if we're inside a batch context. If so, get the model ID
# from the batch request context. All requests in a batch are guaranteed
# to have the same multiplexed_model_id (batches are split by model ID).
batch_request_context = ray.serve.context._get_serve_batch_request_context()
if batch_request_context:
return batch_request_context[0].multiplexed_model_id

# Fall back to the regular request context
_request_context = ray.serve.context._get_serve_request_context()
return _request_context.multiplexed_model_id

Expand Down
209 changes: 131 additions & 78 deletions python/ray/serve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,105 +435,158 @@ async def _assign_func_results(
for future in futures:
_set_exception_if_not_done(future, e)

def _split_batch_by_model_id(
self, batch: List[_SingleRequest]
) -> List[List[_SingleRequest]]:
"""Split a batch into sub-batches based on multiplexed_model_id.

When using model multiplexing with batching, requests for different models
may end up in the same batch. This method ensures that each sub-batch only
contains requests for the same model, preventing issues where a single batch
contains requests for different models.

If no requests have a multiplexed_model_id set, returns the original batch
as a single sub-batch.

Args:
batch: The batch of requests to split.

Returns:
A list of sub-batches, where each sub-batch contains requests for the
same multiplexed_model_id.
"""
# Group requests by their multiplexed_model_id
model_id_to_requests: Dict[str, List[_SingleRequest]] = {}
for request in batch:
model_id = request.request_context.multiplexed_model_id
if model_id not in model_id_to_requests:
model_id_to_requests[model_id] = []
model_id_to_requests[model_id].append(request)

# Return sub-batches for each model_id
return list(model_id_to_requests.values())

async def _process_batches(self, func: Callable) -> None:
"""Loops infinitely and processes queued request batches."""
# When asyncio task is created, the task will inherit the request context from the current context.
# So we unset the request context so the current context is not inherited by the task, _process_batch.
serve.context._unset_request_context()
while not self._loop.is_closed():
batch, computed_batch_size = await self.wait_for_batch()
promise = self._process_batch(func, batch, computed_batch_size)
batch, _ = await self.wait_for_batch()

# Split batch by multiplexed_model_id to ensure requests for different
# models are processed in separate batches. This is necessary when using
# model multiplexing with batching, as a single batch containing requests
# for different models would not work correctly.
sub_batches = self._split_batch_by_model_id(batch)

# Process all sub-batches together under a single semaphore permit.
# This ensures sub-batches from the same original batch run concurrently
# rather than being serialized by the semaphore.
promise = self._process_sub_batches(func, sub_batches)
task = asyncio.create_task(promise)
self.tasks.add(task)
self.curr_iteration_start_times[task] = time.time()
task.add_done_callback(self._handle_completed_task)

async def _process_batch(
self, func: Callable, batch: List[_SingleRequest], computed_batch_size: int
async def _process_sub_batches(
self, func: Callable, sub_batches: List[List[_SingleRequest]]
) -> None:
"""Processes queued request batch."""
"""Processes multiple sub-batches concurrently under a single semaphore permit.

This method acquires the semaphore once and then processes all sub-batches
in parallel, ensuring that sub-batches from the same original batch don't
compete for semaphore permits.
"""
# NOTE: this semaphore caps the number of concurrent batches specified by `max_concurrent_batches`
async with self.semaphore:
# Remove requests that have been cancelled from the batch. If
# all requests have been cancelled, simply return and wait for
# the next batch.
original_batch_len = len(batch)
batch = [req for req in batch if not req.future.cancelled()]
if len(batch) == 0:
return

# Record batch utilization metric.
# Use computed_batch_size from wait_for_batch for efficiency.
# If requests were cancelled, we need to recompute since the batch changed.
if len(batch) != original_batch_len:
computed_batch_size = self._compute_batch_size(batch)

# Calculate and record batch utilization percentage.
batch_utilization_percent = (
computed_batch_size / self.max_batch_size
) * 100
self._batch_utilization_histogram.observe(
batch_utilization_percent, tags={"function_name": self._function_name}
)
# Create tasks for each sub-batch. We use asyncio.create_task() instead
# of passing coroutines directly to asyncio.gather() because create_task
# copies the current context, giving each sub-batch its own isolated
# contextvars. This prevents concurrent sub-batches from overwriting
# each other's _serve_batch_request_context, which would cause
# get_multiplexed_model_id() to return wrong values.
tasks = [
asyncio.create_task(self._process_batch_inner(func, sub_batch))
for sub_batch in sub_batches
]
await asyncio.gather(*tasks)

async def _process_batch_inner(
self, func: Callable, batch: List[_SingleRequest]
) -> None:
"""Processes a single batch without acquiring the semaphore.

# Record actual batch size (number of requests in the batch computed by the batch_size_fn).
self._batch_size_histogram.observe(
computed_batch_size, tags={"function_name": self._function_name}
)
This is the inner implementation called by _process_sub_batches after
the semaphore has already been acquired.
"""
# Remove requests that have been cancelled from the batch. If
# all requests have been cancelled, simply return and wait for
# the next batch.
batch = [req for req in batch if not req.future.cancelled()]
if len(batch) == 0:
return

# Increment batches processed counter.
self._batches_processed_counter.inc(
tags={"function_name": self._function_name}
)
# Compute batch size for this sub-batch. Each sub-batch may have a different
# size, especially when splitting by model_id, so we compute it here.
computed_batch_size = self._compute_batch_size(batch)

futures = [item.future for item in batch]
# Calculate and record batch utilization percentage.
batch_utilization_percent = (computed_batch_size / self.max_batch_size) * 100
self._batch_utilization_histogram.observe(
batch_utilization_percent, tags={"function_name": self._function_name}
)

# Most of the logic in the function should be wrapped in this try-
# except block, so the futures' exceptions can be set if an exception
# occurs. Otherwise, the futures' requests may hang indefinitely.
batch_execution_start_time = time.time()
try:
self_arg = batch[0].self_arg
args, kwargs = _batch_args_kwargs(
[item.flattened_args for item in batch]
)
# Record actual batch size (number of requests in the batch computed by the batch_size_fn).
self._batch_size_histogram.observe(
computed_batch_size, tags={"function_name": self._function_name}
)

# Method call.
if self_arg is not None:
func_future_or_generator = func(self_arg, *args, **kwargs)
# Normal function call.
else:
func_future_or_generator = func(*args, **kwargs)
# Increment batches processed counter.
self._batches_processed_counter.inc(tags={"function_name": self._function_name})

# Add individual request context to the batch request context
serve.context._set_batch_request_context(
[req.request_context for req in batch]
)
futures = [item.future for item in batch]

if isasyncgenfunction(func):
func_generator = func_future_or_generator
await self._consume_func_generator(
func_generator, futures, len(batch)
)
else:
func_future = func_future_or_generator
await self._assign_func_results(func_future, futures, len(batch))

# Reset the batch request context after the batch is processed
serve.context._set_batch_request_context([])
except Exception as e:
logger.exception("_process_batch ran into an unexpected exception.")

for future in futures:
_set_exception_if_not_done(future, e)
finally:
# Record batch execution time.
batch_execution_time_ms = (
time.time() - batch_execution_start_time
) * 1000
self._batch_execution_time_histogram.observe(
batch_execution_time_ms, tags={"function_name": self._function_name}
)
# Most of the logic in the function should be wrapped in this try-
# except block, so the futures' exceptions can be set if an exception
# occurs. Otherwise, the futures' requests may hang indefinitely.
batch_execution_start_time = time.time()
try:
self_arg = batch[0].self_arg
args, kwargs = _batch_args_kwargs([item.flattened_args for item in batch])

# Method call.
if self_arg is not None:
func_future_or_generator = func(self_arg, *args, **kwargs)
# Normal function call.
else:
func_future_or_generator = func(*args, **kwargs)

# Add individual request context to the batch request context
serve.context._set_batch_request_context(
[req.request_context for req in batch]
)

if isasyncgenfunction(func):
func_generator = func_future_or_generator
await self._consume_func_generator(func_generator, futures, len(batch))
else:
func_future = func_future_or_generator
await self._assign_func_results(func_future, futures, len(batch))

# Reset the batch request context after the batch is processed
serve.context._set_batch_request_context([])
except Exception as e:
logger.exception("_process_batch ran into an unexpected exception.")

for future in futures:
_set_exception_if_not_done(future, e)
finally:
# Record batch execution time.
batch_execution_time_ms = (time.time() - batch_execution_start_time) * 1000
self._batch_execution_time_histogram.observe(
batch_execution_time_ms, tags={"function_name": self._function_name}
)

def _handle_completed_task(self, task: asyncio.Task) -> None:
self.tasks.remove(task)
Expand Down
Loading