Skip to content

[Data][LLM] Add should_continue_on_error for graceful error handling in batch inference#59212

Merged
kouroshHakha merged 8 commits intoray-project:masterfrom
nrghosh:nrghosh/data-llm-error-handling
Dec 10, 2025
Merged

[Data][LLM] Add should_continue_on_error for graceful error handling in batch inference#59212
kouroshHakha merged 8 commits intoray-project:masterfrom
nrghosh:nrghosh/data-llm-error-handling

Conversation

@nrghosh
Copy link
Contributor

@nrghosh nrghosh commented Dec 6, 2025

Description

Add should_continue_on_error parameter to vLLM batch processor config. When
enabled, even nonfatal inference failures yield error rows instead of crashing the job.

Scoped to Ray Data LLM batch inference only; no changes to Ray Data core.

Addresses: #52449
Related: #52457

Problem

When running LLM batch inference at scale, a single bad row (e.g., prompt
exceeding max_model_len) can crash the whole batch.

vLLM AsyncLLM engine distinguishes two types of errors: EngineGenerateError and EngineDeadError.

EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).

EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.

Solution

Add optional should_continue_on_error parameter to processor config. The
parameter defaults to False, preserving existing fail-fast behavior.
When set to True:

  • Catch exceptions in vLLMEngineStageUDF
  • Failed rows yield with __inference_error__ column set to error message
  • Successful rows have __inference_error__: None
  • Error rows bypass postprocess (avoids crashes when expected fields missing)
  • Job completes with mixed success/failure outputs
  • Users filter downstream: ds.filter(lambda r: r["__inference_error__"] is None)

Design

  1. Default behavior unchanged: should_continue_on_error=False preserves
    existing fail-fast semantics. This is opt-in only.

  2. Error rows bypass postprocess: User's postprocess function likely
    expects generated_text and other output fields. Error rows won't have
    these, so we skip postprocess to avoid secondary crashes.

  3. Error as Optional[str]: The __inference_error__ column is None on
    success, or contains the error message (with type) on failure. This
    provides debuggability while keeping schema simple.

  4. LLM operator only: Per feedback, this is scoped to the LLM processor
    implementation. No changes to Ray Data core primitives.

Questions

  • This primarily takes into consideration vLLM's abstractions for request vs fatal errors. Extend beyond?
  • Should inference_error always be present, or only when should_continue_on_error=True?
  • Downstream stage handling- fix in base class or per stage
  • Prefer error row content minimal (current) vs. include original input for debuggability?

Tradeoffs

  • Visible failures: Prefer visible errors (via column) over silently dropped rows for observability.
  • Schema change: All outputs include inference_error. Enables success/failure distinction and debugging, at the cost of an extra column even for users who don’t need error handling.
  • Scope (vLLM only): Only vLLM-stage errors are captured; preprocessing errors (chat template, tokenize) still crash the job. Keeps implementation focused but yields inconsistent failure modes.
  • Postprocess bypass: Error rows skip the user’s postprocess function. Prevents secondary crashes but removes user control over error-row formatting.
  • Unstructured errors: Errors are stored as strings, not typed objects or anything fancy
  • No retries: Failed rows are just marked

Files Changed

  • python/ray/data/llm.py - Document new parameter in public API
  • python/ray/llm/_internal/batch/processor/base.py - Add should_continue_on_error to config
  • python/ray/llm/_internal/batch/processor/vllm_engine_proc.py - Pass config to stage
  • python/ray/llm/_internal/batch/stages/base.py - Skip postprocess for error rows
  • python/ray/llm/_internal/batch/stages/vllm_engine_stage.py - Catch errors and yield error rows

Example Usage

from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor

config = vLLMEngineProcessorConfig(
    model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
    should_continue_on_error=True,  # Enable graceful error handling
)

processor = build_llm_processor(
    config,
    preprocess=lambda row: dict(
        messages=[{"role": "user", "content": row["prompt"]}],
        sampling_params=dict(temperature=0.3, max_tokens=100),
    ),
    postprocess=lambda row: dict(
        response=row["generated_text"],
    ),
)

ds = ray.data.read_json("prompts.json")
result = processor(ds)

# Filter successful results
successful = result.filter(lambda r: r["__inference_error__"] is None)
successful.write_json("outputs/")

# Analyze failures
failed = result.filter(lambda r: r["__inference_error__"] is not None)
print(f"Failed: {failed.count()} rows")
failed.show(5)

Tests

python/ray/llm/tests/batch/cpu/stages/test_stage_base.py:

  • test_wrap_postprocess_bypasses_error_rows
  • test_wrap_postprocess_success_rows_run_postprocess

python/ray/llm/tests/batch/gpu/stages/test_vllm_engine_stage.py:

  • test_vllm_udf_default_raises_on_error
  • test_vllm_udf_should_continue_on_error_yields_error_row
  • test_vllm_udf_mixed_success_and_error

…h inference

Add `continue_on_error` parameter to vLLM batch processor config. When
enabled, inference failures yield error rows instead of crashing the job.

Scoped to Ray Data LLM batch inference only; no changes to Ray Data core.

Addresses: ray-project#52449
Related: ray-project#52457

When running LLM batch inference at scale, a single bad row (e.g., prompt
exceeding max_model_len) can crash the whole batch.

Add optional `continue_on_error` parameter to processor config. The
parameter defaults to False, preserving existing fail-fast behavior.
When set to True:
- Catch exceptions in vLLMEngineStageUDF
- Failed rows yield with `__inference_error__` column set to error message
- Successful rows have `__inference_error__: None`
- Error rows bypass postprocess (avoids crashes when expected fields missing)
- Job completes with mixed success/failure outputs
- Users filter downstream: ds.filter(lambda r: r["__inference_error__"] is None)

1. **Default behavior unchanged**: `continue_on_error=False` preserves
   existing fail-fast semantics. This is opt-in only.

2. **Error rows bypass postprocess**: User's postprocess function likely
   expects `generated_text` and other output fields. Error rows won't have
   these, so we skip postprocess to avoid secondary crashes.

3. **Error as Optional[str]**: The `__inference_error__` column is None on
   success, or contains the error message (with type) on failure. This
   provides debuggability while keeping schema simple.

4. **LLM operator only**: Per feedback, this is scoped to the LLM processor
   implementation. No changes to Ray Data core primitives.

- **Silent vs visible failures**: Choose visible failures (error column)
  over silent dropping for observability.

- **Schema addition**: All outputs now include `__inference_error__` column.
  This is necessary for users to distinguish success from failure and debug.

- **No retry mechanism**: Retrying and auto-tuning is outside the scope of
  this PR.

---

- `python/ray/data/llm.py` - Document new parameter in public API
- `python/ray/llm/_internal/batch/processor/base.py` - Add `continue_on_error` to config
- `python/ray/llm/_internal/batch/processor/vllm_engine_proc.py` - Pass config to stage
- `python/ray/llm/_internal/batch/stages/base.py` - Skip postprocess for error rows
- `python/ray/llm/_internal/batch/stages/vllm_engine_stage.py` - Catch errors and yield error rows

```python
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor

config = vLLMEngineProcessorConfig(
    model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
    continue_on_error=True,  # Enable graceful error handling
)

processor = build_llm_processor(
    config,
    preprocess=lambda row: dict(
        messages=[{"role": "user", "content": row["prompt"]}],
        sampling_params=dict(temperature=0.3, max_tokens=100),
    ),
    postprocess=lambda row: dict(
        response=row["generated_text"],
    ),
)

ds = ray.data.read_json("prompts.json")
result = processor(ds)

successful = result.filter(lambda r: r["__inference_error__"] is None)
successful.write_json("outputs/")

failed = result.filter(lambda r: r["__inference_error__"] is not None)
print(f"Failed: {failed.count()} rows")
failed.show(5)
```

`python/ray/llm/tests/batch/cpu/stages/test_stage_base.py`:
- `test_wrap_postprocess_bypasses_error_rows`
- `test_wrap_postprocess_success_rows_run_postprocess`

`python/ray/llm/tests/batch/gpu/stages/test_vllm_engine_stage.py`:
- `test_vllm_udf_default_raises_on_error`
- `test_vllm_udf_continue_on_error_yields_error_row`
- `test_vllm_udf_mixed_success_and_error`

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a continue_on_error parameter for graceful error handling during batch inference, which is a great addition for robustness at scale. The implementation correctly isolates the error handling within the vLLM engine stage and provides a mechanism to flag failed rows using the __inference_error__ column.

My review focuses on improving the debuggability of failed requests and ensuring schema consistency in the output. Specifically, I've suggested:

  1. Preserving more context in error rows to make it easier to identify the problematic input.
  2. Ensuring the __inference_error__ column is present in all output rows (both success and failure) for a consistent schema, as described in the PR description.
  3. Including request parameters in the error output for more complete debugging information.

Overall, the changes are well-structured and the addition of tests is thorough. Addressing these points will make the feature even more user-friendly and robust.

@nrghosh nrghosh requested review from a team and richardliaw December 6, 2025 02:39
nrghosh and others added 3 commits December 6, 2025 16:25
- expect continue_on_error to (1) exist, and (2) be false (default)

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
- Distinguish fatal (EngineDeadError) vs recoverable errors; fatal errors
  always propagate even when continue_on_error=True
- Error rows bypass downstream stage UDFs to prevent crashes when expected
  fields (e.g., generated_tokens) are missing
- Include original prompt in error rows for debuggability
- Add __inference_error__ column to success rows when continue_on_error=True
  for consistent output schema; no schema change when False (backwards compatible)
- Add tests for fatal error propagation, error row bypass, and schema consistency

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
@nrghosh
Copy link
Contributor Author

nrghosh commented Dec 8, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a continue_on_error flag for graceful error handling in vLLM batch inference, which is an excellent addition for improving the robustness of large-scale inference jobs. The implementation is well-structured, creating error rows with an __inference_error__ column instead of crashing. My review identified a high-severity issue where debugging information added to error rows is later stripped out during post-processing. I've provided a suggestion to fix this, along with corresponding updates to the tests to ensure the debugging information is preserved as intended.

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
@nrghosh
Copy link
Contributor Author

nrghosh commented Dec 8, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a continue_on_error parameter for vLLM batch inference in Ray Data. This allows inference jobs to continue even if some rows fail, by yielding error rows with an __inference_error__ column instead of crashing. The changes are well-structured, with error handling logic localized to the vLLM stage and a generic mechanism for bypassing error rows in subsequent stages. The implementation correctly distinguishes between recoverable and fatal vLLM errors. The changes are also well-covered by new unit tests. My main feedback is a minor suggestion to replace a magic number with a constant for better code maintainability.

Comment on lines +620 to +621
if len(prompt) > 500:
prompt = prompt[:500] + "...[truncated]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The prompt truncation length 500 is a magic number. It's better to define it as a constant at the module level to improve readability and maintainability. For example: _MAX_PROMPT_LENGTH_IN_ERROR = 500.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. just some minor comments

Comment on lines +620 to +621
if len(prompt) > 500:
prompt = prompt[:500] + "...[truncated]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

@kouroshHakha kouroshHakha marked this pull request as ready for review December 9, 2025 05:58
@kouroshHakha kouroshHakha requested a review from a team as a code owner December 9, 2025 05:58
# Assign the index of the row in the batch to the idx_in_batch_column.
# This is beacuse the UDF output may be out-of-order (if asyncio.as_completed
# is used interanlly for example), and we need to carry over unused input
# This is because the UDF output may be out-of-order (if asyncio.as_completed
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Validation fails for error rows before bypass logic

The validate_inputs call at line 171 runs on all rows before error rows are separated (lines 181-190). Error rows from vLLMEngineStage lack required keys like generated_tokens that downstream stages such as DetokenizeStage expect. This causes validation to fail with "Required input keys not found" before error rows can be bypassed. The tests avoid this by passing expected_input_keys=None, but in production the DetokenizeStage has required keys. The validation needs to skip or be called after error row separation.

Additional Locations (1)

Fix in Cursor Fix in Web

# to avoid bloating the output.
prompt = row.get("prompt", "")
if len(prompt) > 500:
prompt = prompt[:500] + "...[truncated]"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Error handler crashes if prompt is None

The error handler at line 619 uses row.get("prompt", "") to retrieve the prompt for debugging. However, Python's dict.get() only returns the default when the key is missing - if the key exists with value None, it returns None. Then len(prompt) at line 620 raises TypeError: object of type 'NoneType' has no len(). This crashes inside the except Exception block, causing the job to fail despite continue_on_error=True, which defeats the purpose of graceful error handling.

Fix in Cursor Fix in Web

@ray-gardener ray-gardener bot added data Ray Data-related issues llm labels Dec 9, 2025
Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM just a few non-blockers?

"""
super().__init__(data_column, expected_input_keys)
self.model = model
self.continue_on_error = continue_on_error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we renamed this parameter to should _continue_on_error globally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, done.

"time_taken_llm": time_taken_llm,
"params": str(request.params),
}
output = await resp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yield await resp ??

nrghosh and others added 2 commits December 9, 2025 15:40
@nrghosh nrghosh changed the title [Data][LLM] Add continue_on_error for graceful error handling in batch inference [Data][LLM] Add should_continue_on_error for graceful error handling in batch inference Dec 9, 2025
@nrghosh nrghosh requested a review from kouroshHakha December 9, 2025 23:52
# Include snippet of failed prompt
prompt = row.get("prompt", "")
if len(prompt) > _MAX_PROMPT_LENGTH_IN_ERROR:
prompt = prompt[:_MAX_PROMPT_LENGTH_IN_ERROR] + "...[truncated]"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Prompt lost in error rows after early pop

The error handling in _generate_with_error_handling tries to capture the prompt via row.get("prompt", "") on line 629. However, generate_async internally calls _prepare_llm_request which pops the prompt from the row (prompt = row.pop("prompt") on line 326) before the actual LLM call. When an error occurs after the prompt is popped (e.g., during vLLM engine processing), the error row will have an empty string for prompt instead of the actual prompt content, defeating the purpose of including it for debugging. The prompt needs to be captured before calling generate_async.

Fix in Cursor Fix in Web

Copy link
Contributor Author

@nrghosh nrghosh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kouroshHakha addressed comments, thanks

will add support for serve handles <> data in a separate PR

cc @richardliaw

@kouroshHakha kouroshHakha merged commit b7c5f06 into ray-project:master Dec 10, 2025
7 checks passed
nrghosh added a commit to nrghosh/ray that referenced this pull request Dec 11, 2025
…tage

Add continue_on_error parameter to ServeDeploymentProcessorConfig. When
enabled, inference failures yield error rows instead of crashing the job.

Changes:
- Add should_continue_on_error to ServeDeploymentProcessorConfig
- Add error handling to ServeDeploymentStageUDF with fatal error detection
- Wire config through build_serve_deployment_processor
- Update base.py: wrap_postprocess bypasses error rows, StatefulStageUDF
  skips UDF for error rows, include_error_column for schema consistency
- Add tests for error handling in serve deployment stage and base stage

Fatal errors (RayActorError, BackPressureError, DeploymentUnavailableError)
always propagate since the replica/deployment is dead. Non-fatal errors
(e.g., ValueError from vLLM) yield error rows with __inference_error__ set.

Review questions:
- Should should_continue_on_error move to base ProcessorConfig?
- BackPressureError could benefit from retry logic (out of scope here)

Addresses: ray-project#59325
Related: ray-project#59212, ray-project#52449

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
nrghosh added a commit to nrghosh/ray that referenced this pull request Dec 11, 2025
…tage

Extends ray-project#59212 to support ServeDeploymentStage. When enabled, inference
failures yield error rows with __inference_error__ set instead of crashing.

Changes:
- Add should_continue_on_error to ServeDeploymentProcessorConfig
- Add error handling to ServeDeploymentStageUDF with fatal error detection
- Fatal errors (RayActorError, BackPressureError, DeploymentUnavailableError)
  always propagate; non-fatal errors yield error rows
- Add tests for error handling behavior

Addresses: ray-project#59325
Related: ray-project#59212

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
nrghosh added a commit to nrghosh/ray that referenced this pull request Dec 11, 2025
…tage

## Why is this change needed?

This is a follow-up to ray-project#59212 which added graceful error handling for
vLLMEngineProcessorConfig. ServeDeploymentStage also needs error handling
to prevent entire batch jobs from crashing due to single bad rows.

When processing large datasets via Ray Serve deployments, a single malformed
request (e.g., invalid parameters, too-long prompt) would crash the entire
job. Users want the ability to:
1. Continue processing despite individual row failures
2. Receive error information for failed rows for debugging/retry
3. Distinguish between recoverable errors (single row) and fatal errors
   (replica/deployment crash)

## Changes

**ServeDeploymentProcessorConfig**:
- Add `should_continue_on_error: bool = False` field

**ServeDeploymentStageUDF**:
- Add `_generate_with_error_handling()` that catches non-fatal exceptions
  and yields error rows with `__inference_error__` set
- Add `_is_fatal_error()` to distinguish fatal vs recoverable errors
- Include truncated `request_kwargs` in error rows for debuggability

**Fatal Error Detection**:
- `RayActorError`: Replica crashed or unavailable
- `BackPressureError`: System overloaded (note: could benefit from retry)
- `DeploymentUnavailableError`: Deployment failed to deploy
- Also handles `RayTaskError` wrapping by inspecting `.cause`

Non-fatal errors (e.g., ValueError from invalid params) yield error rows
when `should_continue_on_error=True`; fatal errors always propagate.

## Tests

- test_serve_udf_default_raises_on_error
- test_serve_udf_continue_on_error_yields_error_row
- test_serve_udf_mixed_success_and_error
- test_serve_udf_fatal_errors_always_propagate (parametrized for all 3 types)
- test_serve_udf_success_with_continue_on_error_includes_none_error

Related to ray-project#59325
nrghosh added a commit to nrghosh/ray that referenced this pull request Jan 14, 2026
Update test_vllm_udf_fatal_error_always_raises to verify that fatal
errors (EngineDeadError) now trigger ray.actor.exit_actor() for
recovery instead of simply re-raising.

The original intent (PR ray-project#59212) was that fatal errors should NOT be
swallowed by should_continue_on_error. This is preserved - fatal
errors still don't yield error rows. The change is that instead of
re-raising (which caused infinite retry loops on the same broken
actor), we now exit the actor to enable recovery.

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
…in batch inference (ray-project#59212)

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data Ray Data-related issues go add ONLY when ready to merge, run all tests llm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants