-
Notifications
You must be signed in to change notification settings - Fork 7.3k
[Data][LLM] Add should_continue_on_error for graceful error handling in batch inference #59212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7ed8e0d
89245ae
6990498
0fb6aa8
cd16eeb
26c7b1a
678e4bd
7e3becb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,21 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # vLLM fatal errors that should always be re-raised, never swallowed. | ||
| # EngineDeadError indicates the vLLM engine process has crashed and is | ||
| # unrecoverable - all subsequent requests would fail anyway. | ||
| _VLLM_FATAL_ERRORS: Tuple[Type[Exception], ...] = () | ||
| try: | ||
| from vllm.v1.engine.exceptions import EngineDeadError | ||
|
|
||
| _VLLM_FATAL_ERRORS = (EngineDeadError,) | ||
| except ImportError: | ||
| # vLLM not installed or older version without this exception | ||
| pass | ||
|
|
||
| # Length of prompt snippet to surface in case of recoverable error | ||
| _MAX_PROMPT_LENGTH_IN_ERROR = 500 | ||
|
|
||
|
|
||
| class vLLMTaskType(str, Enum): | ||
| """The type of task to run on the vLLM engine.""" | ||
|
|
@@ -457,6 +472,7 @@ def __init__( | |
| task_type: vLLMTaskType = vLLMTaskType.GENERATE, | ||
| max_pending_requests: Optional[int] = None, | ||
| dynamic_lora_loading_path: Optional[str] = None, | ||
| should_continue_on_error: bool = False, | ||
| ): | ||
| """ | ||
| Initialize the vLLMEngineStageUDF. | ||
|
|
@@ -471,9 +487,13 @@ def __init__( | |
| it will be set to 1.1 * max_num_seqs * pipeline_parallel_size. | ||
| dynamic_lora_loading_path: The path to the dynamic LoRA adapter. It is expected | ||
| to hold subfolders each for a different lora checkpoint. | ||
| should_continue_on_error: If True, continue processing when inference fails for | ||
| a row instead of raising. Failed rows will have '__inference_error__' | ||
| set to the error message. | ||
| """ | ||
| super().__init__(data_column, expected_input_keys) | ||
| self.model = model | ||
| self.should_continue_on_error = should_continue_on_error | ||
|
|
||
| # Setup vLLM engine kwargs. | ||
| self.task_type = task_type | ||
|
|
@@ -565,6 +585,57 @@ def normalize_engine_kwargs( | |
| engine_kwargs["task"] = task_type | ||
| return engine_kwargs | ||
|
|
||
| async def _generate_with_error_handling( | ||
| self, | ||
| row: Dict[str, Any], | ||
| batch_uuid: uuid.UUID, | ||
| ) -> Dict[str, Any]: | ||
| """Generate output for a single row, catching errors if should_continue_on_error is set. | ||
|
|
||
| Args: | ||
| row: The input row. | ||
| batch_uuid: The batch UUID for logging. | ||
|
|
||
| Returns: | ||
| The output dict, with __inference_error__ set if an error occurred. | ||
| """ | ||
| idx_in_batch = row[self.IDX_IN_BATCH_COLUMN] | ||
| try: | ||
| request, output, time_taken_llm = await self.llm.generate_async(row) | ||
| return { | ||
| **output, | ||
| "request_id": request.request_id, | ||
| self.IDX_IN_BATCH_COLUMN: request.idx_in_batch, | ||
| "batch_uuid": batch_uuid.hex, | ||
| "time_taken_llm": time_taken_llm, | ||
| "params": str(request.params), | ||
| "__inference_error__": None, | ||
| } | ||
| except _VLLM_FATAL_ERRORS: | ||
| # Fatal engine errors (e.g., EngineDeadError) must always propagate. | ||
| # The engine is dead and all subsequent requests would fail. | ||
| raise | ||
| except Exception as e: | ||
| if not self.should_continue_on_error: | ||
| raise | ||
| error_msg = f"{type(e).__name__}: {str(e)}" | ||
| logger.warning( | ||
| "[vLLM] Inference failed for row %d in batch %s: %s", | ||
| idx_in_batch, | ||
| batch_uuid.hex, | ||
| error_msg, | ||
| ) | ||
| # Include snippet of failed prompt | ||
| prompt = row.get("prompt", "") | ||
| if len(prompt) > _MAX_PROMPT_LENGTH_IN_ERROR: | ||
| prompt = prompt[:_MAX_PROMPT_LENGTH_IN_ERROR] + "...[truncated]" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Prompt lost in error rows after early popThe error handling in |
||
| return { | ||
| self.IDX_IN_BATCH_COLUMN: idx_in_batch, | ||
| "batch_uuid": batch_uuid.hex, | ||
| "__inference_error__": error_msg, | ||
| "prompt": prompt, | ||
| } | ||
|
|
||
| async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]: | ||
| """Run the vLLM engine. | ||
|
|
||
|
|
@@ -577,19 +648,13 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] | |
| batch_uuid = uuid.uuid4() | ||
| batch_start_time = time.perf_counter() | ||
|
|
||
| tasks = [asyncio.create_task(self.llm.generate_async(row)) for row in batch] | ||
| tasks = [ | ||
| asyncio.create_task(self._generate_with_error_handling(row, batch_uuid)) | ||
| for row in batch | ||
| ] | ||
|
|
||
| for resp in asyncio.as_completed(tasks): | ||
| request, output, time_taken_llm = await resp | ||
|
|
||
| yield { | ||
| **output, | ||
| "request_id": request.request_id, | ||
| self.IDX_IN_BATCH_COLUMN: request.idx_in_batch, | ||
| "batch_uuid": batch_uuid.hex, | ||
| "time_taken_llm": time_taken_llm, | ||
| "params": str(request.params), | ||
| } | ||
| yield await resp | ||
|
|
||
| batch_time_taken = time.perf_counter() - batch_start_time | ||
| # TODO: Add metrics to the UDf wrapper so that we don't need | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Validation fails for error rows before bypass logic
The
validate_inputscall at line 171 runs on all rows before error rows are separated (lines 181-190). Error rows fromvLLMEngineStagelack required keys likegenerated_tokensthat downstream stages such asDetokenizeStageexpect. This causes validation to fail with "Required input keys not found" before error rows can be bypassed. The tests avoid this by passingexpected_input_keys=None, but in production theDetokenizeStagehas required keys. The validation needs to skip or be called after error row separation.Additional Locations (1)
python/ray/llm/_internal/batch/stages/base.py#L180-L190