diff --git a/python/ray/data/llm.py b/python/ray/data/llm.py index 57ae712c9002..9dac042a5c70 100644 --- a/python/ray/data/llm.py +++ b/python/ray/data/llm.py @@ -113,6 +113,11 @@ class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig): each batch. The default value may not be optimal when the batch size or the batch processing latency is too small, but it should be good enough for batch size >= 64. + should_continue_on_error: If True, continue processing when inference fails for a row + instead of raising an exception. Failed rows will have a non-null + ``__inference_error__`` column containing the error message, and other + output columns will be None. Error rows bypass postprocess. If False + (default), any inference error will raise an exception. chat_template_stage: Chat templating stage config (bool | dict | ChatTemplateStageConfig). Defaults to True. Use nested config for per-stage control over batch_size, concurrency, runtime_env, num_cpus, and memory. Legacy ``apply_chat_template`` diff --git a/python/ray/llm/_internal/batch/processor/base.py b/python/ray/llm/_internal/batch/processor/base.py index 276248e01bc4..e8d150fcdc2d 100644 --- a/python/ray/llm/_internal/batch/processor/base.py +++ b/python/ray/llm/_internal/batch/processor/base.py @@ -154,6 +154,14 @@ class OfflineProcessorConfig(ProcessorConfig): "or the batch processing latency is too small, but it should be good " "enough for batch size >= 32.", ) + should_continue_on_error: bool = Field( + default=False, + description="If True, continue processing when inference fails for a row " + "instead of raising an exception. Failed rows will have a non-null " + "'__inference_error__' column containing the error message, and other " + "output columns will be None. Error rows bypass postprocess. " + "If False (default), any inference error will raise an exception.", + ) # Processor stage configurations (legacy booleans, will be deprecated). apply_chat_template: bool = Field( @@ -304,9 +312,13 @@ def __init__( self.DATA_COLUMN, ) + # When should_continue_on_error is enabled, include __inference_error__ column + # in all output rows for consistent schema (None for success, message for error). + include_error_column = getattr(config, "should_continue_on_error", False) self.postprocess = wrap_postprocess( postprocess, self.DATA_COLUMN, + include_error_column=include_error_column, ) for stage in stages: diff --git a/python/ray/llm/_internal/batch/processor/vllm_engine_proc.py b/python/ray/llm/_internal/batch/processor/vllm_engine_proc.py index e5da56b4d32f..e0d8ee0e2c6e 100644 --- a/python/ray/llm/_internal/batch/processor/vllm_engine_proc.py +++ b/python/ray/llm/_internal/batch/processor/vllm_engine_proc.py @@ -217,6 +217,7 @@ def build_vllm_engine_processor( max_pending_requests=config.max_pending_requests, dynamic_lora_loading_path=config.dynamic_lora_loading_path, placement_group_config=config.placement_group_config, + should_continue_on_error=config.should_continue_on_error, ), map_batches_kwargs=dict( zero_copy_batch=True, diff --git a/python/ray/llm/_internal/batch/stages/base.py b/python/ray/llm/_internal/batch/stages/base.py index 8e57f1738838..d822ce275808 100644 --- a/python/ray/llm/_internal/batch/stages/base.py +++ b/python/ray/llm/_internal/batch/stages/base.py @@ -40,13 +40,20 @@ def _preprocess(row: dict[str, Any]) -> dict[str, Any]: def wrap_postprocess( fn: UserDefinedFunction, processor_data_column: str, + include_error_column: bool = False, ) -> Callable: """Wrap the postprocess function to remove the processor_data_column. Note that we fully rely on users to determine which columns to carry over. + Error rows (with __inference_error__ set) bypass the user's postprocess + function and return directly with the error information preserved. + Args: fn: The function to be applied. processor_data_column: The internal data column name of the processor. + include_error_column: If True, always include __inference_error__ in output + (None for success rows, error message for failures). This ensures + consistent schema across all output rows. Returns: The wrapped function. @@ -58,7 +65,18 @@ def _postprocess(row: dict[str, Any]) -> dict[str, Any]: f"[Internal] {processor_data_column} not found in row {row}" ) - return fn(row[processor_data_column]) + data = row[processor_data_column] + + # Error rows bypass user postprocess to avoid crashes when + # expected output fields are missing. Return entire data dict + # to preserve debugging info (e.g., prompt). + if data.get("__inference_error__") is not None: + return data + + result = fn(data) + if include_error_column: + result["__inference_error__"] = None + return result return _postprocess @@ -153,37 +171,53 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] self.validate_inputs(inputs) # Assign the index of the row in the batch to the idx_in_batch_column. - # This is beacuse the UDF output may be out-of-order (if asyncio.as_completed - # is used interanlly for example), and we need to carry over unused input + # This is because the UDF output may be out-of-order (if asyncio.as_completed + # is used internally for example), and we need to carry over unused input # columns to the next stage. Thus, we use the row index in batch to match # the output of the UDF with the input. for idx, row in enumerate(inputs): row[self.IDX_IN_BATCH_COLUMN] = idx + # Separate error rows from normal rows. Error rows (those with + # __inference_error__ set) bypass the UDF to avoid crashes when + # expected fields are missing (e.g., generated_tokens for DetokenizeUDF). + normal_rows = [] + error_row_indices = set() + for idx, row in enumerate(inputs): + if row.get("__inference_error__") is not None: + error_row_indices.add(idx) + else: + normal_rows.append(row) + # Collect all outputs first, then return them in the original order # This is a requirement set by https://github.com/ray-project/ray/pull/54190/ - not_outputed_rows = set(range(len(inputs))) - async for output in self.udf(inputs): - if self.IDX_IN_BATCH_COLUMN not in output: - raise ValueError( - "The output of the UDF must contain the column " - f"{self.IDX_IN_BATCH_COLUMN}." - ) - idx_in_batch = output.pop(self.IDX_IN_BATCH_COLUMN) - if idx_in_batch not in not_outputed_rows: - raise ValueError( - f"The row {idx_in_batch} is outputed twice. " - "This is likely due to the UDF is not one-to-one." - ) - not_outputed_rows.remove(idx_in_batch) - - # Add stage outputs to the data column of the row. - inputs[idx_in_batch].pop(self.IDX_IN_BATCH_COLUMN) - inputs[idx_in_batch].update(output) + not_outputed_rows = set(range(len(inputs))) - error_row_indices + if normal_rows: + async for output in self.udf(normal_rows): + if self.IDX_IN_BATCH_COLUMN not in output: + raise ValueError( + "The output of the UDF must contain the column " + f"{self.IDX_IN_BATCH_COLUMN}." + ) + idx_in_batch = output.pop(self.IDX_IN_BATCH_COLUMN) + if idx_in_batch not in not_outputed_rows: + raise ValueError( + f"The row {idx_in_batch} is outputed twice. " + "This is likely due to the UDF is not one-to-one." + ) + not_outputed_rows.remove(idx_in_batch) + + # Add stage outputs to the data column of the row. + inputs[idx_in_batch].pop(self.IDX_IN_BATCH_COLUMN) + inputs[idx_in_batch].update(output) if not_outputed_rows: raise ValueError(f"The rows {not_outputed_rows} are not outputed.") + # Clean up idx column from error rows (normal rows already cleaned above) + for idx in error_row_indices: + inputs[idx].pop(self.IDX_IN_BATCH_COLUMN, None) + # Return all updated inputs in the original order yield {self.data_column: inputs} diff --git a/python/ray/llm/_internal/batch/stages/vllm_engine_stage.py b/python/ray/llm/_internal/batch/stages/vllm_engine_stage.py index 8107be87475c..7cd7b75ba475 100644 --- a/python/ray/llm/_internal/batch/stages/vllm_engine_stage.py +++ b/python/ray/llm/_internal/batch/stages/vllm_engine_stage.py @@ -33,6 +33,21 @@ logger = logging.getLogger(__name__) +# vLLM fatal errors that should always be re-raised, never swallowed. +# EngineDeadError indicates the vLLM engine process has crashed and is +# unrecoverable - all subsequent requests would fail anyway. +_VLLM_FATAL_ERRORS: Tuple[Type[Exception], ...] = () +try: + from vllm.v1.engine.exceptions import EngineDeadError + + _VLLM_FATAL_ERRORS = (EngineDeadError,) +except ImportError: + # vLLM not installed or older version without this exception + pass + +# Length of prompt snippet to surface in case of recoverable error +_MAX_PROMPT_LENGTH_IN_ERROR = 500 + class vLLMTaskType(str, Enum): """The type of task to run on the vLLM engine.""" @@ -457,6 +472,7 @@ def __init__( task_type: vLLMTaskType = vLLMTaskType.GENERATE, max_pending_requests: Optional[int] = None, dynamic_lora_loading_path: Optional[str] = None, + should_continue_on_error: bool = False, ): """ Initialize the vLLMEngineStageUDF. @@ -471,9 +487,13 @@ def __init__( it will be set to 1.1 * max_num_seqs * pipeline_parallel_size. dynamic_lora_loading_path: The path to the dynamic LoRA adapter. It is expected to hold subfolders each for a different lora checkpoint. + should_continue_on_error: If True, continue processing when inference fails for + a row instead of raising. Failed rows will have '__inference_error__' + set to the error message. """ super().__init__(data_column, expected_input_keys) self.model = model + self.should_continue_on_error = should_continue_on_error # Setup vLLM engine kwargs. self.task_type = task_type @@ -565,6 +585,57 @@ def normalize_engine_kwargs( engine_kwargs["task"] = task_type return engine_kwargs + async def _generate_with_error_handling( + self, + row: Dict[str, Any], + batch_uuid: uuid.UUID, + ) -> Dict[str, Any]: + """Generate output for a single row, catching errors if should_continue_on_error is set. + + Args: + row: The input row. + batch_uuid: The batch UUID for logging. + + Returns: + The output dict, with __inference_error__ set if an error occurred. + """ + idx_in_batch = row[self.IDX_IN_BATCH_COLUMN] + try: + request, output, time_taken_llm = await self.llm.generate_async(row) + return { + **output, + "request_id": request.request_id, + self.IDX_IN_BATCH_COLUMN: request.idx_in_batch, + "batch_uuid": batch_uuid.hex, + "time_taken_llm": time_taken_llm, + "params": str(request.params), + "__inference_error__": None, + } + except _VLLM_FATAL_ERRORS: + # Fatal engine errors (e.g., EngineDeadError) must always propagate. + # The engine is dead and all subsequent requests would fail. + raise + except Exception as e: + if not self.should_continue_on_error: + raise + error_msg = f"{type(e).__name__}: {str(e)}" + logger.warning( + "[vLLM] Inference failed for row %d in batch %s: %s", + idx_in_batch, + batch_uuid.hex, + error_msg, + ) + # Include snippet of failed prompt + prompt = row.get("prompt", "") + if len(prompt) > _MAX_PROMPT_LENGTH_IN_ERROR: + prompt = prompt[:_MAX_PROMPT_LENGTH_IN_ERROR] + "...[truncated]" + return { + self.IDX_IN_BATCH_COLUMN: idx_in_batch, + "batch_uuid": batch_uuid.hex, + "__inference_error__": error_msg, + "prompt": prompt, + } + async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]: """Run the vLLM engine. @@ -577,19 +648,13 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] batch_uuid = uuid.uuid4() batch_start_time = time.perf_counter() - tasks = [asyncio.create_task(self.llm.generate_async(row)) for row in batch] + tasks = [ + asyncio.create_task(self._generate_with_error_handling(row, batch_uuid)) + for row in batch + ] for resp in asyncio.as_completed(tasks): - request, output, time_taken_llm = await resp - - yield { - **output, - "request_id": request.request_id, - self.IDX_IN_BATCH_COLUMN: request.idx_in_batch, - "batch_uuid": batch_uuid.hex, - "time_taken_llm": time_taken_llm, - "params": str(request.params), - } + yield await resp batch_time_taken = time.perf_counter() - batch_start_time # TODO: Add metrics to the UDf wrapper so that we don't need diff --git a/python/ray/llm/tests/batch/cpu/stages/test_stage_base.py b/python/ray/llm/tests/batch/cpu/stages/test_stage_base.py index 3bd73deef5ba..c66ee36c6369 100644 --- a/python/ray/llm/tests/batch/cpu/stages/test_stage_base.py +++ b/python/ray/llm/tests/batch/cpu/stages/test_stage_base.py @@ -39,6 +39,79 @@ def to_string(x: dict) -> dict: wrapped({"wrong_key": 42}) +def test_wrap_postprocess_bypasses_error_rows(): + """Error rows with __inference_error__ set bypass user postprocess.""" + + def user_fn(data: dict) -> dict: + # Would crash if called with error row (missing generated_text) + return {"response": data["generated_text"].upper()} + + wrapped = wrap_postprocess(user_fn, "__data") + + error_row = { + "__data": { + "__inference_error__": "ValueError: prompt too long", + "prompt": "This is a long prompt", + } + } + result = wrapped(error_row) + # Error rows return entire data dict to preserve debugging info + assert result == { + "__inference_error__": "ValueError: prompt too long", + "prompt": "This is a long prompt", + } + + +def test_wrap_postprocess_success_rows_run_postprocess(): + """Success rows (__inference_error__ is None) run user postprocess.""" + + def user_fn(data: dict) -> dict: + return {"response": data["generated_text"], "tokens": data["num_tokens"]} + + wrapped = wrap_postprocess(user_fn, "__data") + + success_row = { + "__data": { + "generated_text": "Hello world", + "num_tokens": 10, + "__inference_error__": None, + } + } + result = wrapped(success_row) + assert result == {"response": "Hello world", "tokens": 10} + + +def test_wrap_postprocess_include_error_column(): + """With include_error_column=True, success rows include __inference_error__: None.""" + + def user_fn(data: dict) -> dict: + return {"response": data["generated_text"]} + + wrapped = wrap_postprocess(user_fn, "__data", include_error_column=True) + + success_row = { + "__data": { + "generated_text": "Hello world", + "__inference_error__": None, + } + } + result = wrapped(success_row) + assert result == {"response": "Hello world", "__inference_error__": None} + + # Error rows return entire data dict to preserve debugging info + error_row = { + "__data": { + "__inference_error__": "ValueError: prompt too long", + "prompt": "a long prompt", + } + } + result = wrapped(error_row) + assert result == { + "__inference_error__": "ValueError: prompt too long", + "prompt": "a long prompt", + } + + class TestStatefulStageUDF: class SimpleUDF(StatefulStageUDF): def __init__( @@ -116,6 +189,75 @@ async def test_missing_idx_in_batch_column(self): async for _ in udf(batch): pass + @pytest.mark.asyncio + async def test_error_rows_bypass_udf(self): + """Error rows with __inference_error__ bypass the UDF entirely.""" + + class FailOnMissingValueUDF(StatefulStageUDF): + async def udf( + self, rows: list[Dict[str, Any]] + ) -> AsyncIterator[Dict[str, Any]]: + for row in rows: + # Would crash on error rows missing 'value' field + yield { + self.IDX_IN_BATCH_COLUMN: row[self.IDX_IN_BATCH_COLUMN], + "processed": row["value"] * 2, + } + + udf = FailOnMissingValueUDF( + data_column="__data", expected_input_keys=None # Skip validation + ) + + batch = { + "__data": [ + {"value": 1}, # Normal row + {"__inference_error__": "ValueError: prompt too long"}, # Error row + {"value": 3}, # Normal row + ] + } + + results = [] + async for result in udf(batch): + results.extend(result["__data"]) + + assert len(results) == 3 + + # Normal rows are processed + assert results[0]["processed"] == 2 + assert results[2]["processed"] == 6 + + # Error row passes through unchanged + assert results[1]["__inference_error__"] == "ValueError: prompt too long" + assert "processed" not in results[1] + + @pytest.mark.asyncio + async def test_all_error_rows_in_batch(self): + """Batch with all error rows should pass through without calling UDF.""" + + class FailIfCalledUDF(StatefulStageUDF): + async def udf( + self, rows: list[Dict[str, Any]] + ) -> AsyncIterator[Dict[str, Any]]: + raise AssertionError("UDF should not be called for all-error batch") + yield # Make this a generator + + udf = FailIfCalledUDF(data_column="__data", expected_input_keys=None) + + batch = { + "__data": [ + {"__inference_error__": "error 1"}, + {"__inference_error__": "error 2"}, + ] + } + + results = [] + async for result in udf(batch): + results.extend(result["__data"]) + + assert len(results) == 2 + assert results[0]["__inference_error__"] == "error 1" + assert results[1]["__inference_error__"] == "error 2" + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/batch/gpu/processor/test_vllm_engine_proc.py b/python/ray/llm/tests/batch/gpu/processor/test_vllm_engine_proc.py index 458967b86e11..3f83159f0cd5 100644 --- a/python/ray/llm/tests/batch/gpu/processor/test_vllm_engine_proc.py +++ b/python/ray/llm/tests/batch/gpu/processor/test_vllm_engine_proc.py @@ -55,6 +55,7 @@ def test_vllm_engine_processor(gpu_type, model_opt_125m): "dynamic_lora_loading_path": None, "max_concurrent_batches": 8, "batch_size": 64, + "should_continue_on_error": False, } runtime_env = stage.map_batches_kwargs.pop("runtime_env") diff --git a/python/ray/llm/tests/batch/gpu/stages/test_vllm_engine_stage.py b/python/ray/llm/tests/batch/gpu/stages/test_vllm_engine_stage.py index 5add5e6a6d4c..fa7158473cef 100644 --- a/python/ray/llm/tests/batch/gpu/stages/test_vllm_engine_stage.py +++ b/python/ray/llm/tests/batch/gpu/stages/test_vllm_engine_stage.py @@ -540,5 +540,149 @@ def test_vllm_output_data_no_logprobs(): assert dumped["prompt_logprobs"] is None +@pytest.mark.asyncio +async def test_vllm_udf_default_raises_on_error(mock_vllm_wrapper): + """Default behavior (should_continue_on_error=False) raises on inference error.""" + mock_vllm_wrapper.return_value.generate_async.side_effect = ValueError( + "prompt too long" + ) + + udf = vLLMEngineStageUDF( + data_column="__data", + expected_input_keys=["prompt", "sampling_params"], + model="/tmp/fake-model", + task_type=vLLMTaskType.GENERATE, + batch_size=32, + max_concurrent_batches=4, + engine_kwargs={}, + should_continue_on_error=False, + ) + + batch = {"__data": [{"prompt": "test", "sampling_params": {"temperature": 0.7}}]} + + with pytest.raises(ValueError, match="prompt too long"): + async for _ in udf(batch): + pass + + +@pytest.mark.asyncio +async def test_vllm_udf_should_continue_on_error_yields_error_row(mock_vllm_wrapper): + """With should_continue_on_error=True, errors yield rows with __inference_error__.""" + mock_vllm_wrapper.return_value.generate_async.side_effect = ValueError( + "prompt too long" + ) + + udf = vLLMEngineStageUDF( + data_column="__data", + expected_input_keys=["prompt", "sampling_params"], + model="/tmp/fake-model", + task_type=vLLMTaskType.GENERATE, + batch_size=32, + max_concurrent_batches=4, + engine_kwargs={}, + should_continue_on_error=True, + ) + + batch = { + "__data": [{"prompt": "test prompt", "sampling_params": {"temperature": 0.7}}] + } + + results = [] + async for result in udf(batch): + results.extend(result["__data"]) + + assert len(results) == 1 + assert "__inference_error__" in results[0] + assert "ValueError" in results[0]["__inference_error__"] + assert "prompt too long" in results[0]["__inference_error__"] + # Error rows include the original prompt for debuggability + assert results[0]["prompt"] == "test prompt" + + +@pytest.mark.asyncio +async def test_vllm_udf_mixed_success_and_error(mock_vllm_wrapper): + """Mixed batch: some rows succeed, some fail.""" + call_count = 0 + + async def mock_generate(row): + nonlocal call_count + call_count += 1 + idx = row["__idx_in_batch"] + if idx == 1: + raise ValueError("prompt too long") + return ( + MagicMock( + request_id=idx, + prompt=row["prompt"], + params=row["sampling_params"], + idx_in_batch=idx, + ), + { + "prompt": row["prompt"], + "generated_text": f"Response to: {row['prompt']}", + }, + 0.1, + ) + + mock_vllm_wrapper.return_value.generate_async.side_effect = mock_generate + + udf = vLLMEngineStageUDF( + data_column="__data", + expected_input_keys=["prompt", "sampling_params"], + model="/tmp/fake-model", + task_type=vLLMTaskType.GENERATE, + batch_size=32, + max_concurrent_batches=4, + engine_kwargs={}, + should_continue_on_error=True, + ) + + batch = { + "__data": [ + {"prompt": "first", "sampling_params": {"temperature": 0.7}}, + {"prompt": "second", "sampling_params": {"temperature": 0.7}}, + {"prompt": "third", "sampling_params": {"temperature": 0.7}}, + ] + } + + results = [] + async for result in udf(batch): + results.extend(result["__data"]) + + assert len(results) == 3 + + errors = [r for r in results if r.get("__inference_error__") is not None] + successes = [r for r in results if r.get("__inference_error__") is None] + + assert len(errors) == 1 + assert len(successes) == 2 + assert "ValueError" in errors[0]["__inference_error__"] + + +@pytest.mark.asyncio +async def test_vllm_udf_fatal_error_always_raises(mock_vllm_wrapper): + """Fatal errors (EngineDeadError) always propagate, even with should_continue_on_error=True.""" + from vllm.v1.engine.exceptions import EngineDeadError + + mock_vllm_wrapper.return_value.generate_async.side_effect = EngineDeadError() + + udf = vLLMEngineStageUDF( + data_column="__data", + expected_input_keys=["prompt", "sampling_params"], + model="/tmp/fake-model", + task_type=vLLMTaskType.GENERATE, + batch_size=32, + max_concurrent_batches=4, + engine_kwargs={}, + should_continue_on_error=True, # Even with this True, fatal errors should raise + ) + + batch = {"__data": [{"prompt": "test", "sampling_params": {"temperature": 0.7}}]} + + with pytest.raises(EngineDeadError): + async for _ in udf(batch): + pass + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__]))