Skip to content

[Bug] Memory Leak Check Fails in Qwen3-Next PD Disaggregation Due to Mamba Size Mismatch #12432

@xjx471258437

Description

@xjx471258437

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Bug Description

The Qwen3-Next PD disaggregation task can be successfully deployed. However, when running bench_serving, the prefill stage fails with the following error:

[2025-10-31 10:40:43 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2797, in run_scheduler_process
    scheduler.event_loop_overlap_disagg_prefill()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/disaggregation/prefill.py", line 371, in event_loop_overlap_disagg_prefill
    self.self_check_during_idle()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 214, in self_check_during_idle
    self.check_memory()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 147, in check_memory
    raise ValueError(msg)
ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=1663695, full_evictable_size=1066, self.token_to_kv_pool_allocator.size=1664761, self.tree_cache.full_protected_size()=0
mamba_available_size=473, mamba_evictable_size=4, self.req_to_token_pool.mamba_pool.size=476, self.tree_cache.mamba_protected_size()=0
NotImplementedError

Root Cause Analysis

After investigating the code history, the issue appears to be introduced by PR #11214 (merged on 2025-10-13), which added memory validation logic for the Mamba size in the check_memory() function.

Before 2025-10-11

on Qwen3-Next PD disaggregation branch support_hybrid_pd
The check_memory() function did not include Mamba size checks. As a result, PD disaggregation tasks ran successfully.

    def check_memory(self):
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                _,
                _,
                full_available_size,
                full_evictable_size,
                swa_available_size,
                swa_evictable_size,
            ) = self._get_swa_token_info()
            memory_leak = full_num_used != 0 or swa_num_used != 0
            token_msg = (
                f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
                f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
            )
        else:
            _, _, available_size, evictable_size = self._get_token_info()
            protected_size = self.tree_cache.protected_size()
            memory_leak = (available_size + evictable_size) != (
                # self.max_total_num_tokens
                # if not self.enable_hierarchical_cache
                # else self.max_total_num_tokens - protected_size
                self.max_total_num_tokens
                - protected_size
            )
            token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"

After 2025-10-13

The new logic in check_memory() includes a branch for self.is_hybrid_gdn with MambaRadixCache, which asserts:

    def _check_mamba_memory(self: Scheduler):
        (
            full_num_used,
            mamba_num_used,
            _,
            _,
            full_available_size,
            full_evictable_size,
            mamba_available_size,
            mamba_evictable_size,
        ) = self._get_mamba_token_info()
        memory_leak = (
            full_num_used != self.tree_cache.full_protected_size()
            or mamba_num_used != self.tree_cache.mamba_protected_size()
        )
        token_msg = (
            f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
            f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
        )
        return memory_leak, token_msg

    def check_memory(self: Scheduler):

        if self.is_hybrid:
            memory_leak, token_msg = self._check_hybrid_memory()
        elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
            memory_leak, token_msg = self._check_mamba_memory()
        else:
            memory_leak, token_msg = self._check_radix_cache_memory()

        if memory_leak:
            msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
            raise ValueError(msg)

However, in PD disaggregation scenarios, the Mamba pool accounting does not align with the protected size in the tree cache. Specifically:

self.req_to_token_pool.mamba_pool.size - (mamba_available_size + mamba_evictable_size) != self.tree_cache.mamba_protected_size()

This discrepancy causes memory_leak to be True, triggering the ValueError.

@yizhang2077 @ShangmingCai

Reproduction

Model: Qwen3-Next

prefill command:

export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
python -m sglang.launch_server     \
--model-path /models/qwen_next/Qwen3-Next-80B-A3B-Instruct     \
--disaggregation-mode prefill     \
--host 0.0.0.0     \
--port 8000     \
--tp-size 2     \
--trust-remote-code     \
--mem-fraction-static 0.8

decode command:

export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
python -m sglang.launch_server    \
--model-path /models/qwen_next/Qwen3-Next-80B-A3B-Instruct     \
--disaggregation-mode decode      \
--tp-size 2     \
--trust-remote-code     \
--host 0.0.0.0     \
--port 8001     \
--mem-fraction-static 0.8     \
--cuda-graph-max-bs 640

sglang_router command:

python3 -m sglang_router.launch_router    \
--pd-disaggregation     \
--prefill http://22.13.56.27:8000     \
--decode http://22.13.56.27:8001    \
--prefill-policy round_robin     \
--decode-policy round_robin     \
--policy round_robin     \
--host 0.0.0.0     \
--port 3000

and the same problem will occur when I use mini_lb or other launcher

Environment

Python: 3.12.3 (main, Jun 18 2025, 17:59:45) [GCC 13.3.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H20-3e
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.8, V12.8.93
CUDA Driver Version: 570.133.20
PyTorch: 2.8.0+cu128
sglang: 0.5.4.post1
sgl_kernel: 0.3.16.post4
flashinfer_python: 0.4.1
triton: 3.4.0
transformers: 4.57.1
torchao: 0.9.0
numpy: 2.3.2
aiohttp: 3.12.15
fastapi: 0.116.1
hf_transfer: 0.1.9
huggingface_hub: 0.35.3
interegular: 0.3.3
modelscope: 1.28.2
orjson: 3.11.1
outlines: 0.1.11
packaging: 25.0
psutil: 7.0.0
pydantic: 2.11.7
python-multipart: 0.0.20
pyzmq: 27.0.1
uvicorn: 0.35.0
uvloop: 0.21.0
vllm: Module Not Found
xgrammar: 0.1.25
openai: 2.6.1
tiktoken: 0.10.0
anthropic: 0.61.0
litellm: 1.75.0
decord2: 2.0.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 PIX PHB NODE NODE 0-79 0-1 N/A
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 PXB PHB NODE NODE 0-79 0-1 N/A
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 PHB PIX NODE NODE 0-79 0-1 N/A
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 PHB PXB NODE NODE 0-79 0-1 N/A
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 NODE NODE PIX PHB 0-159 0-1 N/A
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 NODE NODE PXB PHB 0-159 0-1 N/A
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 NODE NODE PHB PIX 0-159 0-1 N/A
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X NODE NODE PHB PXB 0-159 0-1 N/A
NIC0 PIX PXB PHB PHB NODE NODE NODE NODE X PHB NODE NODE
NIC1 PHB PHB PIX PXB NODE NODE NODE NODE PHB X NODE NODE
NIC2 NODE NODE NODE NODE PIX PXB PHB PHB NODE NODE X PHB
NIC3 NODE NODE NODE NODE PHB PHB PIX PXB NODE NODE PHB X

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

NIC Legend:

NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3

Hypervisor vendor: KVM
ulimit soft: 102400

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions