[Fix] concat_mla_absorb_q_kernel fails for long inputs#12453
[Fix] concat_mla_absorb_q_kernel fails for long inputs#12453Fridge003 merged 2 commits intosgl-project:mainfrom
concat_mla_absorb_q_kernel fails for long inputs#12453Conversation
Summary of ChangesHello @bingps, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical integer overflow bug within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly identifies and fixes a critical integer overflow bug in the concat_mla_absorb_q_kernel that occurs with long input sequences. The fix, which involves changing the data type of stride parameters from int to int64_t, is appropriate and effectively resolves the issue. While reviewing, I noticed a similar potential overflow issue in another kernel within the same file, which I've detailed in a specific comment. Addressing this would improve the robustness of the code.
| const int64_t a_stride_0, | ||
| const int a_stride_1, | ||
| const int b_stride_0, | ||
| const int64_t b_stride_0, | ||
| const int b_stride_1, | ||
| const int out_stride_0, | ||
| const int64_t out_stride_0, | ||
| const int out_stride_1) { |
There was a problem hiding this comment.
This change correctly prevents integer overflow for large inputs by using int64_t for strides.
A similar potential overflow issue seems to exist in concat_mla_k_kernel. The strides k_stride_0, k_nope_stride_0, and k_rope_stride_0 are defined as int (lines 21, 23, 25), but their product with token_id inside the kernel could overflow if num_tokens is large.
For example, with k_stride_0 being 24576, an overflow would occur if num_tokens exceeds 2,147,483,647 / 24576 ~= 87,375. While current tests might not reach this limit, it's a potential failure point for longer sequences.
It would be beneficial to also change these stride types to int64_t to prevent this issue and ensure consistency across the kernels.
There was a problem hiding this comment.
@bingps Agreed, it's also beneficial to change the datatype for concat_mla_k_kernel
Can you please fix it here
There was a problem hiding this comment.
@Fridge003 Updated~ Tests are in the following comment.
|
Test if __name__ == "__main__":
H = 128
D_NOPE = 128
D_ROPE = 64
for S in [16384, 32768, 65536, 131072]:
print(f"============ Test concat_mla_k on {S=} ============")
k_nope = torch.randn(S, H, D_NOPE, dtype=torch.bfloat16, device="cuda")
k_rope = torch.randn(S, 1, D_ROPE, dtype=torch.bfloat16, device="cuda")
assert k_nope.is_contiguous() and k_rope.is_contiguous()
def ref_concat_k(k, k_nope, k_rope):
k[..., :D_NOPE] = k_nope
k[..., D_NOPE:] = k_rope
k_ref = torch.empty(S, H, D_NOPE + D_ROPE, dtype=torch.bfloat16, device="cuda")
ref_concat_k(k_ref, k_nope, k_rope)
k = torch.empty(S, H, D_NOPE + D_ROPE, dtype=torch.bfloat16, device="cuda")
concat_mla_k(k, k_nope, k_rope)
assert torch.equal(k_ref, k)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ref_concat_k(k_ref, k_nope, k_rope),
quantiles=quantiles,
)
print(f"Torch cat: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: concat_mla_k(k, k_nope, k_rope),
quantiles=quantiles,
)
print(f"concat_mla: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")Before After |
|
https://github.com/sgl-project/sglang/actions/runs/19016288024?pr=12453 |
Motivation
#12250
Modifications
This is caused by int32 overflow in addressing like
BBufType* base_addr = reinterpret_cast<BBufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1 + A_LAST_DIM);.Specifically, the
out_stride_0can be 128*576=73728, for S=30000, theidx_0 * out_stride_0can be 30000*73728=2,211,840,000, which exceeds int32 limit (2,147,483,647).Fixed by indicating int64 strides to make the implicit cast to int64.
Accuracy Tests
Bit-wise equal with torch.cat
Benchmarking and Profiling
Before:
After
Checklist