diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 70b476acacfa..f39d3f6fcbe6 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -21,36 +21,19 @@ env: SCCACHE_GHA_ENABLED: "true" jobs: - maturin-build-test: + build-wheel: if: | github.event_name != 'pull_request' || (github.event.action != 'labeled' && contains(github.event.pull_request.labels.*.name, 'run-ci')) || (github.event.action == 'labeled' && github.event.label.name == 'run-ci') - runs-on: ubuntu-latest + runs-on: 4-gpu-a10 steps: - - uses: actions/checkout@v4 - with: - path: sglang-repo - - - name: Move sgl-model-gateway folder to root - run: | - mv sglang-repo/sgl-model-gateway/* . - rm -rf sglang-repo - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.13" + - name: Checkout code + uses: actions/checkout@v4 - - name: Install protoc and dependencies + - name: Install rust dependencies run: | - sudo apt-get update - sudo apt-get install -y wget unzip gcc g++ perl make - cd /tmp - wget https://github.com/protocolbuffers/protobuf/releases/download/v32.0/protoc-32.0-linux-x86_64.zip - sudo unzip protoc-32.0-linux-x86_64.zip -d /usr/local - rm protoc-32.0-linux-x86_64.zip - protoc --version + bash scripts/ci/ci_install_rust.sh - name: Configure sccache uses: mozilla-actions/sccache-action@v0.0.9 @@ -61,29 +44,63 @@ jobs: - name: Rust cache uses: Swatinem/rust-cache@v2 with: - workspaces: "." + workspaces: sgl-model-gateway shared-key: "rust-cache" cache-all-crates: true cache-on-failure: true save-if: true - - name: Test maturin build - uses: PyO3/maturin-action@v1 - with: - working-directory: bindings/python - args: --release --out dist --features vendored-openssl - rust-toolchain: stable - sccache: true + - name: Build python binding + run: | + source "$HOME/.cargo/env" + export RUSTC_WRAPPER=sccache + cd sgl-model-gateway/bindings/python + python3 -m pip install --upgrade pip maturin + maturin build --profile ci --features vendored-openssl --out dist - name: List built wheel - run: ls -lh bindings/python/dist/ + run: ls -lh sgl-model-gateway/bindings/python/dist/ + + - name: Upload wheel artifact + uses: actions/upload-artifact@v4 + with: + name: smg-wheel + path: sgl-model-gateway/bindings/python/dist/*.whl + retention-days: 1 - name: Test wheel install run: | - pip install bindings/python/dist/*.whl - python -c "import sglang_router; print('Python package: OK')" - python -c "from sglang_router.sglang_router_rs import Router; print('Rust extension: OK')" - python -m sglang_router.launch_router --help > /dev/null && echo "Entry point: OK" + pip install sgl-model-gateway/bindings/python/dist/*.whl + python3 -c "import sglang_router; print('Python package: OK')" + python3 -c "from sglang_router.sglang_router_rs import Router; print('Rust extension: OK')" + python3 -m sglang_router.launch_router --help > /dev/null && echo "Entry point: OK" + + python-unit-tests: + needs: build-wheel + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + path: sglang-repo + + - name: Move sgl-model-gateway folder to root + run: | + mv sglang-repo/sgl-model-gateway/* . + rm -rf sglang-repo + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Download wheel artifact + uses: actions/download-artifact@v4 + with: + name: smg-wheel + path: dist/ + + - name: Install wheel + run: pip install dist/*.whl - name: Run Python unit tests run: | @@ -157,6 +174,7 @@ jobs: gateway-e2e: name: ${{ matrix.name }} + needs: build-wheel if: | github.event_name != 'pull_request' || (github.event.action != 'labeled' && contains(github.event.pull_request.labels.*.name, 'run-ci')) || @@ -196,31 +214,19 @@ jobs: env_vars: "SHOW_WORKER_LOGS=0 SHOW_ROUTER_LOGS=1" reruns: "--reruns 2 --reruns-delay 5" parallel_opts: "--workers 1 --tests-per-worker 4" # Thread-based parallelism + - name: chat-completions + timeout: 45 + test_dirs: "e2e_test/chat_completions" + extra_deps: "" + env_vars: "SHOW_WORKER_LOGS=0 SHOW_ROUTER_LOGS=1" + reruns: "--reruns 2 --reruns-delay 5" + parallel_opts: "" runs-on: 4-gpu-a10 timeout-minutes: ${{ matrix.timeout }} steps: - name: Checkout code uses: actions/checkout@v4 - - name: Install rust dependencies - run: | - bash scripts/ci/ci_install_rust.sh - - - name: Configure sccache - uses: mozilla-actions/sccache-action@v0.0.9 - with: - version: "v0.12.0" - disable_annotations: true - - - name: Rust cache - uses: Swatinem/rust-cache@v2 - with: - workspaces: sgl-model-gateway - shared-key: "rust-cache" - cache-all-crates: true - cache-on-failure: true - save-if: true - - name: Install SGLang dependencies run: | sudo --preserve-env=PATH bash scripts/ci/ci_install_dependency.sh @@ -268,15 +274,16 @@ jobs: sleep 2 curl -f --max-time 1 http://localhost:8001/sse > /dev/null 2>&1 && echo "Brave MCP Server is healthy!" || echo "Brave MCP Server responded" - - name: Build python binding + - name: Download wheel artifact + uses: actions/download-artifact@v4 + with: + name: smg-wheel + path: wheel/ + + - name: Install wheel run: | - source "$HOME/.cargo/env" - export RUSTC_WRAPPER=sccache - cd sgl-model-gateway/bindings/python - python3 -m pip install --upgrade pip maturin pip uninstall -y sglang-router || true - maturin build --profile ci --features vendored-openssl --out dist - pip install dist/*.whl + pip install wheel/*.whl - name: Install e2e test dependencies run: | @@ -289,7 +296,6 @@ jobs: run: | bash scripts/killall_sglang.sh "nuk_gpus" cd sgl-model-gateway - source "$HOME/.cargo/env" ${{ matrix.env_vars }} ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest ${{ matrix.reruns }} ${{ matrix.parallel_opts }} ${{ matrix.test_dirs }} -s -vv -o log_cli=true --log-cli-level=INFO - name: Upload benchmark results @@ -335,7 +341,7 @@ jobs: cache-to: type=gha,mode=max finish: - needs: [maturin-build-test, unit-tests, gateway-e2e, docker-build-test] + needs: [build-wheel, python-unit-tests, unit-tests, gateway-e2e, docker-build-test] runs-on: ubuntu-latest steps: - name: Finish diff --git a/sgl-model-gateway/e2e_test/chat_completions/test_openai_server.py b/sgl-model-gateway/e2e_test/chat_completions/test_openai_server.py new file mode 100644 index 000000000000..676f9d98a794 --- /dev/null +++ b/sgl-model-gateway/e2e_test/chat_completions/test_openai_server.py @@ -0,0 +1,316 @@ +"""Chat Completions API E2E Tests - OpenAI Server Compatibility. + +Tests for OpenAI-compatible chat completions API through the gateway. + +Source: Migrated from e2e_grpc/basic/test_openai_server.py +""" + +from __future__ import annotations + +import json +import logging + +import pytest + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Chat Completion Tests (Llama 8B) +# ============================================================================= + + +@pytest.mark.model("llama-8b") +@pytest.mark.gateway(extra_args=["--history-backend", "memory"]) +@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) +class TestChatCompletion: + """Tests for OpenAI-compatible chat completions API.""" + + @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize("parallel_sample_num", [1, 2]) + def test_chat_completion(self, setup_backend, logprobs, parallel_sample_num): + """Test non-streaming chat completion with logprobs and parallel sampling.""" + _, model, client, gateway = setup_backend + self._run_chat_completion(client, model, logprobs, parallel_sample_num) + + @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize("parallel_sample_num", [1, 2]) + def test_chat_completion_stream(self, setup_backend, logprobs, parallel_sample_num): + """Test streaming chat completion with logprobs and parallel sampling.""" + _, model, client, gateway = setup_backend + self._run_chat_completion_stream(client, model, logprobs, parallel_sample_num) + + def test_regex(self, setup_backend): + """Test structured output with regex constraint.""" + _, model, client, gateway = setup_backend + + regex = ( + r"""\{\n""" + + r""" "name": "[\w]+",\n""" + + r""" "population": [\d]+\n""" + + r"""\}""" + ) + + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=128, + extra_body={"regex": regex}, + ) + text = response.choices[0].message.content + + try: + js_obj = json.loads(text) + except (TypeError, json.decoder.JSONDecodeError): + raise + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) + + def test_penalty(self, setup_backend): + """Test frequency penalty parameter.""" + _, model, client, gateway = setup_backend + + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=32, + frequency_penalty=1.0, + ) + text = response.choices[0].message.content + assert isinstance(text, str) + + def test_response_prefill(self, setup_backend): + """Test assistant message prefill with continue_final_message.""" + _, model, client, gateway = setup_backend + + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": """ +Extract the name, size, price, and color from this product description as a JSON object: + + +The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices. + +""", + }, + { + "role": "assistant", + "content": "{\n", + }, + ], + temperature=0, + extra_body={"continue_final_message": True}, + ) + + assert ( + response.choices[0] + .message.content.strip() + .startswith('"name": "SmartHome Mini",') + ) + + def test_model_list(self, setup_backend): + """Test listing available models.""" + _, model, client, gateway = setup_backend + + models = list(client.models.list().data) + assert len(models) == 1 + + @pytest.mark.skip( + reason="Skipping retrieve model test as it is not supported by the router" + ) + def test_retrieve_model(self, setup_backend): + """Test retrieving a specific model.""" + import openai + + _, model, client, gateway = setup_backend + + retrieved_model = client.models.retrieve(model) + assert retrieved_model.id == model + assert retrieved_model.root == model + + with pytest.raises(openai.NotFoundError): + client.models.retrieve("non-existent-model") + + # ------------------------------------------------------------------------- + # Helper methods + # ------------------------------------------------------------------------- + + def _run_chat_completion(self, client, model, logprobs, parallel_sample_num): + """Run a non-streaming chat completion and verify response.""" + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": "What is the capital of France? Answer in a few words.", + }, + ], + temperature=0, + logprobs=logprobs is not None and logprobs > 0, + top_logprobs=logprobs, + n=parallel_sample_num, + ) + + if logprobs: + assert isinstance( + response.choices[0].logprobs.content[0].top_logprobs[0].token, str + ) + + ret_num_top_logprobs = len( + response.choices[0].logprobs.content[0].top_logprobs + ) + assert ( + ret_num_top_logprobs == logprobs + ), f"{ret_num_top_logprobs} vs {logprobs}" + + assert len(response.choices) == parallel_sample_num + assert response.choices[0].message.role == "assistant" + assert isinstance(response.choices[0].message.content, str) + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def _run_chat_completion_stream( + self, client, model, logprobs, parallel_sample_num=1 + ): + """Run a streaming chat completion and verify response chunks.""" + generator = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "What is the capital of France?"}, + ], + temperature=0, + logprobs=logprobs is not None and logprobs > 0, + top_logprobs=logprobs, + stream=True, + stream_options={"include_usage": True}, + n=parallel_sample_num, + ) + + is_firsts = {} + is_finished = {} + finish_reason_counts = {} + for response in generator: + usage = response.usage + if usage is not None: + assert usage.prompt_tokens > 0, "usage.prompt_tokens was zero" + assert usage.completion_tokens > 0, "usage.completion_tokens was zero" + assert usage.total_tokens > 0, "usage.total_tokens was zero" + continue + + index = response.choices[0].index + finish_reason = response.choices[0].finish_reason + if finish_reason is not None: + is_finished[index] = True + finish_reason_counts[index] = finish_reason_counts.get(index, 0) + 1 + + data = response.choices[0].delta + + if is_firsts.get(index, True): + assert ( + data.role == "assistant" + ), "data.role was not 'assistant' for first chunk" + is_firsts[index] = False + continue + + if logprobs and not is_finished.get(index, False): + assert response.choices[0].logprobs, "logprobs was not returned" + assert isinstance( + response.choices[0].logprobs.content[0].top_logprobs[0].token, str + ), "top_logprobs token was not a string" + assert isinstance( + response.choices[0].logprobs.content[0].top_logprobs, list + ), "top_logprobs was not a list" + ret_num_top_logprobs = len( + response.choices[0].logprobs.content[0].top_logprobs + ) + assert ( + ret_num_top_logprobs == logprobs + ), f"{ret_num_top_logprobs} vs {logprobs}" + + assert ( + isinstance(data.content, str) + or isinstance(data.reasoning_content, str) + or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0) + or response.choices[0].finish_reason + ) + assert response.id + assert response.created + + for index in range(parallel_sample_num): + assert not is_firsts.get( + index, True + ), f"index {index} is not found in the response" + + for index in range(parallel_sample_num): + assert ( + index in finish_reason_counts + ), f"No finish_reason found for index {index}" + assert finish_reason_counts[index] == 1, ( + f"Expected 1 finish_reason chunk for index {index}, " + f"got {finish_reason_counts[index]}" + ) + + +# ============================================================================= +# Chat Completion Tests (GPT-OSS) +# +# NOTE: Some tests are skipped because they don't work with OSS models: +# - test_regex: OSS models don't support regex constraints +# - test_penalty: OSS models don't support frequency_penalty +# - test_response_prefill: OSS models don't support continue_final_message +# ============================================================================= + + +@pytest.mark.model("gpt-oss") +@pytest.mark.gateway( + extra_args=["--reasoning-parser=gpt-oss", "--history-backend", "memory"] +) +class TestChatCompletionGptOss(TestChatCompletion): + """Tests for chat completions API with GPT-OSS model. + + Inherits from TestChatCompletion and overrides tests that don't work + with OSS models. + """ + + @pytest.mark.parametrize("logprobs", [None]) # No logprobs for OSS + @pytest.mark.parametrize("parallel_sample_num", [1, 2]) + def test_chat_completion(self, setup_backend, logprobs, parallel_sample_num): + """Test non-streaming chat completion with parallel sampling (no logprobs).""" + super().test_chat_completion(setup_backend, logprobs, parallel_sample_num) + + @pytest.mark.parametrize("logprobs", [None]) # No logprobs for OSS + @pytest.mark.parametrize("parallel_sample_num", [1, 2]) + def test_chat_completion_stream(self, setup_backend, logprobs, parallel_sample_num): + """Test streaming chat completion with parallel sampling (no logprobs).""" + super().test_chat_completion_stream( + setup_backend, logprobs, parallel_sample_num + ) + + @pytest.mark.skip(reason="OSS models don't support regex constraints") + def test_regex(self, setup_backend): + pass + + @pytest.mark.skip(reason="OSS models don't support frequency_penalty") + def test_penalty(self, setup_backend): + pass + + @pytest.mark.skip(reason="OSS models don't support continue_final_message") + def test_response_prefill(self, setup_backend): + pass diff --git a/sgl-model-gateway/e2e_test/fixtures/hooks.py b/sgl-model-gateway/e2e_test/fixtures/hooks.py index a827f0c61836..2cbbb427ada4 100644 --- a/sgl-model-gateway/e2e_test/fixtures/hooks.py +++ b/sgl-model-gateway/e2e_test/fixtures/hooks.py @@ -118,8 +118,21 @@ def calculate_test_gpus( for item in items: # Extract model from marker or use default - model_marker = item.get_closest_marker(PARAM_MODEL) - model_id = model_marker.args[0] if model_marker and model_marker.args else None + # First check the class directly (handles inheritance correctly) + model_id = None + if hasattr(item, "cls") and item.cls is not None: + for marker in ( + item.cls.pytestmark if hasattr(item.cls, "pytestmark") else [] + ): + if marker.name == PARAM_MODEL and marker.args: + model_id = marker.args[0] + break + # Fall back to get_closest_marker for method-level markers + if model_id is None: + model_marker = item.get_closest_marker(PARAM_MODEL) + model_id = ( + model_marker.args[0] if model_marker and model_marker.args else None + ) # Check parametrize for model if model_id is None: