From f8245f83e4372c3753d4083a1a4852428ed00ef6 Mon Sep 17 00:00:00 2001 From: alcanderian Date: Sat, 7 Jun 2025 09:05:20 +0000 Subject: [PATCH] chore: upgrade sgl-kernel v0.1.6 --- python/pyproject.toml | 2 +- python/sglang/srt/entrypoints/engine.py | 2 +- .../srt/layers/quantization/deep_gemm.py | 98 ++++++++----------- 3 files changed, 44 insertions(+), 58 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 7aaaa7de95c6..12f7a74b37db 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -49,7 +49,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.1.5", + "sgl-kernel==0.1.6", "flashinfer_python==0.2.5", "torch==2.6.0", "torchvision==0.21.0", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 813fc4c7d17c..d02e344d74b3 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -579,7 +579,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda: assert_pkg_version( "sgl-kernel", - "0.1.5", + "0.1.6", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) diff --git a/python/sglang/srt/layers/quantization/deep_gemm.py b/python/sglang/srt/layers/quantization/deep_gemm.py index ef454bc0fd3f..c49a6bb6d7bc 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm.py +++ b/python/sglang/srt/layers/quantization/deep_gemm.py @@ -17,10 +17,10 @@ try: import deep_gemm from deep_gemm import get_num_sms + from deep_gemm.jit import build from deep_gemm.jit.compiler import get_nvcc_compiler from deep_gemm.jit_kernels.gemm import get_best_configs from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType - from deep_gemm.jit_kernels.tuner import jit_tuner sm_version = get_device_sm() if sm_version == 90: @@ -148,32 +148,28 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one( block_k = 128 num_tma_threads = 128 num_math_threads_per_group = 128 + kwargs = { + "GEMM_TYPE": GemmType.GroupedMasked, "NUM_TMA_THREADS": num_tma_threads, "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, + "N": n, + "K": k, + "NUM_GROUPS": 1, + "BLOCK_M": block_m, + "BLOCK_N": block_n, "BLOCK_K": block_k, + "SWIZZLE_D_MODE": smem_config[1], + "BLOCK_N_PADDING": smem_config[2], + "NUM_STAGES": num_stages, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], "NUM_SMS": num_sms, "SMEM_SIZE": smem_config[0], } - _, _ = jit_tuner.compile_and_tune( - name="m_grouped_gemm_fp8_fp8_bf16_nt", - keys={ - "N": n, - "K": k, - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "SWIZZLE_D_MODE": smem_config[1], - "BLOCK_N_PADDING": smem_config[2], - "NUM_GROUPS": num_groups, - "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": tma_multicast_config[0], - "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], - "GEMM_TYPE": GemmType.GroupedMasked, - }, - space=(), - kwargs=kwargs, - runtime_cls=FP8GemmRuntime, - ) + + code = FP8GemmRuntime.generate(kwargs) + _ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs) def _compile_grouped_gemm_nt_f8f8bf16_contig_one( @@ -187,31 +183,26 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one( num_tma_threads = 128 num_math_threads_per_group = 128 kwargs = { + "GEMM_TYPE": GemmType.GroupedContiguous, "NUM_TMA_THREADS": num_tma_threads, "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, + "N": n, + "K": k, + "NUM_GROUPS": 1, + "BLOCK_M": block_m, + "BLOCK_N": block_n, "BLOCK_K": block_k, + "SWIZZLE_D_MODE": smem_config[1], + "BLOCK_N_PADDING": smem_config[2], + "NUM_STAGES": num_stages, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], "NUM_SMS": num_sms, "SMEM_SIZE": smem_config[0], } - _, _ = jit_tuner.compile_and_tune( - name="m_grouped_gemm_fp8_fp8_bf16_nt", - keys={ - "N": n, - "K": k, - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "SWIZZLE_D_MODE": smem_config[1], - "BLOCK_N_PADDING": smem_config[2], - "NUM_GROUPS": num_groups, - "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": tma_multicast_config[0], - "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], - "GEMM_TYPE": GemmType.GroupedContiguous, - }, - space=(), - kwargs=kwargs, - runtime_cls=FP8GemmRuntime, - ) + + code = FP8GemmRuntime.generate(kwargs) + _ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs) def _compile_gemm_nt_f8f8bf16_one( @@ -228,28 +219,23 @@ def _compile_gemm_nt_f8f8bf16_one( "GEMM_TYPE": GemmType.Normal, "NUM_TMA_THREADS": num_tma_threads, "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, + "N": n, + "K": k, "NUM_GROUPS": 1, + "BLOCK_M": block_m, + "BLOCK_N": block_n, "BLOCK_K": block_k, + "SWIZZLE_D_MODE": smem_config[1], + "BLOCK_N_PADDING": smem_config[2], + "NUM_STAGES": num_stages, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], "NUM_SMS": num_sms, "SMEM_SIZE": smem_config[0], } - _, _ = jit_tuner.compile_and_tune( - name="gemm_fp8_fp8_bf16_nt", - keys={ - "N": n, - "K": k, - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "SWIZZLE_D_MODE": smem_config[1], - "BLOCK_N_PADDING": smem_config[2], - "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": tma_multicast_config[0], - "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], - }, - space=(), - kwargs=kwargs, - runtime_cls=FP8GemmRuntime, - ) + + code = FP8GemmRuntime.generate(kwargs) + _ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs) _KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {