[DeepseekV32]: use _concat_mla_absorb_q_general to replace torch.cat#12215
[DeepseekV32]: use _concat_mla_absorb_q_general to replace torch.cat#12215Fridge003 merged 5 commits intosgl-project:mainfrom
_concat_mla_absorb_q_general to replace torch.cat#12215Conversation
Summary of ChangesHello @bingps, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on optimizing the DeepseekV32 attention mechanism by addressing an identified performance bottleneck. It replaces the less efficient Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request replaces torch.cat with _concat_mla_absorb_q_general in several locations within nsa_backend.py to improve efficiency. The changes involve importing the new function and substituting it in the forward_extend and forward_decode methods. I have identified a potential issue where the new function might not be universally applicable, and have raised a critical concern.
| @@ -793,7 +794,7 @@ def forward_extend( | |||
| ) | |||
| if NSA_PREFILL_IMPL == "tilelang": | |||
| if q_rope is not None: | |||
There was a problem hiding this comment.
This change replaces torch.cat with _concat_mla_absorb_q_general. It's crucial to ensure that _concat_mla_absorb_q_general is functionally equivalent to torch.cat in all scenarios where it's being used. If _concat_mla_absorb_q_general does not handle all cases correctly, it could lead to incorrect results or unexpected behavior. Can you provide more details on the scenarios where _concat_mla_absorb_q_general is guaranteed to work correctly and any limitations it might have?
Given the potential for widespread impact, this is a critical issue.
q_all = _concat_mla_absorb_q_general(q_nope, q_rope) # Ensure this function is equivalent to torch.cat in all cases|
@bingps Can you please post some accuracy results, as well as performance benchmarks before/after this change? |
Sure~ Here is a simple benchmark. import torch
import triton
from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general
if __name__ == "__main__":
H = 128
D_NOPE = 512
D_ROPE = 64
for S in [1, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 30000]:
print(f"============ Test _concat_mla_absorb_q_general on {S=} ============")
q_nope = torch.randn(S, H, D_NOPE, dtype=torch.bfloat16, device="cuda")
q_rope = torch.randn(S, H, D_ROPE, dtype=torch.bfloat16, device="cuda")
q_ref = torch.cat([q_nope, q_rope], dim=-1)
q = _concat_mla_absorb_q_general(q_nope, q_rope)
assert torch.equal(q_ref, q)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.cat([q_nope, q_rope], dim=-1),
quantiles=quantiles,
)
print(f"Torch cat: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _concat_mla_absorb_q_general(q_nope, q_rope),
quantiles=quantiles,
)
print(f"concat_mla: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")The concat results are bit-wise equal, and it is 2.4x faster for 16K concat, while not slower when S=1. |
|
However in my testing, the kernel fails for ~30K inputs 😿 See #12250 |
me too |
|
Might wait for another sgl-kernel bump to include #12453 |
|
Passed all dpsk v32 tests |
Motivation
#11989
torch.cat([q_nope, q_rope], dim=-1)is heavily used innsa_backend, which is less efficient.Modifications
Use existed kernel
_concat_mla_absorb_q_generalto replace.Checklist