Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions .github/workflows/pr-test-rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,56 @@ jobs:
env_vars: "SHOW_WORKER_LOGS=0 SHOW_ROUTER_LOGS=1"
reruns: "--reruns 2 --reruns-delay 5"
parallel_opts: "--workers 1 --tests-per-worker 4" # Thread-based parallelism
- name: chat-completions
- name: chat-completions-sglang
timeout: 45
test_dirs: "e2e_test/chat_completions"
extra_deps: ""
env_vars: "SHOW_WORKER_LOGS=0 SHOW_ROUTER_LOGS=1"
env_vars: "E2E_RUNTIME=sglang SHOW_WORKER_LOGS=0 SHOW_ROUTER_LOGS=1"
reruns: "--reruns 2 --reruns-delay 5"
parallel_opts: ""
test_filter: ""
- name: chat-completions-vllm
timeout: 45
test_dirs: "e2e_test/chat_completions"
extra_deps: ""
env_vars: "E2E_RUNTIME=vllm SHOW_WORKER_LOGS=0 SHOW_ROUTER_LOGS=1"
reruns: "--reruns 2 --reruns-delay 5"
parallel_opts: ""
# TODO: Remove filter when vLLM supports logprobs and n>1 with greedy sampling
# Excludes: 5-grpc (logprobs=5), 2-None-grpc (n=2 with no logprobs), multiple_choices (n=2 tests)
test_filter: "-k 'not (5-grpc or 2-None-grpc or multiple_choices)'"
setup_vllm: true
runs-on: 4-gpu-a10
timeout-minutes: ${{ matrix.timeout }}
steps:
- name: Checkout code
uses: actions/checkout@v6

- name: Clone SGLang repository
- name: Install inference backend (SGLang or vLLM)
run: |
git clone https://github.com/sgl-project/sglang.git
if [ "${{ matrix.setup_vllm }}" == "true" ]; then
echo "Installing vLLM for gRPC testing..."
python3 -m pip install vllm
else
echo "Installing SGLang..."
git clone https://github.com/sgl-project/sglang.git
cd sglang
sudo --preserve-env=PATH bash scripts/ci/cuda/ci_install_dependency.sh
fi
- name: Install SGLang dependencies
- name: Cache flash-attn build
if: matrix.setup_vllm
uses: actions/cache@v4
with:
path: ~/.cache/pip/wheels
key: flash-attn-${{ runner.os }}-py310-${{ hashFiles('**/setup.py') }}

- name: Install flash-attn for vLLM
if: matrix.setup_vllm
run: |
cd sglang
sudo --preserve-env=PATH bash scripts/ci/cuda/ci_install_dependency.sh
python3 -m pip install flash-attn --no-build-isolation
python3 -m pip uninstall -y flashinfer flashinfer-python flashinfer-cubin flashinfer-jit-cache 2>/dev/null || true
python3 -m pip install flashinfer-python==0.5.3 flashinfer-cubin==0.5.3
- name: Setup Oracle Instant Client
if: matrix.setup_oracle
Expand Down Expand Up @@ -288,7 +317,7 @@ jobs:
- name: Run E2E tests
run: |
bash scripts/ci_killall_sglang.sh "nuk_gpus"
${{ matrix.env_vars }} ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest ${{ matrix.reruns }} ${{ matrix.parallel_opts }} ${{ matrix.test_dirs }} -s -vv -o log_cli=true --log-cli-level=INFO
${{ matrix.env_vars }} ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest ${{ matrix.reruns }} ${{ matrix.parallel_opts }} ${{ matrix.test_dirs }} ${{ matrix.test_filter }} -s -vv -o log_cli=true --log-cli-level=INFO
- name: Upload benchmark results
if: matrix.upload_benchmarks && success()
Expand Down
26 changes: 16 additions & 10 deletions e2e_test/fixtures/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def pytest_collection_modifyitems(
WorkerType,
)


def track_worker(
model_id: str, mode: ConnectionMode, worker_type: WorkerType, count: int
) -> None:
Expand All @@ -118,21 +119,26 @@ def calculate_test_gpus(

for item in items:
# Extract model from marker or use default
# First check the class directly (handles inheritance correctly)
# Walk class MRO to prioritize child class markers over parent class
model_id = None
if hasattr(item, "cls") and item.cls is not None:
for marker in (
item.cls.pytestmark if hasattr(item.cls, "pytestmark") else []
):
if marker.name == PARAM_MODEL and marker.args:
model_id = marker.args[0]
# Walk the class MRO to find model marker
# MRO lists classes from most specific (child) to least specific (parent)
for cls in item.cls.__mro__:
if hasattr(cls, "pytestmark"):
markers = cls.pytestmark if isinstance(cls.pytestmark, list) else [cls.pytestmark]
for marker in markers:
if marker.name == PARAM_MODEL and marker.args:
model_id = marker.args[0]
break
if model_id:
break
# Fall back to get_closest_marker for method-level markers

# Fallback to get_closest_marker if not found via MRO
if model_id is None:
model_marker = item.get_closest_marker(PARAM_MODEL)
model_id = (
model_marker.args[0] if model_marker and model_marker.args else None
)
if model_marker and model_marker.args:
model_id = model_marker.args[0]

# Check parametrize for model
if model_id is None:
Expand Down
76 changes: 75 additions & 1 deletion e2e_test/fixtures/setup_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
if TYPE_CHECKING:
from infra import ModelPool

from infra import get_runtime, is_vllm

from .markers import get_marker_kwargs, get_marker_value

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,8 +104,25 @@ def test_chat(self, setup_backend):
is_local = False
connection_mode = None

# Local backends: use worker from pool + launch gateway
# Local backends: check runtime environment variable for gRPC mode
if is_local:
# For gRPC mode, check E2E_RUNTIME environment variable
if connection_mode == ConnectionMode.GRPC:
runtime = get_runtime()
logger.info(
"gRPC backend detected: E2E_RUNTIME=%s, routing to %s backend",
runtime,
"vLLM" if is_vllm() else "SGLang",
)

# Route to vLLM gRPC if runtime is vllm
if is_vllm():
yield from _setup_vllm_grpc_backend(
request, model_pool, model_id, workers_config, gateway_config
)
return

# Otherwise use regular local backend (sglang grpc or http)
yield from _setup_local_backend(
request,
model_pool,
Expand Down Expand Up @@ -262,6 +281,61 @@ def _setup_pd_backend(
worker.release()


def _setup_vllm_grpc_backend(
request: pytest.FixtureRequest,
model_pool: "ModelPool",
model_id: str,
workers_config: dict,
gateway_config: dict,
):
"""Setup vLLM gRPC backend."""
import openai
from infra import Gateway

logger.info("Setting up vLLM gRPC backend for model %s", model_id)

# vLLM currently only supports single worker per test
# get_vllm_grpc_worker() auto-acquires the returned instance
try:
instance = model_pool.get_vllm_grpc_worker(model_id)
except RuntimeError as e:
pytest.fail(str(e))

model_path = instance.model_path
worker_urls = [instance.worker_url]

# Launch gateway
gateway = Gateway()
gateway.start(
worker_urls=worker_urls,
model_path=model_path,
policy=gateway_config["policy"],
timeout=gateway_config["timeout"],
extra_args=gateway_config["extra_args"],
)

client = openai.OpenAI(
base_url=f"{gateway.base_url}/v1",
api_key="not-used",
)

logger.info(
"Setup vLLM gRPC backend: model=%s, worker=%s, gateway=%s, policy=%s",
model_id,
instance.worker_url,
gateway.base_url,
gateway_config["policy"],
)

try:
yield "grpc", model_path, client, gateway
finally:
logger.info("Tearing down vLLM gRPC gateway")
gateway.shutdown()
# Release reference to allow eviction
instance.release()


def _setup_local_backend(
request: pytest.FixtureRequest,
model_pool: "ModelPool",
Expand Down
11 changes: 11 additions & 0 deletions e2e_test/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
DEFAULT_HOST,
DEFAULT_MODEL,
DEFAULT_ROUTER_TIMEOUT,
DEFAULT_RUNTIME,
DEFAULT_STARTUP_TIMEOUT,
ENV_BACKENDS,
ENV_MODEL,
ENV_MODELS,
ENV_RUNTIME,
ENV_SHOW_ROUTER_LOGS,
ENV_SHOW_WORKER_LOGS,
ENV_SKIP_BACKEND_SETUP,
Expand All @@ -25,6 +27,9 @@
ConnectionMode,
Runtime,
WorkerType,
get_runtime,
is_sglang,
is_vllm,
)
from .gateway import Gateway, WorkerInfo, launch_cloud_gateway
from .gpu_allocator import (
Expand Down Expand Up @@ -82,6 +87,7 @@
# Defaults
"DEFAULT_MODEL",
"DEFAULT_HOST",
"DEFAULT_RUNTIME",
"DEFAULT_STARTUP_TIMEOUT",
"DEFAULT_ROUTER_TIMEOUT",
"HEALTH_CHECK_INTERVAL",
Expand All @@ -91,11 +97,16 @@
"ENV_MODELS",
"ENV_BACKENDS",
"ENV_MODEL",
"ENV_RUNTIME",
"ENV_STARTUP_TIMEOUT",
"ENV_SKIP_MODEL_POOL",
"ENV_SKIP_BACKEND_SETUP",
"ENV_SHOW_ROUTER_LOGS",
"ENV_SHOW_WORKER_LOGS",
# Runtime helpers
"get_runtime",
"is_vllm",
"is_sglang",
# GPU allocation
"GPUAllocator",
"GPUInfo",
Expand Down
40 changes: 40 additions & 0 deletions e2e_test/infra/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,53 @@ class Runtime(str, Enum):
# Default model
DEFAULT_MODEL = "llama-8b"

# Default runtime for gRPC tests
DEFAULT_RUNTIME = "sglang"

# Environment variable names
ENV_MODELS = "E2E_MODELS"
ENV_BACKENDS = "E2E_BACKENDS"
ENV_MODEL = "E2E_MODEL"
ENV_RUNTIME = "E2E_RUNTIME" # Runtime for gRPC tests: "sglang" or "vllm"
ENV_STARTUP_TIMEOUT = "E2E_STARTUP_TIMEOUT"
ENV_SKIP_MODEL_POOL = "SKIP_MODEL_POOL"
ENV_SKIP_BACKEND_SETUP = "SKIP_BACKEND_SETUP"


# Runtime detection helpers
_RUNTIME_CACHE = None


def get_runtime() -> str:
"""Get the current test runtime (sglang or vllm).
Returns:
Runtime name from E2E_RUNTIME environment variable, defaults to "sglang".
"""
global _RUNTIME_CACHE
if _RUNTIME_CACHE is None:
import os

_RUNTIME_CACHE = os.environ.get(ENV_RUNTIME, DEFAULT_RUNTIME)
return _RUNTIME_CACHE


def is_vllm() -> bool:
"""Check if tests are running with vLLM runtime.
Returns:
True if E2E_RUNTIME is "vllm", False otherwise.
"""
return get_runtime() == "vllm"


def is_sglang() -> bool:
"""Check if tests are running with SGLang runtime.
Returns:
True if E2E_RUNTIME is "sglang", False otherwise.
"""
return get_runtime() == "sglang"
ENV_SHOW_ROUTER_LOGS = "SHOW_ROUTER_LOGS"
ENV_SHOW_WORKER_LOGS = "SHOW_WORKER_LOGS"

Expand Down
Loading
Loading