Skip to content

[Fix] concat_mla_absorb_q_kernel fails for long inputs#12453

Merged
Fridge003 merged 2 commits intosgl-project:mainfrom
bingps:fix-concat-mla-q
Nov 2, 2025
Merged

[Fix] concat_mla_absorb_q_kernel fails for long inputs#12453
Fridge003 merged 2 commits intosgl-project:mainfrom
bingps:fix-concat-mla-q

Conversation

@bingps
Copy link
Contributor

@bingps bingps commented Oct 31, 2025

Motivation

#12250

Modifications

This is caused by int32 overflow in addressing like BBufType* base_addr = reinterpret_cast<BBufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1 + A_LAST_DIM);.

Specifically, the out_stride_0 can be 128*576=73728, for S=30000, the idx_0 * out_stride_0 can be 30000*73728=2,211,840,000, which exceeds int32 limit (2,147,483,647).

Fixed by indicating int64 strides to make the implicit cast to int64.

Accuracy Tests

Bit-wise equal with torch.cat

import torch
import triton
from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general

if __name__ == "__main__":
    H = 128
    D_NOPE = 512
    D_ROPE = 64

    # for S in [1, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]:
        print(f"============ Test _concat_mla_absorb_q_general on {S=} ============")
        q_nope = torch.randn(S, H, D_NOPE, dtype=torch.bfloat16, device="cuda")
        q_rope = torch.randn(S, H, D_ROPE, dtype=torch.bfloat16, device="cuda")
        assert q_nope.is_contiguous() and q_rope.is_contiguous()

        q_ref = torch.cat([q_nope, q_rope], dim=-1)
        q = _concat_mla_absorb_q_general(q_nope, q_rope)
        assert torch.equal(q_ref, q)

        quantiles = [0.5, 0.2, 0.8]
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch.cat([q_nope, q_rope], dim=-1),
            quantiles=quantiles,
        )
        print(f"Torch  cat: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")

        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: _concat_mla_absorb_q_general(q_nope, q_rope),
            quantiles=quantiles,
        )
        print(f"concat_mla: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")

Benchmarking and Profiling

Before:

============ Test _concat_mla_absorb_q_general on S=1 ============
Torch  cat: median    7 us, min    6 us, max    7 us
concat_mla: median    6 us, min    6 us, max    6 us
============ Test _concat_mla_absorb_q_general on S=128 ============
Torch  cat: median   32 us, min   32 us, max   32 us
concat_mla: median   17 us, min   17 us, max   18 us
============ Test _concat_mla_absorb_q_general on S=256 ============
Torch  cat: median   57 us, min   57 us, max   57 us
concat_mla: median   29 us, min   28 us, max   29 us
============ Test _concat_mla_absorb_q_general on S=512 ============
Torch  cat: median  106 us, min  106 us, max  107 us
concat_mla: median   49 us, min   48 us, max   50 us
============ Test _concat_mla_absorb_q_general on S=1024 ============
Torch  cat: median  206 us, min  206 us, max  206 us
concat_mla: median   91 us, min   90 us, max   91 us
============ Test _concat_mla_absorb_q_general on S=2048 ============
Torch  cat: median  405 us, min  405 us, max  406 us
concat_mla: median  174 us, min  173 us, max  174 us
============ Test _concat_mla_absorb_q_general on S=4096 ============
Torch  cat: median  803 us, min  803 us, max  803 us
concat_mla: median  339 us, min  339 us, max  340 us
============ Test _concat_mla_absorb_q_general on S=8192 ============
Torch  cat: median 1599 us, min 1599 us, max 1599 us
concat_mla: median  673 us, min  672 us, max  673 us
============ Test _concat_mla_absorb_q_general on S=16384 ============
Torch  cat: median 3189 us, min 3182 us, max 3190 us
concat_mla: median 1346 us, min 1342 us, max 1351 us

After

============ Test _concat_mla_absorb_q_general on S=1 ============
Torch  cat: median    6 us, min    6 us, max    7 us
concat_mla: median    6 us, min    6 us, max    6 us
============ Test _concat_mla_absorb_q_general on S=128 ============
Torch  cat: median   32 us, min   32 us, max   32 us
concat_mla: median   17 us, min   17 us, max   18 us
============ Test _concat_mla_absorb_q_general on S=256 ============
Torch  cat: median   57 us, min   57 us, max   57 us
concat_mla: median   28 us, min   28 us, max   29 us
============ Test _concat_mla_absorb_q_general on S=512 ============
Torch  cat: median  106 us, min  106 us, max  107 us
concat_mla: median   49 us, min   48 us, max   50 us
============ Test _concat_mla_absorb_q_general on S=1024 ============
Torch  cat: median  206 us, min  206 us, max  206 us
concat_mla: median   91 us, min   90 us, max   91 us
============ Test _concat_mla_absorb_q_general on S=2048 ============
Torch  cat: median  405 us, min  405 us, max  406 us
concat_mla: median  174 us, min  173 us, max  175 us
============ Test _concat_mla_absorb_q_general on S=4096 ============
Torch  cat: median  803 us, min  803 us, max  803 us
concat_mla: median  340 us, min  340 us, max  341 us
============ Test _concat_mla_absorb_q_general on S=8192 ============
Torch  cat: median 1599 us, min 1599 us, max 1599 us
concat_mla: median  674 us, min  673 us, max  674 us
============ Test _concat_mla_absorb_q_general on S=16384 ============
Torch  cat: median 3190 us, min 3190 us, max 3191 us
concat_mla: median 1351 us, min 1348 us, max 1360 us
============ Test _concat_mla_absorb_q_general on S=32768 ============
Torch  cat: median 9140 us, min 9139 us, max 9144 us
concat_mla: median 2676 us, min 2675 us, max 2676 us

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bingps, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical integer overflow bug within the concat_mla_absorb_q_kernel CUDA function. The issue arose when calculating memory addresses for long input sequences, where intermediate stride products exceeded the maximum value for a 32-bit integer. By upgrading the relevant stride parameters to 64-bit integers, the kernel can now correctly process significantly larger inputs, ensuring the stability and reliability of operations involving multi-layer attention (MLA) without any degradation in accuracy or performance.

Highlights

  • Bug Fix: Addresses an int32 overflow issue in the concat_mla_absorb_q_kernel function, which previously caused failures when processing long input sequences due to stride calculations exceeding 32-bit integer limits.
  • Robustness Improvement: Modifies the kernel's stride parameters (a_stride_0, b_stride_0, out_stride_0) from int to int64_t to correctly handle large memory offsets and prevent future overflow issues for very long inputs.
  • Accuracy and Performance: Confirms bit-wise equality with torch.cat and demonstrates consistent or improved performance across various input sizes, including very long sequences (S=32768), after the fix.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly identifies and fixes a critical integer overflow bug in the concat_mla_absorb_q_kernel that occurs with long input sequences. The fix, which involves changing the data type of stride parameters from int to int64_t, is appropriate and effectively resolves the issue. While reviewing, I noticed a similar potential overflow issue in another kernel within the same file, which I've detailed in a specific comment. Addressing this would improve the robustness of the code.

Comment on lines +129 to 134
const int64_t a_stride_0,
const int a_stride_1,
const int b_stride_0,
const int64_t b_stride_0,
const int b_stride_1,
const int out_stride_0,
const int64_t out_stride_0,
const int out_stride_1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change correctly prevents integer overflow for large inputs by using int64_t for strides.

A similar potential overflow issue seems to exist in concat_mla_k_kernel. The strides k_stride_0, k_nope_stride_0, and k_rope_stride_0 are defined as int (lines 21, 23, 25), but their product with token_id inside the kernel could overflow if num_tokens is large.

For example, with k_stride_0 being 24576, an overflow would occur if num_tokens exceeds 2,147,483,647 / 24576 ~= 87,375. While current tests might not reach this limit, it's a potential failure point for longer sequences.

It would be beneficial to also change these stride types to int64_t to prevent this issue and ensure consistency across the kernels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bingps Agreed, it's also beneficial to change the datatype for concat_mla_k_kernel
Can you please fix it here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridge003 Updated~ Tests are in the following comment.

@bingps
Copy link
Contributor Author

bingps commented Nov 2, 2025

Test concat_mla_k with

if __name__ == "__main__":
    H = 128
    D_NOPE = 128
    D_ROPE = 64

    for S in [16384, 32768, 65536, 131072]:
        print(f"============ Test concat_mla_k on {S=} ============")
        k_nope = torch.randn(S, H, D_NOPE, dtype=torch.bfloat16, device="cuda")
        k_rope = torch.randn(S, 1, D_ROPE, dtype=torch.bfloat16, device="cuda")
        assert k_nope.is_contiguous() and k_rope.is_contiguous()

        def ref_concat_k(k, k_nope, k_rope):
            k[..., :D_NOPE] = k_nope
            k[..., D_NOPE:] = k_rope

        k_ref = torch.empty(S, H, D_NOPE + D_ROPE, dtype=torch.bfloat16, device="cuda")
        ref_concat_k(k_ref, k_nope, k_rope)
        k = torch.empty(S, H, D_NOPE + D_ROPE, dtype=torch.bfloat16, device="cuda")
        concat_mla_k(k, k_nope, k_rope)
        assert torch.equal(k_ref, k)

        quantiles = [0.5, 0.2, 0.8]
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: ref_concat_k(k_ref, k_nope, k_rope),
            quantiles=quantiles,
        )
        print(f"Torch  cat: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")

        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: concat_mla_k(k, k_nope, k_rope),
            quantiles=quantiles,
        )
        print(f"concat_mla: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")

Before

============ Test concat_mla_k on S=16384 ============
Torch  cat: median 1450 us, min 1450 us, max 1452 us
concat_mla: median  395 us, min  395 us, max  396 us
============ Test concat_mla_k on S=32768 ============
Torch  cat: median 2890 us, min 2889 us, max 2891 us
concat_mla: median  778 us, min  778 us, max  779 us
============ Test concat_mla_k on S=65536 ============
Torch  cat: median 5779 us, min 5777 us, max 5780 us
concat_mla: median 1544 us, min 1543 us, max 1545 us
============ Test concat_mla_k on S=131072 ============
Traceback (most recent call last):
  File "/root/sglang-fork/python/test_cat.py", line 55, in <module>
    assert torch.equal(k_ref, k)
           ^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1

After

============ Test concat_mla_k on S=16384 ============ 
Torch  cat: median 1451 us, min 1450 us, max 1452 us
concat_mla: median  396 us, min  395 us, max  396 us
============ Test concat_mla_k on S=32768 ============
Torch  cat: median 2891 us, min 2890 us, max 2893 us 
concat_mla: median  778 us, min  778 us, max  779 us
============ Test concat_mla_k on S=65536 ============
Torch  cat: median 5780 us, min 5779 us, max 5780 us
concat_mla: median 1544 us, min 1543 us, max 1545 us
============ Test concat_mla_k on S=131072 ============
Torch  cat: median 11557 us, min 11555 us, max 11559 us
concat_mla: median 3073 us, min 3073 us, max 3074 us

@Fridge003
Copy link
Collaborator

https://github.com/sgl-project/sglang/actions/runs/19016288024?pr=12453
The correctness can be ensured by GB200 test

@Fridge003 Fridge003 merged commit 15ed27d into sgl-project:main Nov 2, 2025
128 of 143 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants