Skip to content

[bug fix] fix ima with get_mla_kv_buffer_kernel overflow#14224

Merged
Fridge003 merged 1 commit intosgl-project:mainfrom
XucSh:fix
Dec 4, 2025
Merged

[bug fix] fix ima with get_mla_kv_buffer_kernel overflow#14224
Fridge003 merged 1 commit intosgl-project:mainfrom
XucSh:fix

Conversation

@XucSh
Copy link
Collaborator

@XucSh XucSh commented Dec 1, 2025

Motivation

when test with below command on H20:

server:
SGLANG_PP_LAYER_PARTITION="3,3,4,4,4,4,4,4,4,4,4,4,4,4,4,3" CUDA_LAUNCH_BLOCKING=1 python3 -m sglang.launch_server --model-path /work/models/DeepSeek3.1/ --nnodes 2 --port 30000 --dist-init-addr 26.5.27.241:62001 --node-rank 0 --tp 1 --pp-size 16 --trust-remote-code --disable-radix-cache --mem-fraction-static 0.8 --max-running-requests 512 --chunked-prefill-size 6144 --attention-backend fa3 --watchdog-timeout 3600 --host 0.0.0.0

SGLANG_PP_LAYER_PARTITION="3,3,4,4,4,4,4,4,4,4,4,4,4,4,4,3" CUDA_LAUNCH_BLOCKING=1 python3 -m sglang.launch_server --model-path /work/models/DeepSeek3.1/ --nnodes 2 --port 30000 --dist-init-addr 26.5.27.241:62001 --node-rank 1 --tp 1 --pp-size 16 --trust-remote-code --disable-radix-cache --mem-fraction-static 0.8 --max-running-requests 512 --chunked-prefill-size 6144 --attention-backend fa3 --watchdog-timeout 3600

client:
python3 -m sglang.bench_serving --host 127.0.0.1 --port 30000 --dataset-path /root/.cache/modelscope/hub/datasets/gliang1001/ShareGPT_V3_unfiltered_cleaned_split/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompt 512 --random-input 13107 --random-output 1 --request-rate 10 --max-concurrency 512 --warmup-requests 0 --backend sglang --dataset-name random --random-range-ratio 1

An ima will show up with below stack:

[2025-12-01 09:06:35 PP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/root/sgl-dev/sglang/python/sglang/srt/managers/scheduler.py", line 2767, in run_scheduler_process
    scheduler.event_loop_pp()
  File "/root/sgl-dev/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/managers/scheduler_pp_mixin.py", line 35, in event_loop_pp
    result = self.run_batch(self.cur_batch)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/managers/scheduler.py", line 1955, in run_batch
    batch_result = self.model_worker.forward_batch_generation(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/managers/tp_worker.py", line 419, in forward_batch_generation
    pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/model_executor/model_runner.py", line 2206, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/model_executor/model_runner.py", line 2257, in _forward_raw
    ret = self.forward_extend(
          ^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/model_executor/model_runner.py", line 2151, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/models/deepseek_v2.py", line 3033, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/root/sgl-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/models/deepseek_v2.py", line 2886, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/root/sgl-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/models/deepseek_v2.py", line 2615, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/root/sgl-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/models/deepseek_v2.py", line 1344, in forward
    return self.forward_core(s)
           ^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/models/deepseek_v2.py", line 1433, in forward_core
    return self.forward_normal_chunked_kv_core(*inner_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/models/deepseek_v2.py", line 2365, in forward_normal_chunked_kv_core
    attn_output = self._chunked_prefix_attn_mha(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/models/deepseek_v2.py", line 2299, in _chunked_prefix_attn_mha
    kv_a_normed, k_pe = self._get_mla_kv_buffer(
                        ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/models/deepseek_v2.py", line 2433, in _get_mla_kv_buffer
    kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/sglang/python/sglang/srt/mem_cache/memory_pool.py", line 1445, in get_mla_kv_buffer
    get_mla_kv_buffer_triton(kv_buffer, loc, cache_k_nope, cache_k_rope)
  File "/root/sgl-dev/sglang/python/sglang/srt/mem_cache/memory_pool.py", line 1262, in get_mla_kv_buffer_triton
    get_mla_kv_buffer_kernel[grid](
  File "/root/sgl-dev/lib/python3.12/site-packages/triton/runtime/jit.py", line 390, in
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sgl-dev/lib/python3.12/site-packages/triton/runtime/jit.py", line 617, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/root/sgl-dev/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 708, in call
    self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

Reported by @whybeyoung

Modifications

Fix it by extend loc to int64.

Checklist

Reported-by: ybyang <ybyang7@iflytek.com>
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @XucSh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical RuntimeError: Triton Error [CUDA]: an illegal memory access that was observed during SGLang operations, specifically within the memory management utilities. The fix ensures the stability and correct memory handling of the system by preventing integer overflow in memory location calculations, thereby improving the robustness of the SGLang runtime.

Highlights

  • Bug Fix: Addresses an "illegal memory access" (IMA) error occurring within the get_mla_kv_buffer_kernel function during specific SGLang server/client tests, which caused a RuntimeError: Triton Error [CUDA].
  • Data Type Correction: The loc variable, which is crucial for memory addressing within Triton kernels, has been explicitly cast to int64 to prevent potential integer overflow issues that were leading to the memory access errors.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@XucSh XucSh added the run-ci label Dec 1, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a critical illegal memory access error in the get_mla_kv_buffer_kernel Triton kernel, which was caused by an integer overflow during pointer arithmetic. Casting the loc variable to tl.int64 is the correct solution. I appreciate that you also proactively applied the same fix to set_mla_kv_buffer_kernel.

While reviewing, I noticed that set_mla_kv_scale_buffer_kernel in the same file (python/sglang/srt/mem_cache/utils.py at line 112) follows a similar pattern for calculating dst_ptr. It might be vulnerable to the same overflow issue. I recommend applying the same .to(tl.int64) cast to the loc variable in that kernel for consistency and to prevent potential future errors.

@XucSh
Copy link
Collaborator Author

XucSh commented Dec 1, 2025

/tag-and-rerun-ci

@XucSh XucSh changed the title [buf fix] fix ima with get_mla_kv_buffer_kernel overflow [bug fix] fix ima with get_mla_kv_buffer_kernel overflow Dec 1, 2025
@whybeyoung
Copy link
Collaborator

Lgtm

Copy link
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiezhq-hermann @BBuf @Fridge003 do you have time to check this PR? A bug we find when refactoring PP.

@xiezhq-hermann xiezhq-hermann self-assigned this Dec 2, 2025
@Fridge003 Fridge003 merged commit af35023 into sgl-project:main Dec 4, 2025
222 of 241 checks passed
yingluosanqian pushed a commit to yingluosanqian/sglang that referenced this pull request Dec 4, 2025
…#14224)

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 5, 2025
…#14224)

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 5, 2025
…#14224)

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
yuchengz816-bot pushed a commit to yuchengz816-bot/sglang that referenced this pull request Dec 8, 2025
…#14224)

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Kevin-XiongC pushed a commit to novitalabs/sglang that referenced this pull request Dec 9, 2025
…#14224)

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
GuoYechang pushed a commit to GuoYechang/sglang that referenced this pull request Jan 13, 2026
…#14224)

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants

Comments