Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions python/ray/llm/_internal/batch/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pydantic import Field, field_validator, model_validator

import ray
from ray.data import Dataset
from ray.data.block import UserDefinedFunction
from ray.llm._internal.batch.stages import (
Expand Down Expand Up @@ -331,14 +330,6 @@ def __init__(
self.postprocess_map_kwargs = postprocess_map_kwargs or {}
self.stages: OrderedDict[str, StatefulStage] = OrderedDict()

# FIXES: https://github.com/ray-project/ray/issues/53124
# TODO (Kourosh): Remove this once the issue is fixed
data_context = ray.data.DataContext.get_current()
data_context.wait_for_min_actors_s = 600
# TODO: Remove this when https://github.com/ray-project/ray/issues/53169
# is fixed.
data_context._enable_actor_pool_on_exit_hook = True

# NOTE (Kourosh): If pre/postprocess is not provided, use the identity function.
# Wrapping is required even if they are identity functions, b/c data_column
# gets inserted/removed via wrap_preprocess/wrap_postprocess.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Test that Ray Data LLM does not override wait_for_min_actors_s.

With default settings (wait_for_min_actors_s <= 0), processing starts
as soon as any actor is ready, regardless of concurrency config.
"""
import sys

import pytest

from ray.data import DataContext
from ray.llm._internal.batch.processor import ProcessorBuilder
from ray.llm._internal.batch.processor.vllm_engine_proc import vLLMEngineProcessorConfig


@pytest.fixture(autouse=True)
def reset_data_context():
"""Reset DataContext before and after each test."""
ctx = DataContext.get_current()
original_value = ctx.wait_for_min_actors_s
ctx.wait_for_min_actors_s = -1
yield
ctx.wait_for_min_actors_s = original_value


class TestWaitForMinActorsNotOverridden:
"""Test that Processor does not override wait_for_min_actors_s."""

def test_processor_does_not_override_default(self):
"""Processor should not change wait_for_min_actors_s from default."""
ctx = DataContext.get_current()
ctx.wait_for_min_actors_s = -1

config = vLLMEngineProcessorConfig(
model_source="facebook/opt-125m",
concurrency=4,
)
ProcessorBuilder.build(config)

assert ctx.wait_for_min_actors_s == -1

@pytest.mark.parametrize("user_value", [60, 600, 1800])
def test_processor_preserves_user_setting(self, user_value):
"""Processor should preserve user-set wait_for_min_actors_s."""
ctx = DataContext.get_current()
ctx.wait_for_min_actors_s = user_value

config = vLLMEngineProcessorConfig(
model_source="facebook/opt-125m",
concurrency=4,
)
ProcessorBuilder.build(config)

assert ctx.wait_for_min_actors_s == user_value


class TestConcurrencyConfigPassthrough:
"""
Test that concurrency config correctly sets ActorPoolStrategy.

This determines blocking behavior when wait_for_min_actors_s > 0:
- concurrency=N → min_size=N → blocks for N actors
- concurrency=(1, N) → min_size=1 → blocks for 1 actor
"""

@pytest.mark.parametrize(
"concurrency,expected_min_size,expected_max_size",
[
(4, 4, 4), # int: fixed pool
((1, 4), 1, 4), # tuple: autoscaling pool
((2, 8), 2, 8), # tuple: custom min
],
ids=["int_concurrency", "tuple_1_to_n", "tuple_custom_min"],
)
def test_concurrency_to_actor_pool_strategy(
self, concurrency, expected_min_size, expected_max_size
):
"""Verify concurrency config maps to correct ActorPoolStrategy."""
config = vLLMEngineProcessorConfig(
model_source="facebook/opt-125m",
concurrency=concurrency,
)
processor = ProcessorBuilder.build(config)

# Get the vLLM stage and check its compute strategy
stage = processor.get_stage_by_name("vLLMEngineStage")
compute = stage.map_batches_kwargs.get("compute")

assert (
compute.min_size == expected_min_size
), f"Expected min_size={expected_min_size}, got {compute.min_size}"
assert (
compute.max_size == expected_max_size
), f"Expected max_size={expected_max_size}, got {compute.max_size}"


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))