From 03e0b33542379aef43af6ff15870cf0ac51c73fc Mon Sep 17 00:00:00 2001 From: Celve Date: Wed, 28 Jan 2026 13:10:09 +0800 Subject: [PATCH 1/7] wip: add jit concat mla --- python/sglang/jit_kernel/concat_mla.py | 82 +++++ .../csrc/elementwise/concat_mla.cuh | 336 ++++++++++++++++++ 2 files changed, 418 insertions(+) create mode 100644 python/sglang/jit_kernel/concat_mla.py create mode 100644 python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh diff --git a/python/sglang/jit_kernel/concat_mla.py b/python/sglang/jit_kernel/concat_mla.py new file mode 100644 index 000000000000..3109594c0ce1 --- /dev/null +++ b/python/sglang/jit_kernel/concat_mla.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_concat_mla_k_module() -> Module: + return load_jit( + "concat_mla_k", + cuda_files=["elementwise/concat_mla.cuh"], + cuda_wrappers=[("concat_mla_k", "ConcatMlaKKernel::run")], + ) + + +@cache_once +def _jit_concat_mla_absorb_q_module() -> Module: + return load_jit( + "concat_mla_absorb_q", + cuda_files=["elementwise/concat_mla.cuh"], + cuda_wrappers=[("concat_mla_absorb_q", "ConcatMlaAbsorbQKernel::run")], + ) + + +@cache_once +def can_use_jit_concat_mla_k() -> bool: + logger = logging.getLogger(__name__) + try: + _jit_concat_mla_k_module() + return True + except Exception as e: + logger.warning(f"Failed to load JIT concat_mla_k kernel: {e}") + return False + + +@cache_once +def can_use_jit_concat_mla_absorb_q() -> bool: + logger = logging.getLogger(__name__) + try: + _jit_concat_mla_absorb_q_module() + return True + except Exception as e: + logger.warning(f"Failed to load JIT concat_mla_absorb_q kernel: {e}") + return False + + +def concat_mla_k(k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor) -> None: + """ + Concatenate k_nope and k_rope into k for MLA (Multi-head Latent Attention). + + This kernel efficiently broadcasts k_rope across all heads while copying + k_nope values directly. + + Args: + k: Output tensor of shape [num_tokens, num_heads=128, k_head_dim=192], dtype=bfloat16 + k_nope: Input tensor of shape [num_tokens, num_heads=128, nope_head_dim=128], dtype=bfloat16 + k_rope: Input tensor of shape [num_tokens, 1, rope_head_dim=64], dtype=bfloat16 + """ + module = _jit_concat_mla_k_module() + module.concat_mla_k(k, k_nope, k_rope) + + +def concat_mla_absorb_q( + a: torch.Tensor, b: torch.Tensor, out: torch.Tensor +) -> None: + """ + Concatenate tensors a and b into out for MLA absorbed Q computation. + + Args: + a: Input tensor of shape [dim_0, dim_1, 512], dtype=bfloat16 + b: Input tensor of shape [dim_0, dim_1, 64], dtype=bfloat16 + out: Output tensor of shape [dim_0, dim_1, 576], dtype=bfloat16 + """ + module = _jit_concat_mla_absorb_q_module() + module.concat_mla_absorb_q(a, b, out) diff --git a/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh b/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh new file mode 100644 index 000000000000..5acbb90a7a77 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh @@ -0,0 +1,336 @@ +#include +#include +#include + +#include + +#include +#include + +namespace { + +// ======================= Memory Utilities ======================= +// Adapted from DeepEP: https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/utils.cuh + +__forceinline__ __device__ int get_lane_id() { + int lane_id; + asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ void st_na_global_v1(const int* ptr, int v) { + asm volatile("st.global.L1::no_allocate.s32 [%0], %1;" ::"l"(ptr), "r"(v) : "memory"); +} + +__device__ __forceinline__ void st_na_global_v2(const int2* ptr, const int2& v) { + asm volatile("st.global.L1::no_allocate.v2.s32 [%0], {%1, %2};" ::"l"(ptr), "r"(v.x), "r"(v.y) : "memory"); +} + +__device__ __forceinline__ int ld_na_global_v1(const int* ptr) { + int r; + asm volatile("ld.global.nc.L1::no_allocate.s32 %0, [%1];" : "=r"(r) : "l"(ptr)); + return r; +} + +__device__ __forceinline__ int2 ld_na_global_v2(const int2* ptr) { + int2 r; + asm volatile("ld.global.nc.L1::no_allocate.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr)); + return r; +} + +__device__ __forceinline__ void prefetch_L2(const void* p) { +#if defined(ENABLE_L2_PREFETCH) + asm volatile("prefetch.global.L2 [%0];" ::"l"(p)); +#endif +} + +// ======================= concat_mla_k Kernel ======================= + +constexpr int NUM_LOCAL_HEADS = 128; +constexpr int QK_NOPE_HEAD_DIM = 128; +constexpr int QK_ROPE_HEAD_DIM = 64; +constexpr int K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM; + +constexpr int HEAD_CHUNK_SIZE = 16; +constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE; + +__global__ void concat_mla_k_kernel( + nv_bfloat16* __restrict__ k, + const nv_bfloat16* __restrict__ k_nope, + const nv_bfloat16* __restrict__ k_rope, + const int num_tokens, + const int64_t k_stride_0, + const int k_stride_1, + const int64_t k_nope_stride_0, + const int k_nope_stride_1, + const int64_t k_rope_stride_0) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int token_id = flat_warp_id / NUM_HEAD_CHUNKS; + const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS; + const int lane_id = get_lane_id(); + if (token_id >= num_tokens) return; + + using NopeVec = int2; // 8B/thread, 32 threads = 256B/row + using RopeVec = int; // 4B/thread, 32 threads = 128B/row + static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(nv_bfloat16), "nope vec mismatch"); + static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(nv_bfloat16), "rope vec mismatch"); + + const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE; + + const int2* __restrict__ nope_src = + reinterpret_cast(k_nope + token_id * k_nope_stride_0 + head_row0 * k_nope_stride_1) + lane_id; + + int2* __restrict__ nope_dst = reinterpret_cast(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id; + + int* __restrict__ rope_dst = + reinterpret_cast(k + token_id * k_stride_0 + head_row0 * k_stride_1 + QK_NOPE_HEAD_DIM) + lane_id; + + const int nope_src_stride_v = (k_nope_stride_1 >> 2); // int2 covers 4 bf16 + const int nope_dst_stride_v = (k_stride_1 >> 2); + const int rope_dst_stride_v = (k_stride_1 >> 1); // int covers 2 bf16 + + const int* rope_base = reinterpret_cast(k_rope + token_id * k_rope_stride_0); + const RopeVec rope_val = ld_na_global_v1(rope_base + lane_id); + + prefetch_L2(nope_src); + NopeVec cur = ld_na_global_v2(nope_src); + +#pragma unroll + for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) { + NopeVec next; + if (i + 1 < HEAD_CHUNK_SIZE) { + const int2* next_src = nope_src + nope_src_stride_v; + prefetch_L2(next_src); + next = ld_na_global_v2(next_src); + } + + st_na_global_v2(nope_dst, cur); + st_na_global_v1(rope_dst, rope_val); + + nope_src += nope_src_stride_v; + nope_dst += nope_dst_stride_v; + rope_dst += rope_dst_stride_v; + + cur = next; + } +} + +struct ConcatMlaKKernel { + static void run(tvm::ffi::TensorView k, tvm::ffi::TensorView k_nope, tvm::ffi::TensorView k_rope) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto H = SymbolicSize{"num_heads"}; + auto D = SymbolicSize{"k_head_dim"}; + auto D_nope = SymbolicSize{"nope_head_dim"}; + auto D_rope = SymbolicSize{"rope_head_dim"}; + auto S0_k = SymbolicSize{"k_stride_0"}; + auto S1_k = SymbolicSize{"k_stride_1"}; + auto S0_k_nope = SymbolicSize{"k_nope_stride_0"}; + auto S1_k_nope = SymbolicSize{"k_nope_stride_1"}; + auto S0_k_rope = SymbolicSize{"k_rope_stride_0"}; + auto device = SymbolicDevice{}; + + // Set known fixed values + H.set_value(NUM_LOCAL_HEADS); + D.set_value(K_HEAD_DIM); + D_nope.set_value(QK_NOPE_HEAD_DIM); + D_rope.set_value(QK_ROPE_HEAD_DIM); + + // Verify k: [num_tokens, num_heads, k_head_dim] + TensorMatcher({N, H, D}) + .with_strides({S0_k, S1_k, 1}) + .with_dtype() + .with_device(device) + .verify(k); + + // Verify k_nope: [num_tokens, num_heads, nope_head_dim] + TensorMatcher({N, H, D_nope}) + .with_strides({S0_k_nope, S1_k_nope, 1}) + .with_dtype() + .with_device(device) + .verify(k_nope); + + // Verify k_rope: [num_tokens, 1, rope_head_dim] + TensorMatcher({N, 1, D_rope}) + .with_strides({S0_k_rope, -1, 1}) + .with_dtype() + .with_device(device) + .verify(k_rope); + + // Check alignment + RuntimeCheck( + reinterpret_cast(k.data_ptr()) % 16 == 0, "Tensor k must be 16-byte aligned"); + RuntimeCheck( + reinterpret_cast(k_nope.data_ptr()) % 16 == 0, "Tensor k_nope must be 16-byte aligned"); + RuntimeCheck( + reinterpret_cast(k_rope.data_ptr()) % 16 == 0, "Tensor k_rope must be 16-byte aligned"); + + const int num_tokens = static_cast(N.unwrap()); + + constexpr int num_warps_per_block = 32; + const int grid_size = div_ceil(num_tokens * NUM_HEAD_CHUNKS, num_warps_per_block); + const int block_size = num_warps_per_block * 32; + + LaunchKernel(grid_size, block_size, device.unwrap())( + concat_mla_k_kernel, + static_cast(k.data_ptr()), + static_cast(k_nope.data_ptr()), + static_cast(k_rope.data_ptr()), + num_tokens, + S0_k.unwrap(), + static_cast(S1_k.unwrap()), + S0_k_nope.unwrap(), + static_cast(S1_k_nope.unwrap()), + S0_k_rope.unwrap()); + } +}; + +// ======================= concat_mla_absorb_q Kernel ======================= + +constexpr int A_LAST_DIM = 512; +constexpr int B_LAST_DIM = 64; +constexpr int OUT_LAST_DIM = A_LAST_DIM + B_LAST_DIM; + +__global__ void concat_mla_absorb_q_kernel( + nv_bfloat16* a, + nv_bfloat16* b, + nv_bfloat16* out, + const int num_items, + const int dim_1, + const int64_t a_stride_0, + const int a_stride_1, + const int64_t b_stride_0, + const int b_stride_1, + const int64_t out_stride_0, + const int out_stride_1) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int lane_id = get_lane_id(); + + const int idx_0 = flat_warp_id / dim_1; + const int idx_1 = flat_warp_id % dim_1; + + if (flat_warp_id >= num_items) { + return; + } + + using ABufType = int4; + constexpr int A_NUM_UNROLL = 2; + static_assert(sizeof(ABufType) * A_NUM_UNROLL == A_LAST_DIM * sizeof(a[0]) / 32); + ABufType a_buf[A_NUM_UNROLL]; + + using BBufType = int; + constexpr int B_NUM_UNROLL = 1; + static_assert(sizeof(BBufType) * B_NUM_UNROLL == B_LAST_DIM * sizeof(b[0]) / 32); + BBufType b_buf; + + { + const BBufType* base_addr = reinterpret_cast(b + idx_0 * b_stride_0 + idx_1 * b_stride_1); + b_buf = *(base_addr + lane_id); + } + +#pragma unroll + for (int i = 0; i < A_NUM_UNROLL; ++i) { + const ABufType* base_addr = reinterpret_cast(a + idx_0 * a_stride_0 + idx_1 * a_stride_1); + a_buf[i] = *(base_addr + i * 32 + lane_id); + } + + { + BBufType* base_addr = reinterpret_cast(out + idx_0 * out_stride_0 + idx_1 * out_stride_1 + A_LAST_DIM); + *(base_addr + lane_id) = b_buf; + } + +#pragma unroll + for (int i = 0; i < A_NUM_UNROLL; ++i) { + ABufType* base_addr = reinterpret_cast(out + idx_0 * out_stride_0 + idx_1 * out_stride_1); + *(base_addr + i * 32 + lane_id) = a_buf[i]; + } +} + +struct ConcatMlaAbsorbQKernel { + static void run(tvm::ffi::TensorView a, tvm::ffi::TensorView b, tvm::ffi::TensorView out) { + using namespace host; + + auto N0_a = SymbolicSize{"a_dim_0"}; + auto N1_a = SymbolicSize{"a_dim_1"}; + auto D_a = SymbolicSize{"a_last_dim"}; + auto N0_b = SymbolicSize{"b_dim_0"}; + auto N1_b = SymbolicSize{"b_dim_1"}; + auto D_b = SymbolicSize{"b_last_dim"}; + auto N0_out = SymbolicSize{"out_dim_0"}; + auto N1_out = SymbolicSize{"out_dim_1"}; + auto D_out = SymbolicSize{"out_last_dim"}; + auto S0_a = SymbolicSize{"a_stride_0"}; + auto S1_a = SymbolicSize{"a_stride_1"}; + auto S0_b = SymbolicSize{"b_stride_0"}; + auto S1_b = SymbolicSize{"b_stride_1"}; + auto S0_out = SymbolicSize{"out_stride_0"}; + auto S1_out = SymbolicSize{"out_stride_1"}; + auto device = SymbolicDevice{}; + + // Set known fixed values + D_a.set_value(A_LAST_DIM); + D_b.set_value(B_LAST_DIM); + D_out.set_value(OUT_LAST_DIM); + + // Verify a: [dim_0, dim_1, A_LAST_DIM] + TensorMatcher({N0_a, N1_a, D_a}) + .with_strides({S0_a, S1_a, 1}) + .with_dtype() + .with_device(device) + .verify(a); + + // Verify b: [dim_0, dim_1, B_LAST_DIM] + TensorMatcher({N0_b, N1_b, D_b}) + .with_strides({S0_b, S1_b, 1}) + .with_dtype() + .with_device(device) + .verify(b); + + // Verify out: [dim_0, dim_1, OUT_LAST_DIM] + TensorMatcher({N0_out, N1_out, D_out}) + .with_strides({S0_out, S1_out, 1}) + .with_dtype() + .with_device(device) + .verify(out); + + // Check alignment + RuntimeCheck( + reinterpret_cast(a.data_ptr()) % 16 == 0, "Tensor a must be 16-byte aligned"); + RuntimeCheck( + reinterpret_cast(b.data_ptr()) % 16 == 0, "Tensor b must be 16-byte aligned"); + RuntimeCheck( + reinterpret_cast(out.data_ptr()) % 16 == 0, "Tensor out must be 16-byte aligned"); + + // Verify dimensions match: a.size(0) * a.size(1) == b.size(0) * b.size(1) + RuntimeCheck( + N0_a.unwrap() * N1_a.unwrap() == N0_b.unwrap() * N1_b.unwrap(), + "Dimension mismatch: a.size(0) * a.size(1) must equal b.size(0) * b.size(1)"); + RuntimeCheck( + N1_a.unwrap() == N1_b.unwrap(), + "Dimension mismatch: a.size(1) must equal b.size(1)"); + + const int num_items = static_cast(N0_a.unwrap() * N1_a.unwrap()); + const int dim_1 = static_cast(N1_a.unwrap()); + + constexpr int num_warps_per_block = 32; + const int grid_size = div_ceil(num_items, num_warps_per_block); + const int block_size = num_warps_per_block * 32; + + LaunchKernel(grid_size, block_size, device.unwrap())( + concat_mla_absorb_q_kernel, + static_cast(a.data_ptr()), + static_cast(b.data_ptr()), + static_cast(out.data_ptr()), + num_items, + dim_1, + S0_a.unwrap(), + static_cast(S1_a.unwrap()), + S0_b.unwrap(), + static_cast(S1_b.unwrap()), + S0_out.unwrap(), + static_cast(S1_out.unwrap())); + } +}; + +} // namespace From 5b35ae53d9519541a6ba4a027710526a19e2c19d Mon Sep 17 00:00:00 2001 From: Celve Date: Wed, 28 Jan 2026 13:21:33 +0800 Subject: [PATCH 2/7] feat: add test --- .../jit_kernel/tests/test_concat_mla.py | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 python/sglang/jit_kernel/tests/test_concat_mla.py diff --git a/python/sglang/jit_kernel/tests/test_concat_mla.py b/python/sglang/jit_kernel/tests/test_concat_mla.py new file mode 100644 index 000000000000..820d29e3ffb6 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_concat_mla.py @@ -0,0 +1,155 @@ +import itertools + +import pytest +import torch +import triton + + +def torch_concat_mla_k( + k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor +) -> None: + """Reference PyTorch implementation for concat_mla_k.""" + # k_nope: [num_tokens, num_heads, nope_head_dim] + # k_rope: [num_tokens, 1, rope_head_dim] + # k: [num_tokens, num_heads, nope_head_dim + rope_head_dim] + nope_head_dim = k_nope.shape[-1] + k[:, :, :nope_head_dim] = k_nope + # Broadcast k_rope across all heads + k[:, :, nope_head_dim:] = k_rope.expand(-1, k.shape[1], -1) + + +def torch_concat_mla_absorb_q( + a: torch.Tensor, b: torch.Tensor, out: torch.Tensor +) -> None: + """Reference PyTorch implementation for concat_mla_absorb_q.""" + # a: [dim_0, dim_1, a_last_dim] + # b: [dim_0, dim_1, b_last_dim] + # out: [dim_0, dim_1, a_last_dim + b_last_dim] + a_last_dim = a.shape[-1] + out[:, :, :a_last_dim] = a + out[:, :, a_last_dim:] = b + + +def sgl_kernel_concat_mla_k( + k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor +) -> None: + """AOT compiled sgl_kernel implementation.""" + from sgl_kernel import concat_mla_k + + concat_mla_k(k, k_nope, k_rope) + + +def sgl_kernel_concat_mla_absorb_q( + a: torch.Tensor, b: torch.Tensor, out: torch.Tensor +) -> None: + """AOT compiled sgl_kernel implementation.""" + from sgl_kernel import concat_mla_absorb_q + + concat_mla_absorb_q(a, b, out) + + +def jit_concat_mla_k( + k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor +) -> None: + """JIT compiled implementation.""" + from sglang.jit_kernel.concat_mla import concat_mla_k + + concat_mla_k(k, k_nope, k_rope) + + +def jit_concat_mla_absorb_q( + a: torch.Tensor, b: torch.Tensor, out: torch.Tensor +) -> None: + """JIT compiled implementation.""" + from sglang.jit_kernel.concat_mla import concat_mla_absorb_q + + concat_mla_absorb_q(a, b, out) + + +# Constants matching the kernel +NUM_LOCAL_HEADS = 128 +QK_NOPE_HEAD_DIM = 128 +QK_ROPE_HEAD_DIM = 64 +K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM + +A_LAST_DIM = 512 +B_LAST_DIM = 64 +OUT_LAST_DIM = A_LAST_DIM + B_LAST_DIM + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + +# Test configurations +NUM_TOKENS_LIST = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS_LIST) +def test_concat_mla_k_jit_vs_torch(num_tokens: int) -> None: + """Test JIT kernel against PyTorch reference.""" + k_jit = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE) + k_torch = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE) + + k_nope = torch.randn(num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) + k_rope = torch.randn(num_tokens, 1, QK_ROPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) + + torch_concat_mla_k(k_torch, k_nope, k_rope) + jit_concat_mla_k(k_jit, k_nope, k_rope) + + triton.testing.assert_close(k_jit, k_torch, atol=0, rtol=0) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS_LIST) +def test_concat_mla_k_jit_vs_aot(num_tokens: int) -> None: + """Test JIT kernel against AOT kernel for bitwise equivalence.""" + k_jit = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE) + k_aot = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE) + + k_nope = torch.randn(num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) + k_rope = torch.randn(num_tokens, 1, QK_ROPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) + + sgl_kernel_concat_mla_k(k_aot, k_nope, k_rope) + jit_concat_mla_k(k_jit, k_nope, k_rope) + + triton.testing.assert_close(k_jit, k_aot, atol=0, rtol=0) + + +DIM_0_LIST = [1, 2, 4, 8, 16, 32] +DIM_1_LIST = [1, 2, 4, 8, 16, 128] + + +@pytest.mark.parametrize( + "dim_0,dim_1", + list(itertools.product(DIM_0_LIST, DIM_1_LIST)), +) +def test_concat_mla_absorb_q_jit_vs_torch(dim_0: int, dim_1: int) -> None: + """Test JIT kernel against PyTorch reference.""" + a = torch.randn(dim_0, dim_1, A_LAST_DIM, device=DEVICE, dtype=DTYPE) + b = torch.randn(dim_0, dim_1, B_LAST_DIM, device=DEVICE, dtype=DTYPE) + out_jit = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE) + out_torch = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE) + + torch_concat_mla_absorb_q(a, b, out_torch) + jit_concat_mla_absorb_q(a, b, out_jit) + + triton.testing.assert_close(out_jit, out_torch, atol=0, rtol=0) + + +@pytest.mark.parametrize( + "dim_0,dim_1", + list(itertools.product(DIM_0_LIST, DIM_1_LIST)), +) +def test_concat_mla_absorb_q_jit_vs_aot(dim_0: int, dim_1: int) -> None: + """Test JIT kernel against AOT kernel for bitwise equivalence.""" + a = torch.randn(dim_0, dim_1, A_LAST_DIM, device=DEVICE, dtype=DTYPE) + b = torch.randn(dim_0, dim_1, B_LAST_DIM, device=DEVICE, dtype=DTYPE) + out_jit = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE) + out_aot = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE) + + sgl_kernel_concat_mla_absorb_q(a, b, out_aot) + jit_concat_mla_absorb_q(a, b, out_jit) + + triton.testing.assert_close(out_jit, out_aot, atol=0, rtol=0) + + +if __name__ == "__main__": + pytest.main([__file__]) From 97084edb5856a83264e27962145448fd5d45c98c Mon Sep 17 00:00:00 2001 From: Celve Date: Wed, 28 Jan 2026 14:33:23 +0800 Subject: [PATCH 3/7] wip: different return style --- python/sglang/jit_kernel/concat_mla.py | 20 ++++++++++++------- .../jit_kernel/tests/test_concat_mla.py | 8 +++++--- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/python/sglang/jit_kernel/concat_mla.py b/python/sglang/jit_kernel/concat_mla.py index 3109594c0ce1..a45752562778 100644 --- a/python/sglang/jit_kernel/concat_mla.py +++ b/python/sglang/jit_kernel/concat_mla.py @@ -67,16 +67,22 @@ def concat_mla_k(k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor) -> module.concat_mla_k(k, k_nope, k_rope) -def concat_mla_absorb_q( - a: torch.Tensor, b: torch.Tensor, out: torch.Tensor -) -> None: +def concat_mla_absorb_q(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ - Concatenate tensors a and b into out for MLA absorbed Q computation. + Concatenate tensors a and b for MLA absorbed Q computation. Args: - a: Input tensor of shape [dim_0, dim_1, 512], dtype=bfloat16 - b: Input tensor of shape [dim_0, dim_1, 64], dtype=bfloat16 - out: Output tensor of shape [dim_0, dim_1, 576], dtype=bfloat16 + a: Input tensor of shape [dim_0, dim_1, a_last_dim], dtype=bfloat16 + b: Input tensor of shape [dim_0, dim_1, b_last_dim], dtype=bfloat16 + + Returns: + Output tensor of shape [dim_0, dim_1, a_last_dim + b_last_dim], dtype=bfloat16 """ + out = torch.empty( + (*a.shape[:-1], a.shape[-1] + b.shape[-1]), + dtype=a.dtype, + device=a.device, + ) module = _jit_concat_mla_absorb_q_module() module.concat_mla_absorb_q(a, b, out) + return out diff --git a/python/sglang/jit_kernel/tests/test_concat_mla.py b/python/sglang/jit_kernel/tests/test_concat_mla.py index 820d29e3ffb6..5ed84c443b8d 100644 --- a/python/sglang/jit_kernel/tests/test_concat_mla.py +++ b/python/sglang/jit_kernel/tests/test_concat_mla.py @@ -45,7 +45,8 @@ def sgl_kernel_concat_mla_absorb_q( """AOT compiled sgl_kernel implementation.""" from sgl_kernel import concat_mla_absorb_q - concat_mla_absorb_q(a, b, out) + result = concat_mla_absorb_q(a, b) # AOT returns output + out.copy_(result) # Copy to provided tensor for comparison def jit_concat_mla_k( @@ -60,10 +61,11 @@ def jit_concat_mla_k( def jit_concat_mla_absorb_q( a: torch.Tensor, b: torch.Tensor, out: torch.Tensor ) -> None: - """JIT compiled implementation.""" + """JIT compiled implementation - wrapper for test compatibility.""" from sglang.jit_kernel.concat_mla import concat_mla_absorb_q - concat_mla_absorb_q(a, b, out) + result = concat_mla_absorb_q(a, b) + out.copy_(result) # Constants matching the kernel From 299a76ae70ddb0d3a89211c8aeb4c2115ea3962d Mon Sep 17 00:00:00 2001 From: Celve Date: Thu, 29 Jan 2026 14:19:32 +0800 Subject: [PATCH 4/7] wip: remove can use helper --- python/sglang/jit_kernel/concat_mla.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/python/sglang/jit_kernel/concat_mla.py b/python/sglang/jit_kernel/concat_mla.py index a45752562778..4945b73bc27f 100644 --- a/python/sglang/jit_kernel/concat_mla.py +++ b/python/sglang/jit_kernel/concat_mla.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging from typing import TYPE_CHECKING import torch @@ -29,28 +28,6 @@ def _jit_concat_mla_absorb_q_module() -> Module: ) -@cache_once -def can_use_jit_concat_mla_k() -> bool: - logger = logging.getLogger(__name__) - try: - _jit_concat_mla_k_module() - return True - except Exception as e: - logger.warning(f"Failed to load JIT concat_mla_k kernel: {e}") - return False - - -@cache_once -def can_use_jit_concat_mla_absorb_q() -> bool: - logger = logging.getLogger(__name__) - try: - _jit_concat_mla_absorb_q_module() - return True - except Exception as e: - logger.warning(f"Failed to load JIT concat_mla_absorb_q kernel: {e}") - return False - - def concat_mla_k(k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor) -> None: """ Concatenate k_nope and k_rope into k for MLA (Multi-head Latent Attention). From 8b3094e61e115d9f4cbf94820cc7da07a5c01729 Mon Sep 17 00:00:00 2001 From: Celve Date: Fri, 30 Jan 2026 10:12:28 +0800 Subject: [PATCH 5/7] wip: add benchmark --- .../jit_kernel/benchmark/bench_concat_mla.py | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 python/sglang/jit_kernel/benchmark/bench_concat_mla.py diff --git a/python/sglang/jit_kernel/benchmark/bench_concat_mla.py b/python/sglang/jit_kernel/benchmark/bench_concat_mla.py new file mode 100644 index 000000000000..87c7ae56f91e --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_concat_mla.py @@ -0,0 +1,163 @@ +import itertools + +import torch +import triton +import triton.testing +from sgl_kernel import concat_mla_absorb_q as aot_absorb_q +from sgl_kernel import concat_mla_k as aot_k + +from sglang.jit_kernel.benchmark.utils import is_in_ci +from sglang.jit_kernel.concat_mla import concat_mla_absorb_q as jit_absorb_q +from sglang.jit_kernel.concat_mla import concat_mla_k as jit_k + +IS_CI = is_in_ci() + +# Constants +NUM_LOCAL_HEADS = 128 +QK_NOPE_HEAD_DIM = 128 +QK_ROPE_HEAD_DIM = 64 +K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM + +A_LAST_DIM = 512 +B_LAST_DIM = 64 + +DTYPE = torch.bfloat16 +DEVICE = "cuda" + + +def aot_concat_mla_k(k, k_nope, k_rope): + aot_k(k, k_nope, k_rope) + + +def jit_concat_mla_k(k, k_nope, k_rope): + jit_k(k, k_nope, k_rope) + + +def torch_concat_mla_k(k, k_nope, k_rope): + nope_head_dim = k_nope.shape[-1] + k[:, :, :nope_head_dim] = k_nope + k[:, :, nope_head_dim:] = k_rope.expand(-1, k.shape[1], -1) + + +def aot_concat_mla_absorb_q(a, b): + return aot_absorb_q(a, b) + + +def jit_concat_mla_absorb_q(a, b): + return jit_absorb_q(a, b) + + +def torch_concat_mla_absorb_q(a, b, out): + a_last_dim = a.shape[-1] + out[:, :, :a_last_dim] = a + out[:, :, a_last_dim:] = b + + +if IS_CI: + NUM_TOKENS_VALS = [256, 1024] +else: + NUM_TOKENS_VALS = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768] + +K_LINE_VALS = ["aot", "jit", "torch"] +K_LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "PyTorch"] +K_STYLES = [("orange", "-"), ("blue", "--"), ("green", "-.")] + + +def _create_concat_mla_k_data(num_tokens): + """Allocate oversized containers and slice to produce non-contiguous tensors.""" + k_nope_container = torch.randn( + (num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM + 128), + dtype=DTYPE, + device=DEVICE, + ) + k_nope = k_nope_container[:, :, :QK_NOPE_HEAD_DIM] + + k_rope_container = torch.randn( + (num_tokens, 1, 128 + QK_ROPE_HEAD_DIM), + dtype=DTYPE, + device=DEVICE, + ) + k_rope = k_rope_container[:, :, -QK_ROPE_HEAD_DIM:] + + k = torch.empty( + (num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM), + dtype=DTYPE, + device=DEVICE, + ) + return k, k_nope, k_rope + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=NUM_TOKENS_VALS, + line_arg="provider", + line_vals=K_LINE_VALS, + line_names=K_LINE_NAMES, + styles=K_STYLES, + ylabel="us", + plot_name="concat-mla-k-performance", + args={}, + ) +) +def bench_concat_mla_k(num_tokens: int, provider: str): + k, k_nope, k_rope = _create_concat_mla_k_data(num_tokens) + + FN_MAP = { + "aot": aot_concat_mla_k, + "jit": jit_concat_mla_k, + "torch": torch_concat_mla_k, + } + fn = lambda: FN_MAP[provider](k, k_nope, k_rope) + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if IS_CI: + ABSORB_Q_VALS = list(itertools.product([4, 16], [16])) +else: + ABSORB_Q_VALS = list(itertools.product([1, 4, 8, 16, 32], [1, 8, 32, 128])) + +Q_LINE_VALS = ["aot", "jit", "torch"] +Q_LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "PyTorch"] +Q_STYLES = [("orange", "-"), ("blue", "--"), ("green", "-.")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["dim_0", "dim_1"], + x_vals=ABSORB_Q_VALS, + line_arg="provider", + line_vals=Q_LINE_VALS, + line_names=Q_LINE_NAMES, + styles=Q_STYLES, + ylabel="us", + plot_name="concat-mla-absorb-q-performance", + args={}, + ) +) +def bench_concat_mla_absorb_q(dim_0: int, dim_1: int, provider: str): + a = torch.randn(dim_0, dim_1, A_LAST_DIM, dtype=DTYPE, device=DEVICE) + b = torch.randn(dim_0, dim_1, B_LAST_DIM, dtype=DTYPE, device=DEVICE) + + if provider == "torch": + out = torch.empty( + dim_0, dim_1, A_LAST_DIM + B_LAST_DIM, dtype=DTYPE, device=DEVICE + ) + fn = lambda: torch_concat_mla_absorb_q(a, b, out) + else: + FN_MAP = { + "aot": aot_concat_mla_absorb_q, + "jit": jit_concat_mla_absorb_q, + } + fn = lambda: FN_MAP[provider](a, b) + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + bench_concat_mla_k.run(print_data=True) + bench_concat_mla_absorb_q.run(print_data=True) From 79a102688047ef316cbafda2ce4c9d4e535927bd Mon Sep 17 00:00:00 2001 From: Celve Date: Fri, 30 Jan 2026 10:28:04 +0800 Subject: [PATCH 6/7] wip: align with utils --- .../csrc/elementwise/concat_mla.cuh | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh b/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh index 5acbb90a7a77..5d5ce5e5cccd 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh @@ -12,33 +12,33 @@ namespace { // ======================= Memory Utilities ======================= // Adapted from DeepEP: https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/utils.cuh -__forceinline__ __device__ int get_lane_id() { +SGL_DEVICE int get_lane_id() { int lane_id; asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); return lane_id; } -__device__ __forceinline__ void st_na_global_v1(const int* ptr, int v) { +SGL_DEVICE void st_na_global_v1(const int* ptr, int v) { asm volatile("st.global.L1::no_allocate.s32 [%0], %1;" ::"l"(ptr), "r"(v) : "memory"); } -__device__ __forceinline__ void st_na_global_v2(const int2* ptr, const int2& v) { +SGL_DEVICE void st_na_global_v2(const int2* ptr, const int2& v) { asm volatile("st.global.L1::no_allocate.v2.s32 [%0], {%1, %2};" ::"l"(ptr), "r"(v.x), "r"(v.y) : "memory"); } -__device__ __forceinline__ int ld_na_global_v1(const int* ptr) { +SGL_DEVICE int ld_na_global_v1(const int* ptr) { int r; asm volatile("ld.global.nc.L1::no_allocate.s32 %0, [%1];" : "=r"(r) : "l"(ptr)); return r; } -__device__ __forceinline__ int2 ld_na_global_v2(const int2* ptr) { +SGL_DEVICE int2 ld_na_global_v2(const int2* ptr) { int2 r; asm volatile("ld.global.nc.L1::no_allocate.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr)); return r; } -__device__ __forceinline__ void prefetch_L2(const void* p) { +SGL_DEVICE void prefetch_L2(const void* p) { #if defined(ENABLE_L2_PREFETCH) asm volatile("prefetch.global.L2 [%0];" ::"l"(p)); #endif @@ -55,9 +55,9 @@ constexpr int HEAD_CHUNK_SIZE = 16; constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE; __global__ void concat_mla_k_kernel( - nv_bfloat16* __restrict__ k, - const nv_bfloat16* __restrict__ k_nope, - const nv_bfloat16* __restrict__ k_rope, + bf16_t* __restrict__ k, + const bf16_t* __restrict__ k_nope, + const bf16_t* __restrict__ k_rope, const int num_tokens, const int64_t k_stride_0, const int k_stride_1, @@ -72,8 +72,8 @@ __global__ void concat_mla_k_kernel( using NopeVec = int2; // 8B/thread, 32 threads = 256B/row using RopeVec = int; // 4B/thread, 32 threads = 128B/row - static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(nv_bfloat16), "nope vec mismatch"); - static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(nv_bfloat16), "rope vec mismatch"); + static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(bf16_t), "nope vec mismatch"); + static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(bf16_t), "rope vec mismatch"); const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE; @@ -174,9 +174,9 @@ struct ConcatMlaKKernel { LaunchKernel(grid_size, block_size, device.unwrap())( concat_mla_k_kernel, - static_cast(k.data_ptr()), - static_cast(k_nope.data_ptr()), - static_cast(k_rope.data_ptr()), + static_cast(k.data_ptr()), + static_cast(k_nope.data_ptr()), + static_cast(k_rope.data_ptr()), num_tokens, S0_k.unwrap(), static_cast(S1_k.unwrap()), @@ -193,9 +193,9 @@ constexpr int B_LAST_DIM = 64; constexpr int OUT_LAST_DIM = A_LAST_DIM + B_LAST_DIM; __global__ void concat_mla_absorb_q_kernel( - nv_bfloat16* a, - nv_bfloat16* b, - nv_bfloat16* out, + bf16_t* a, + bf16_t* b, + bf16_t* out, const int num_items, const int dim_1, const int64_t a_stride_0, @@ -319,9 +319,9 @@ struct ConcatMlaAbsorbQKernel { LaunchKernel(grid_size, block_size, device.unwrap())( concat_mla_absorb_q_kernel, - static_cast(a.data_ptr()), - static_cast(b.data_ptr()), - static_cast(out.data_ptr()), + static_cast(a.data_ptr()), + static_cast(b.data_ptr()), + static_cast(out.data_ptr()), num_items, dim_1, S0_a.unwrap(), From c0001694f312624266c1d59b8e226e5b234ffe64 Mon Sep 17 00:00:00 2001 From: Celve Date: Sun, 1 Feb 2026 21:13:14 +0800 Subject: [PATCH 7/7] wip: fix lint issues --- .../csrc/elementwise/concat_mla.cuh | 29 ++++++------------- .../jit_kernel/tests/test_concat_mla.py | 28 +++++++++++++----- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh b/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh index 5d5ce5e5cccd..eee33318fc83 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh @@ -1,5 +1,6 @@ #include #include + #include #include @@ -138,11 +139,7 @@ struct ConcatMlaKKernel { D_rope.set_value(QK_ROPE_HEAD_DIM); // Verify k: [num_tokens, num_heads, k_head_dim] - TensorMatcher({N, H, D}) - .with_strides({S0_k, S1_k, 1}) - .with_dtype() - .with_device(device) - .verify(k); + TensorMatcher({N, H, D}).with_strides({S0_k, S1_k, 1}).with_dtype().with_device(device).verify(k); // Verify k_nope: [num_tokens, num_heads, nope_head_dim] TensorMatcher({N, H, D_nope}) @@ -159,12 +156,9 @@ struct ConcatMlaKKernel { .verify(k_rope); // Check alignment - RuntimeCheck( - reinterpret_cast(k.data_ptr()) % 16 == 0, "Tensor k must be 16-byte aligned"); - RuntimeCheck( - reinterpret_cast(k_nope.data_ptr()) % 16 == 0, "Tensor k_nope must be 16-byte aligned"); - RuntimeCheck( - reinterpret_cast(k_rope.data_ptr()) % 16 == 0, "Tensor k_rope must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(k.data_ptr()) % 16 == 0, "Tensor k must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(k_nope.data_ptr()) % 16 == 0, "Tensor k_nope must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(k_rope.data_ptr()) % 16 == 0, "Tensor k_rope must be 16-byte aligned"); const int num_tokens = static_cast(N.unwrap()); @@ -295,20 +289,15 @@ struct ConcatMlaAbsorbQKernel { .verify(out); // Check alignment - RuntimeCheck( - reinterpret_cast(a.data_ptr()) % 16 == 0, "Tensor a must be 16-byte aligned"); - RuntimeCheck( - reinterpret_cast(b.data_ptr()) % 16 == 0, "Tensor b must be 16-byte aligned"); - RuntimeCheck( - reinterpret_cast(out.data_ptr()) % 16 == 0, "Tensor out must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(a.data_ptr()) % 16 == 0, "Tensor a must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(b.data_ptr()) % 16 == 0, "Tensor b must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(out.data_ptr()) % 16 == 0, "Tensor out must be 16-byte aligned"); // Verify dimensions match: a.size(0) * a.size(1) == b.size(0) * b.size(1) RuntimeCheck( N0_a.unwrap() * N1_a.unwrap() == N0_b.unwrap() * N1_b.unwrap(), "Dimension mismatch: a.size(0) * a.size(1) must equal b.size(0) * b.size(1)"); - RuntimeCheck( - N1_a.unwrap() == N1_b.unwrap(), - "Dimension mismatch: a.size(1) must equal b.size(1)"); + RuntimeCheck(N1_a.unwrap() == N1_b.unwrap(), "Dimension mismatch: a.size(1) must equal b.size(1)"); const int num_items = static_cast(N0_a.unwrap() * N1_a.unwrap()); const int dim_1 = static_cast(N1_a.unwrap()); diff --git a/python/sglang/jit_kernel/tests/test_concat_mla.py b/python/sglang/jit_kernel/tests/test_concat_mla.py index 5ed84c443b8d..6c5d3631da53 100644 --- a/python/sglang/jit_kernel/tests/test_concat_mla.py +++ b/python/sglang/jit_kernel/tests/test_concat_mla.py @@ -88,10 +88,16 @@ def jit_concat_mla_absorb_q( @pytest.mark.parametrize("num_tokens", NUM_TOKENS_LIST) def test_concat_mla_k_jit_vs_torch(num_tokens: int) -> None: """Test JIT kernel against PyTorch reference.""" - k_jit = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE) - k_torch = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE) - - k_nope = torch.randn(num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) + k_jit = torch.empty( + num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + k_torch = torch.empty( + num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + + k_nope = torch.randn( + num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) k_rope = torch.randn(num_tokens, 1, QK_ROPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) torch_concat_mla_k(k_torch, k_nope, k_rope) @@ -103,10 +109,16 @@ def test_concat_mla_k_jit_vs_torch(num_tokens: int) -> None: @pytest.mark.parametrize("num_tokens", NUM_TOKENS_LIST) def test_concat_mla_k_jit_vs_aot(num_tokens: int) -> None: """Test JIT kernel against AOT kernel for bitwise equivalence.""" - k_jit = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE) - k_aot = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE) - - k_nope = torch.randn(num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) + k_jit = torch.empty( + num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + k_aot = torch.empty( + num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + + k_nope = torch.randn( + num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) k_rope = torch.randn(num_tokens, 1, QK_ROPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) sgl_kernel_concat_mla_k(k_aot, k_nope, k_rope)