Skip to content

Comments

[Move sgl-kernel Kernel to JIT] Add JIT concat MLA kernels#17889

Merged
BBuf merged 7 commits intosgl-project:mainfrom
celve:jit-concat-mla
Feb 3, 2026
Merged

[Move sgl-kernel Kernel to JIT] Add JIT concat MLA kernels#17889
BBuf merged 7 commits intosgl-project:mainfrom
celve:jit-concat-mla

Conversation

@celve
Copy link
Collaborator

@celve celve commented Jan 28, 2026

Motivation

Add JIT-compiled CUDA kernels for MLA tensor concatenation:

  • concat_mla_k
  • concat_mla_absorb_q

Modifications

  • python/sglang/jit_kernel/concat_mla.py: Python interface
  • python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh
  • python/sglang/jit_kernel/tests/test_concat_mla.py: Unit tests

Accuracy Tests

Verified against PyTorch implementation and AOT sgl_kernel.

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@celve celve requested a review from DarkSharpness as a code owner January 28, 2026 15:48
Copilot AI review requested due to automatic review settings January 28, 2026 15:48
@celve celve requested a review from BBuf as a code owner January 28, 2026 15:48
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @celve, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates performance-critical JIT-compiled CUDA kernels into the SGLang framework to accelerate tensor concatenation operations fundamental to Multi-head Latent Attention. By offloading these computations to optimized GPU kernels, the change aims to significantly improve the efficiency of MLA-related processes. The introduction is accompanied by robust testing to ensure functional correctness and compatibility with existing components.

Highlights

  • New JIT-compiled CUDA Kernels: Introduced two new JIT-compiled CUDA kernels: concat_mla_k and concat_mla_absorb_q, designed for efficient tensor concatenation in Multi-head Latent Attention (MLA) operations.
  • Python Interfaces: Provided Python interfaces in python/sglang/jit_kernel/concat_mla.py to allow easy integration and usage of the new CUDA kernels within the SGLang framework.
  • Optimized concat_mla_k: The concat_mla_k kernel efficiently broadcasts k_rope across all heads while directly copying k_nope values, optimizing a common MLA tensor manipulation pattern.
  • Optimized concat_mla_absorb_q: The concat_mla_absorb_q kernel performs concatenation of two tensors (a and b) specifically for MLA absorbed Q computation, aiming for performance gains.
  • Comprehensive Unit Testing: Added unit tests in python/sglang/jit_kernel/tests/test_concat_mla.py to verify the correctness of the JIT kernels against both PyTorch reference implementations and existing AOT (Ahead-Of-Time) compiled SGLang kernels, ensuring bitwise equivalence where applicable.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces JIT-compiled CUDA kernels for MLA tensor concatenation, concat_mla_k and concat_mla_absorb_q. The changes include the Python interface, the CUDA C++ kernel implementation, and comprehensive unit tests. The code is well-structured and the tests are thorough, verifying correctness against PyTorch and consistency with AOT versions. My feedback focuses on performance optimizations within the concat_mla_absorb_q_kernel by hoisting loop-invariant computations and suggesting the use of non-temporal memory operations for better performance, consistent with the other kernel in the file.

constexpr int B_LAST_DIM = 64;
constexpr int OUT_LAST_DIM = A_LAST_DIM + B_LAST_DIM;

__global__ void concat_mla_absorb_q_kernel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The memory access pattern in this kernel is streaming, similar to concat_mla_k_kernel. For better performance and consistency, consider using non-temporal load/store intrinsics (e.g., ld_na_global_..., st_na_global_...) for a, b, and out tensors. This can reduce L1 cache pollution. Using these would likely require adding v4 variants for int4 types to the memory utilities, or composing them from existing v1 and v2 variants.

Comment on lines +232 to +236
#pragma unroll
for (int i = 0; i < A_NUM_UNROLL; ++i) {
const ABufType* base_addr = reinterpret_cast<ABufType*>(a + idx_0 * a_stride_0 + idx_1 * a_stride_1);
a_buf[i] = *(base_addr + i * 32 + lane_id);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For performance, the base address calculation, which is loop-invariant, should be hoisted out of the loop. This avoids redundant calculations in each iteration.

  const ABufType* base_addr = reinterpret_cast<ABufType*>(a + idx_0 * a_stride_0 + idx_1 * a_stride_1);
#pragma unroll
  for (int i = 0; i < A_NUM_UNROLL; ++i) {
    a_buf[i] = *(base_addr + i * 32 + lane_id);
  }

Comment on lines +243 to +247
#pragma unroll
for (int i = 0; i < A_NUM_UNROLL; ++i) {
ABufType* base_addr = reinterpret_cast<ABufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1);
*(base_addr + i * 32 + lane_id) = a_buf[i];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly, the base address calculation for the output tensor should be hoisted out of the loop to avoid redundant computations.

  ABufType* base_addr = reinterpret_cast<ABufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1);
#pragma unroll
  for (int i = 0; i < A_NUM_UNROLL; ++i) {
    *(base_addr + i * 32 + lane_id) = a_buf[i];
  }

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 97084edb58

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +291 to +295
TensorMatcher({N0_out, N1_out, D_out})
.with_strides({S0_out, S1_out, 1})
.with_dtype<bf16_t>()
.with_device<kDLCUDA>(device)
.verify(out);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Require output dims to match inputs in concat_mla_absorb_q

The output tensor’s leading dimensions are verified with independent symbols (N0_out, N1_out) but never required to match a/b. The kernel computes num_items and indexing from a’s sizes and then writes into out, so a caller that passes a preallocated out with smaller or different first dimensions will cause out-of-bounds writes or corrupted results. Add a RuntimeCheck tying N0_out/N1_out to N0_a/N1_a (or derive indexing from out) to prevent this.

Useful? React with 👍 / 👎.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds JIT-compiled CUDA kernels for Multi-head Latent Attention (MLA) tensor concatenation operations, supporting models like DeepSeek-V2/V3/R1.

Changes:

  • Added Python interface for two JIT kernels: concat_mla_k and concat_mla_absorb_q
  • Implemented optimized CUDA kernels with warp-level parallelism and memory access optimizations
  • Added comprehensive unit tests verifying correctness against PyTorch reference and AOT kernel implementations

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
python/sglang/jit_kernel/concat_mla.py Python interface providing module loading, error handling, and public API for the JIT kernels
python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh CUDA kernel implementations with memory utilities, tensor validation, and optimized memory access patterns
python/sglang/jit_kernel/tests/test_concat_mla.py Unit tests comparing JIT kernels against PyTorch reference and AOT implementations with various input sizes

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@BBuf
Copy link
Collaborator

BBuf commented Jan 29, 2026

We should apply this kernel to deepseek v3/r1.

@BBuf
Copy link
Collaborator

BBuf commented Jan 29, 2026

@DarkSharpness Any other advices?

@DarkSharpness
Copy link
Collaborator

Do you have any benchmark result @celve ?

@celve
Copy link
Collaborator Author

celve commented Jan 30, 2026

Do you have any benchmark result @celve ?

Benchmark results on H100:

concat-mla-k-performance:
   num_tokens  SGL AOT Kernel  SGL JIT Kernel      PyTorch
0       256.0        5.475417        5.668444    15.477764
1       512.0       11.878400       11.931831    31.441071
2      1024.0       29.546901       32.658746    79.164883
3      2048.0       62.138934       60.133575   138.081330
4      4096.0      125.988555      132.837618   305.827204
5      8192.0      248.622000      231.221624   551.673912
6     16384.0      514.511665      526.252979  1185.922662
7     32768.0     1003.911972      924.476412  2152.489662
concat-mla-absorb-q-performance:
    dim_0  dim_1  SGL AOT Kernel  SGL JIT Kernel   PyTorch
0     1.0    1.0        1.693475        1.665798  2.882798
1     1.0    8.0        1.770887        1.777346  4.797667
2     1.0   32.0        2.511637        2.509455  4.951636
3     1.0  128.0        2.525998        2.513862  5.099835
4     4.0    1.0        1.675436        1.677608  4.361981
5     4.0    8.0        2.512797        2.509239  4.951825
6     4.0   32.0        2.531210        2.514974  5.102467
7     4.0  128.0        2.562861        2.590036  5.345432
8     8.0    1.0        1.772115        1.777019  4.801221
9     8.0    8.0        2.529712        2.523697  5.007935
10    8.0   32.0        2.561970        2.545978  5.166949
11    8.0  128.0        2.600290        2.594486  5.825511
12   16.0    1.0        2.020227        2.006557  4.937140
13   16.0    8.0        2.528193        2.514921  5.113745
14   16.0   32.0        2.564538        2.593589  5.372622
15   16.0  128.0        2.606951        2.608138  6.601896
16   32.0    1.0        2.523933        2.518549  4.940041
17   32.0    8.0        2.562259        2.546113  5.180123
18   32.0   32.0        2.605903        2.591558  5.857756
19   32.0  128.0        2.975260        2.977352  8.561575

@BBuf
Copy link
Collaborator

BBuf commented Feb 1, 2026

@celve fix lint

@celve
Copy link
Collaborator Author

celve commented Feb 1, 2026

@celve fix lint

fixed

@BBuf
Copy link
Collaborator

BBuf commented Feb 2, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Feb 2, 2026
@BBuf BBuf changed the title [Kernel] Add JIT concat MLA kernels [Move sgl-kernel Kernel to JIT] Add JIT concat MLA kernels Feb 3, 2026
@BBuf
Copy link
Collaborator

BBuf commented Feb 3, 2026

@BBuf BBuf merged commit 9b1619c into sgl-project:main Feb 3, 2026
200 of 217 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 5, 2026
sfiisf pushed a commit to sfiisf/sglang that referenced this pull request Feb 5, 2026
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants