Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -271,36 +271,36 @@ jobs:

echo "All benchmark tests completed!"

# sgl-kernel-b200-test:
# needs: [check-changes, sgl-kernel-build-wheels]
# if: needs.check-changes.outputs.sgl_kernel == 'true'
# runs-on: 4-gpu-b200
# env:
# RUNNER_LABELS: 4-gpu-b200
# steps:
# - uses: actions/checkout@v4
sgl-kernel-b200-test:
needs: [check-changes, sgl-kernel-build-wheels]
if: needs.check-changes.outputs.sgl_kernel == 'true'
runs-on: 4-gpu-b200
env:
RUNNER_LABELS: 4-gpu-b200
steps:
- uses: actions/checkout@v4

# - name: Cleanup
# run: |
# ls -alh sgl-kernel/dist || true
# rm -rf sgl-kernel/dist/* || true
- name: Cleanup
run: |
ls -alh sgl-kernel/dist || true
rm -rf sgl-kernel/dist/* || true

# - name: Download artifacts
# uses: actions/download-artifact@v4
# with:
# path: sgl-kernel/dist/
# merge-multiple: true
# pattern: wheel-python3.10-cuda12.9
- name: Download artifacts
uses: actions/download-artifact@v4
with:
path: sgl-kernel/dist/
merge-multiple: true
pattern: wheel-python3.10-cuda12.9

# - name: Install dependencies
# run: |
# CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1 bash scripts/ci/ci_install_dependency.sh
- name: Install dependencies
run: |
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1 bash scripts/ci/ci_install_dependency.sh

# - name: Run sgl-kernel unit tests on B200
# timeout-minutes: 30
# run: |
# cd sgl-kernel
# pytest tests/
- name: Run sgl-kernel unit tests on B200
timeout-minutes: 30
run: |
cd sgl-kernel
pytest tests/

# Adding a single CUDA13 smoke test to verify that the kernel builds and runs
# TODO: Add back this test when it can pass on CI
Expand Down Expand Up @@ -1094,6 +1094,7 @@ jobs:
sgl-kernel-unit-test,
sgl-kernel-mla-test,
sgl-kernel-benchmark-test,
sgl-kernel-b200-test,

multimodal-gen-test-1-gpu,
multimodal-gen-test-2-gpu,
Expand Down
4 changes: 2 additions & 2 deletions sgl-kernel/tests/test_es_fp8_blockwise_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def is_sm90_supported(device=None) -> bool:


@pytest.mark.skipif(
not (is_sm100_supported() or is_sm90_supported()),
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90",
not is_sm90_supported(),
reason="es_fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm90",
)
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64, 128])
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
Expand Down
8 changes: 8 additions & 0 deletions sgl-kernel/tests/test_flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
DTYPE = [torch.float16, torch.bfloat16]


def is_sm90_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
)
Comment on lines +41 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is_sm90_supported function is a duplicate of the one in sgl-kernel/tests/test_es_fp8_blockwise_moe.py. To follow the DRY (Don't Repeat Yourself) principle and improve maintainability, this function should be defined in a single, shared location, such as a test utilities file (e.g., sgl-kernel/tests/utils.py or sgl-kernel/tests/conftest.py), and imported into both test modules.



def quantize_k_cache(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int,
Expand Down Expand Up @@ -362,6 +368,7 @@ def test_flashmla_prefill(
torch.testing.assert_close(ans_lse, ref_lse, atol=1e-6, rtol=2.01 / 65536)


@pytest.mark.skipif(not is_sm90_supported(), reason="SM90 required for FP8 support")
@pytest.mark.parametrize("b", B_DECODE)
@pytest.mark.parametrize("s_q", S_Q_DECODE)
@pytest.mark.parametrize("s_k", S_K_DECODE)
Expand Down Expand Up @@ -512,6 +519,7 @@ def test_flash_mla_decode(
torch.testing.assert_close(lse_ans, lse_ref, atol=1e-6, rtol=8.01 / 65536)


@pytest.mark.skipif(not is_sm90_supported(), reason="SM90 required for FP8 support")
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1, 2])
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
Expand Down
Loading