[Move sgl-kernel Kernel to JIT] Add JIT concat MLA kernels#17889
[Move sgl-kernel Kernel to JIT] Add JIT concat MLA kernels#17889BBuf merged 7 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @celve, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates performance-critical JIT-compiled CUDA kernels into the SGLang framework to accelerate tensor concatenation operations fundamental to Multi-head Latent Attention. By offloading these computations to optimized GPU kernels, the change aims to significantly improve the efficiency of MLA-related processes. The introduction is accompanied by robust testing to ensure functional correctness and compatibility with existing components. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces JIT-compiled CUDA kernels for MLA tensor concatenation, concat_mla_k and concat_mla_absorb_q. The changes include the Python interface, the CUDA C++ kernel implementation, and comprehensive unit tests. The code is well-structured and the tests are thorough, verifying correctness against PyTorch and consistency with AOT versions. My feedback focuses on performance optimizations within the concat_mla_absorb_q_kernel by hoisting loop-invariant computations and suggesting the use of non-temporal memory operations for better performance, consistent with the other kernel in the file.
| constexpr int B_LAST_DIM = 64; | ||
| constexpr int OUT_LAST_DIM = A_LAST_DIM + B_LAST_DIM; | ||
|
|
||
| __global__ void concat_mla_absorb_q_kernel( |
There was a problem hiding this comment.
The memory access pattern in this kernel is streaming, similar to concat_mla_k_kernel. For better performance and consistency, consider using non-temporal load/store intrinsics (e.g., ld_na_global_..., st_na_global_...) for a, b, and out tensors. This can reduce L1 cache pollution. Using these would likely require adding v4 variants for int4 types to the memory utilities, or composing them from existing v1 and v2 variants.
| #pragma unroll | ||
| for (int i = 0; i < A_NUM_UNROLL; ++i) { | ||
| const ABufType* base_addr = reinterpret_cast<ABufType*>(a + idx_0 * a_stride_0 + idx_1 * a_stride_1); | ||
| a_buf[i] = *(base_addr + i * 32 + lane_id); | ||
| } |
There was a problem hiding this comment.
For performance, the base address calculation, which is loop-invariant, should be hoisted out of the loop. This avoids redundant calculations in each iteration.
const ABufType* base_addr = reinterpret_cast<ABufType*>(a + idx_0 * a_stride_0 + idx_1 * a_stride_1);
#pragma unroll
for (int i = 0; i < A_NUM_UNROLL; ++i) {
a_buf[i] = *(base_addr + i * 32 + lane_id);
}
| #pragma unroll | ||
| for (int i = 0; i < A_NUM_UNROLL; ++i) { | ||
| ABufType* base_addr = reinterpret_cast<ABufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1); | ||
| *(base_addr + i * 32 + lane_id) = a_buf[i]; | ||
| } |
There was a problem hiding this comment.
Similarly, the base address calculation for the output tensor should be hoisted out of the loop to avoid redundant computations.
ABufType* base_addr = reinterpret_cast<ABufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1);
#pragma unroll
for (int i = 0; i < A_NUM_UNROLL; ++i) {
*(base_addr + i * 32 + lane_id) = a_buf[i];
}
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 97084edb58
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| TensorMatcher({N0_out, N1_out, D_out}) | ||
| .with_strides({S0_out, S1_out, 1}) | ||
| .with_dtype<bf16_t>() | ||
| .with_device<kDLCUDA>(device) | ||
| .verify(out); |
There was a problem hiding this comment.
Require output dims to match inputs in concat_mla_absorb_q
The output tensor’s leading dimensions are verified with independent symbols (N0_out, N1_out) but never required to match a/b. The kernel computes num_items and indexing from a’s sizes and then writes into out, so a caller that passes a preallocated out with smaller or different first dimensions will cause out-of-bounds writes or corrupted results. Add a RuntimeCheck tying N0_out/N1_out to N0_a/N1_a (or derive indexing from out) to prevent this.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Pull request overview
This pull request adds JIT-compiled CUDA kernels for Multi-head Latent Attention (MLA) tensor concatenation operations, supporting models like DeepSeek-V2/V3/R1.
Changes:
- Added Python interface for two JIT kernels:
concat_mla_kandconcat_mla_absorb_q - Implemented optimized CUDA kernels with warp-level parallelism and memory access optimizations
- Added comprehensive unit tests verifying correctness against PyTorch reference and AOT kernel implementations
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| python/sglang/jit_kernel/concat_mla.py | Python interface providing module loading, error handling, and public API for the JIT kernels |
| python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh | CUDA kernel implementations with memory utilities, tensor validation, and optimized memory access patterns |
| python/sglang/jit_kernel/tests/test_concat_mla.py | Unit tests comparing JIT kernels against PyTorch reference and AOT implementations with various input sizes |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
We should apply this kernel to deepseek v3/r1. |
|
@DarkSharpness Any other advices? |
|
Do you have any benchmark result @celve ? |
Benchmark results on H100: |
|
@celve fix lint |
fixed |
|
/tag-and-rerun-ci |
Motivation
Add JIT-compiled CUDA kernels for MLA tensor concatenation:
Modifications
Accuracy Tests
Verified against PyTorch implementation and AOT sgl_kernel.
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci