Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/ray/data/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
each batch. The default value may not be optimal when the batch size
or the batch processing latency is too small, but it should be good
enough for batch size >= 64.
should_continue_on_error: If True, continue processing when inference fails for a row
instead of raising an exception. Failed rows will have a non-null
``__inference_error__`` column containing the error message, and other
output columns will be None. Error rows bypass postprocess. If False
(default), any inference error will raise an exception.
chat_template_stage: Chat templating stage config (bool | dict | ChatTemplateStageConfig).
Defaults to True. Use nested config for per-stage control over batch_size,
concurrency, runtime_env, num_cpus, and memory. Legacy ``apply_chat_template``
Expand Down
12 changes: 12 additions & 0 deletions python/ray/llm/_internal/batch/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ class OfflineProcessorConfig(ProcessorConfig):
"or the batch processing latency is too small, but it should be good "
"enough for batch size >= 32.",
)
should_continue_on_error: bool = Field(
default=False,
description="If True, continue processing when inference fails for a row "
"instead of raising an exception. Failed rows will have a non-null "
"'__inference_error__' column containing the error message, and other "
"output columns will be None. Error rows bypass postprocess. "
"If False (default), any inference error will raise an exception.",
)

# Processor stage configurations (legacy booleans, will be deprecated).
apply_chat_template: bool = Field(
Expand Down Expand Up @@ -304,9 +312,13 @@ def __init__(
self.DATA_COLUMN,
)

# When should_continue_on_error is enabled, include __inference_error__ column
# in all output rows for consistent schema (None for success, message for error).
include_error_column = getattr(config, "should_continue_on_error", False)
self.postprocess = wrap_postprocess(
postprocess,
self.DATA_COLUMN,
include_error_column=include_error_column,
)

for stage in stages:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def build_vllm_engine_processor(
max_pending_requests=config.max_pending_requests,
dynamic_lora_loading_path=config.dynamic_lora_loading_path,
placement_group_config=config.placement_group_config,
should_continue_on_error=config.should_continue_on_error,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
Expand Down
76 changes: 55 additions & 21 deletions python/ray/llm/_internal/batch/stages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,20 @@ def _preprocess(row: dict[str, Any]) -> dict[str, Any]:
def wrap_postprocess(
fn: UserDefinedFunction,
processor_data_column: str,
include_error_column: bool = False,
) -> Callable:
"""Wrap the postprocess function to remove the processor_data_column.
Note that we fully rely on users to determine which columns to carry over.

Error rows (with __inference_error__ set) bypass the user's postprocess
function and return directly with the error information preserved.

Args:
fn: The function to be applied.
processor_data_column: The internal data column name of the processor.
include_error_column: If True, always include __inference_error__ in output
(None for success rows, error message for failures). This ensures
consistent schema across all output rows.

Returns:
The wrapped function.
Expand All @@ -58,7 +65,18 @@ def _postprocess(row: dict[str, Any]) -> dict[str, Any]:
f"[Internal] {processor_data_column} not found in row {row}"
)

return fn(row[processor_data_column])
data = row[processor_data_column]

# Error rows bypass user postprocess to avoid crashes when
# expected output fields are missing. Return entire data dict
# to preserve debugging info (e.g., prompt).
if data.get("__inference_error__") is not None:
return data

result = fn(data)
if include_error_column:
result["__inference_error__"] = None
return result

return _postprocess

Expand Down Expand Up @@ -153,37 +171,53 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]
self.validate_inputs(inputs)

# Assign the index of the row in the batch to the idx_in_batch_column.
# This is beacuse the UDF output may be out-of-order (if asyncio.as_completed
# is used interanlly for example), and we need to carry over unused input
# This is because the UDF output may be out-of-order (if asyncio.as_completed
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Validation fails for error rows before bypass logic

The validate_inputs call at line 171 runs on all rows before error rows are separated (lines 181-190). Error rows from vLLMEngineStage lack required keys like generated_tokens that downstream stages such as DetokenizeStage expect. This causes validation to fail with "Required input keys not found" before error rows can be bypassed. The tests avoid this by passing expected_input_keys=None, but in production the DetokenizeStage has required keys. The validation needs to skip or be called after error row separation.

Additional Locations (1)

Fix in Cursor Fix in Web

# is used internally for example), and we need to carry over unused input
# columns to the next stage. Thus, we use the row index in batch to match
# the output of the UDF with the input.
for idx, row in enumerate(inputs):
row[self.IDX_IN_BATCH_COLUMN] = idx

# Separate error rows from normal rows. Error rows (those with
# __inference_error__ set) bypass the UDF to avoid crashes when
# expected fields are missing (e.g., generated_tokens for DetokenizeUDF).
normal_rows = []
error_row_indices = set()
for idx, row in enumerate(inputs):
if row.get("__inference_error__") is not None:
error_row_indices.add(idx)
else:
normal_rows.append(row)

# Collect all outputs first, then return them in the original order
# This is a requirement set by https://github.com/ray-project/ray/pull/54190/
not_outputed_rows = set(range(len(inputs)))
async for output in self.udf(inputs):
if self.IDX_IN_BATCH_COLUMN not in output:
raise ValueError(
"The output of the UDF must contain the column "
f"{self.IDX_IN_BATCH_COLUMN}."
)
idx_in_batch = output.pop(self.IDX_IN_BATCH_COLUMN)
if idx_in_batch not in not_outputed_rows:
raise ValueError(
f"The row {idx_in_batch} is outputed twice. "
"This is likely due to the UDF is not one-to-one."
)
not_outputed_rows.remove(idx_in_batch)

# Add stage outputs to the data column of the row.
inputs[idx_in_batch].pop(self.IDX_IN_BATCH_COLUMN)
inputs[idx_in_batch].update(output)
not_outputed_rows = set(range(len(inputs))) - error_row_indices
if normal_rows:
async for output in self.udf(normal_rows):
if self.IDX_IN_BATCH_COLUMN not in output:
raise ValueError(
"The output of the UDF must contain the column "
f"{self.IDX_IN_BATCH_COLUMN}."
)
idx_in_batch = output.pop(self.IDX_IN_BATCH_COLUMN)
if idx_in_batch not in not_outputed_rows:
raise ValueError(
f"The row {idx_in_batch} is outputed twice. "
"This is likely due to the UDF is not one-to-one."
)
not_outputed_rows.remove(idx_in_batch)

# Add stage outputs to the data column of the row.
inputs[idx_in_batch].pop(self.IDX_IN_BATCH_COLUMN)
inputs[idx_in_batch].update(output)

if not_outputed_rows:
raise ValueError(f"The rows {not_outputed_rows} are not outputed.")

# Clean up idx column from error rows (normal rows already cleaned above)
for idx in error_row_indices:
inputs[idx].pop(self.IDX_IN_BATCH_COLUMN, None)

# Return all updated inputs in the original order
yield {self.data_column: inputs}

Expand Down
87 changes: 76 additions & 11 deletions python/ray/llm/_internal/batch/stages/vllm_engine_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@

logger = logging.getLogger(__name__)

# vLLM fatal errors that should always be re-raised, never swallowed.
# EngineDeadError indicates the vLLM engine process has crashed and is
# unrecoverable - all subsequent requests would fail anyway.
_VLLM_FATAL_ERRORS: Tuple[Type[Exception], ...] = ()
try:
from vllm.v1.engine.exceptions import EngineDeadError

_VLLM_FATAL_ERRORS = (EngineDeadError,)
except ImportError:
# vLLM not installed or older version without this exception
pass

# Length of prompt snippet to surface in case of recoverable error
_MAX_PROMPT_LENGTH_IN_ERROR = 500


class vLLMTaskType(str, Enum):
"""The type of task to run on the vLLM engine."""
Expand Down Expand Up @@ -457,6 +472,7 @@ def __init__(
task_type: vLLMTaskType = vLLMTaskType.GENERATE,
max_pending_requests: Optional[int] = None,
dynamic_lora_loading_path: Optional[str] = None,
should_continue_on_error: bool = False,
):
"""
Initialize the vLLMEngineStageUDF.
Expand All @@ -471,9 +487,13 @@ def __init__(
it will be set to 1.1 * max_num_seqs * pipeline_parallel_size.
dynamic_lora_loading_path: The path to the dynamic LoRA adapter. It is expected
to hold subfolders each for a different lora checkpoint.
should_continue_on_error: If True, continue processing when inference fails for
a row instead of raising. Failed rows will have '__inference_error__'
set to the error message.
"""
super().__init__(data_column, expected_input_keys)
self.model = model
self.should_continue_on_error = should_continue_on_error

# Setup vLLM engine kwargs.
self.task_type = task_type
Expand Down Expand Up @@ -565,6 +585,57 @@ def normalize_engine_kwargs(
engine_kwargs["task"] = task_type
return engine_kwargs

async def _generate_with_error_handling(
self,
row: Dict[str, Any],
batch_uuid: uuid.UUID,
) -> Dict[str, Any]:
"""Generate output for a single row, catching errors if should_continue_on_error is set.

Args:
row: The input row.
batch_uuid: The batch UUID for logging.

Returns:
The output dict, with __inference_error__ set if an error occurred.
"""
idx_in_batch = row[self.IDX_IN_BATCH_COLUMN]
try:
request, output, time_taken_llm = await self.llm.generate_async(row)
return {
**output,
"request_id": request.request_id,
self.IDX_IN_BATCH_COLUMN: request.idx_in_batch,
"batch_uuid": batch_uuid.hex,
"time_taken_llm": time_taken_llm,
"params": str(request.params),
"__inference_error__": None,
}
except _VLLM_FATAL_ERRORS:
# Fatal engine errors (e.g., EngineDeadError) must always propagate.
# The engine is dead and all subsequent requests would fail.
raise
except Exception as e:
if not self.should_continue_on_error:
raise
error_msg = f"{type(e).__name__}: {str(e)}"
logger.warning(
"[vLLM] Inference failed for row %d in batch %s: %s",
idx_in_batch,
batch_uuid.hex,
error_msg,
)
# Include snippet of failed prompt
prompt = row.get("prompt", "")
if len(prompt) > _MAX_PROMPT_LENGTH_IN_ERROR:
prompt = prompt[:_MAX_PROMPT_LENGTH_IN_ERROR] + "...[truncated]"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Prompt lost in error rows after early pop

The error handling in _generate_with_error_handling tries to capture the prompt via row.get("prompt", "") on line 629. However, generate_async internally calls _prepare_llm_request which pops the prompt from the row (prompt = row.pop("prompt") on line 326) before the actual LLM call. When an error occurs after the prompt is popped (e.g., during vLLM engine processing), the error row will have an empty string for prompt instead of the actual prompt content, defeating the purpose of including it for debugging. The prompt needs to be captured before calling generate_async.

Fix in Cursor Fix in Web

return {
self.IDX_IN_BATCH_COLUMN: idx_in_batch,
"batch_uuid": batch_uuid.hex,
"__inference_error__": error_msg,
"prompt": prompt,
}

async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]:
"""Run the vLLM engine.

Expand All @@ -577,19 +648,13 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
batch_uuid = uuid.uuid4()
batch_start_time = time.perf_counter()

tasks = [asyncio.create_task(self.llm.generate_async(row)) for row in batch]
tasks = [
asyncio.create_task(self._generate_with_error_handling(row, batch_uuid))
for row in batch
]

for resp in asyncio.as_completed(tasks):
request, output, time_taken_llm = await resp

yield {
**output,
"request_id": request.request_id,
self.IDX_IN_BATCH_COLUMN: request.idx_in_batch,
"batch_uuid": batch_uuid.hex,
"time_taken_llm": time_taken_llm,
"params": str(request.params),
}
yield await resp

batch_time_taken = time.perf_counter() - batch_start_time
# TODO: Add metrics to the UDf wrapper so that we don't need
Expand Down
Loading