Skip to content

Comments

[DeepseekV32]: use _concat_mla_absorb_q_general to replace torch.cat#12215

Merged
Fridge003 merged 5 commits intosgl-project:mainfrom
bingps:dsv32-q-cat
Nov 12, 2025
Merged

[DeepseekV32]: use _concat_mla_absorb_q_general to replace torch.cat#12215
Fridge003 merged 5 commits intosgl-project:mainfrom
bingps:dsv32-q-cat

Conversation

@bingps
Copy link
Contributor

@bingps bingps commented Oct 27, 2025

Motivation

#11989
torch.cat([q_nope, q_rope], dim=-1) is heavily used in nsa_backend, which is less efficient.

Modifications

Use existed kernel _concat_mla_absorb_q_general to replace.

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bingps, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the DeepseekV32 attention mechanism by addressing an identified performance bottleneck. It replaces the less efficient torch.cat operation, which was frequently used for concatenating query tensors (q_nope and q_rope) within the nsa_backend, with a specialized and more performant kernel, _concat_mla_absorb_q_general. This modification aims to improve the overall efficiency of attention computations by leveraging an existing optimized kernel.

Highlights

  • Performance Optimization: Replaced multiple instances of torch.cat with the more efficient _concat_mla_absorb_q_general kernel within the nsa_backend for DeepseekV32 attention mechanisms.
  • Kernel Integration: Integrated the _concat_mla_absorb_q_general kernel from trtllm_mla_backend to handle the concatenation of q_nope and q_rope tensors.
  • Affected Functions: The change impacts the forward_extend and forward_decode functions, specifically under various NSA_PREFILL_IMPL and NSA_DECODE_IMPL conditions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces torch.cat with _concat_mla_absorb_q_general in several locations within nsa_backend.py to improve efficiency. The changes involve importing the new function and substituting it in the forward_extend and forward_decode methods. I have identified a potential issue where the new function might not be universally applicable, and have raised a critical concern.

@@ -793,7 +794,7 @@ def forward_extend(
)
if NSA_PREFILL_IMPL == "tilelang":
if q_rope is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change replaces torch.cat with _concat_mla_absorb_q_general. It's crucial to ensure that _concat_mla_absorb_q_general is functionally equivalent to torch.cat in all scenarios where it's being used. If _concat_mla_absorb_q_general does not handle all cases correctly, it could lead to incorrect results or unexpected behavior. Can you provide more details on the scenarios where _concat_mla_absorb_q_general is guaranteed to work correctly and any limitations it might have?

Given the potential for widespread impact, this is a critical issue.

                q_all = _concat_mla_absorb_q_general(q_nope, q_rope) # Ensure this function is equivalent to torch.cat in all cases

@Fridge003
Copy link
Collaborator

Fridge003 commented Oct 27, 2025

@bingps Can you please post some accuracy results, as well as performance benchmarks before/after this change?

@bingps
Copy link
Contributor Author

bingps commented Oct 28, 2025

@bingps Can you please post some accuracy results, as well as performance benchmarks before/after this change?

Sure~ Here is a simple benchmark.

import torch
import triton
from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general


if __name__ == "__main__":
    H = 128
    D_NOPE = 512
    D_ROPE = 64

    for S in [1, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 30000]:
        print(f"============ Test _concat_mla_absorb_q_general on {S=} ============")
        q_nope = torch.randn(S, H, D_NOPE, dtype=torch.bfloat16, device="cuda")
        q_rope = torch.randn(S, H, D_ROPE, dtype=torch.bfloat16, device="cuda")

        q_ref = torch.cat([q_nope, q_rope], dim=-1)
        q = _concat_mla_absorb_q_general(q_nope, q_rope)
        assert torch.equal(q_ref, q)

        quantiles = [0.5, 0.2, 0.8]
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch.cat([q_nope, q_rope], dim=-1),
            quantiles=quantiles,
        )
        print(f"Torch  cat: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")

        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: _concat_mla_absorb_q_general(q_nope, q_rope),
            quantiles=quantiles,
        )
        print(f"concat_mla: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")

The concat results are bit-wise equal, and it is 2.4x faster for 16K concat, while not slower when S=1.

============ Test _concat_mla_absorb_q_general on S=1 ============
Torch  cat: median    7 us, min    7 us, max    7 us
concat_mla: median    6 us, min    6 us, max    7 us
============ Test _concat_mla_absorb_q_general on S=128 ============
Torch  cat: median   32 us, min   32 us, max   32 us
concat_mla: median   18 us, min   17 us, max   18 us
============ Test _concat_mla_absorb_q_general on S=256 ============
Torch  cat: median   57 us, min   57 us, max   57 us
concat_mla: median   28 us, min   28 us, max   29 us
============ Test _concat_mla_absorb_q_general on S=512 ============
Torch  cat: median  106 us, min  106 us, max  107 us
concat_mla: median   49 us, min   48 us, max   50 us
============ Test _concat_mla_absorb_q_general on S=1024 ============
Torch  cat: median  206 us, min  206 us, max  206 us
concat_mla: median   90 us, min   90 us, max   91 us
============ Test _concat_mla_absorb_q_general on S=2048 ============
Torch  cat: median  405 us, min  405 us, max  406 us
concat_mla: median  174 us, min  172 us, max  174 us
============ Test _concat_mla_absorb_q_general on S=4096 ============
Torch  cat: median  804 us, min  803 us, max  804 us
concat_mla: median  340 us, min  340 us, max  341 us
============ Test _concat_mla_absorb_q_general on S=8192 ============
Torch  cat: median 1599 us, min 1599 us, max 1600 us
concat_mla: median  672 us, min  671 us, max  673 us
============ Test _concat_mla_absorb_q_general on S=16384 ============
Torch  cat: median 3190 us, min 3183 us, max 3191 us
concat_mla: median 1337 us, min 1337 us, max 1338 us

@bingps
Copy link
Contributor Author

bingps commented Oct 28, 2025

However in my testing, the kernel fails for ~30K inputs 😿

See #12250

@Johnsonms
Copy link
Contributor

However in my testing, the kernel fails for ~30K inputs 😿

See #12250

me too

@Fridge003
Copy link
Collaborator

Might wait for another sgl-kernel bump to include #12453

Copy link
Collaborator

@Fridge003 Fridge003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for CI

@Fridge003
Copy link
Collaborator

@Fridge003 Fridge003 merged commit 4eda996 into sgl-project:main Nov 12, 2025
146 of 185 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants