Skip to content

Comments

[DeepSeek V3.2] Enable trtllm NSA with bf16 kvcache#16758

Merged
Fridge003 merged 4 commits intosgl-project:mainfrom
akhilg-nv:ds_sparse_trtllm_gen
Jan 23, 2026
Merged

[DeepSeek V3.2] Enable trtllm NSA with bf16 kvcache#16758
Fridge003 merged 4 commits intosgl-project:mainfrom
akhilg-nv:ds_sparse_trtllm_gen

Conversation

@akhilg-nv
Copy link
Contributor

@akhilg-nv akhilg-nv commented Jan 8, 2026

Motivation

Enables NSA backend with trtllm kernels for sparse attention. This can be more efficient than FlashMLA when the head size isn't a multiple of 64 and hence requires padding. This PR enables with BF16 KVCache, FP8 will follow in another PR.

Modifications

This change interfaces with the new kernel added in flashinfer to use trtllm kernel for decode in NSA backend. There are also modifications made to server_args.py to enable the new option.

Accuracy Tests

python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --nsa-decode-backend trtllm --kv-cache-dtype bfloat16 --reasoning-parser deepseek-v3 --trust-remote-code

python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 120000 --repeat 8 --thinking-mode deepseek-v3 |& tee logs/trtllm_mla_sparse_gpqa.log

Repeat: 8, mean: 0.853
Scores: ['0.854', '0.854', '0.859', '0.879', '0.848', '0.838', '0.859', '0.833']

Testing MTP --speculative-algorithm EAGLE --speculative-num-steps 2 --speculative-eagle-topk 1 --speculative-num-draft-tokens 3

GPQA:

Repeat: 8, mean: 0.84898 [1:19:07<04:19, 23.58s/it]
Scores: ['0.838', '0.854', '0.869', '0.818', '0.848', '0.843', '0.874', '0.838']

Avg accept len: 2.3740
Min accept len: 1.9800
Max accept len: 2.8500

Avg accept rate: 0.7913
Min accept rate: 0.6600
Max accept rate: 0.9500

LongBenchV2:
python -m sglang.test.run_eval \ --eval-name longbench_v2 \ --host 127.0.0.1 \ --port 30000 \ --model deepseek-ai/DeepSeek-V3.2 \ --max-context-length 128000 \ --max-tokens 16384 \ --num-threads 16

Total latency: 1313.224 s
Score: 0.550

Avg accept len: 2.5248
Min accept len: 2.0000
Max accept len: 2.9800

Avg accept rate: 0.8419
Min accept rate: 0.6700
Max accept rate: 0.9900

Benchmarking and Profiling

Trace shows the new kernel call
(about 14 microseconds):

image image

Trace with the FlashMLA path with bf16 KVCache shows the FlashMLA sparse kernel is slower (about 26.5 microseconds):

image image

Comparing FlashMLA sparse and TRTLLM (both with bf16 kvcache) shows similar average results for decode.

python3 -m sglang.bench_serving --model deepseek-ai/DeepSeek-V3.2 --warmup-requests 5 --max-concurrency 16 --num-prompts 80 --random-range-ratio 0.8 --random-input-len 4096 --random-output-len 4096 --profile --profile-by-stage

FlashMLA Sparse (BF16 KVCache)

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 16        
Successful requests:                     80        
Benchmark duration (s):                  57.26     
Total input tokens:                      29016     
Total input text tokens:                 29016     
Total generated tokens:                  17012     
Total generated tokens (retokenized):    16921     
Request throughput (req/s):              1.40      
Input token throughput (tok/s):          506.74    
Output token throughput (tok/s):         297.10    
Peak output token throughput (tok/s):    576.00    
Peak concurrent requests:                21        
Total token throughput (tok/s):          803.83    
Concurrency:                             10.87     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   7780.15   
Median E2E Latency (ms):                 5522.24   
P90 E2E Latency (ms):                    17340.74  
P99 E2E Latency (ms):                    33557.02  
---------------Time to First Token----------------
Mean TTFT (ms):                          348.70    
Median TTFT (ms):                        276.23    
P99 TTFT (ms):                           683.36    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          34.78     
Median TPOT (ms):                        36.56     
P99 TPOT (ms):                           57.35     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           35.21     
Median ITL (ms):                         28.07     
P95 ITL (ms):                            151.14    
P99 ITL (ms):                            185.14    
Max ITL (ms):                            333.13   

TRTLLM Sparse NSA

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 16        
Successful requests:                     80        
Benchmark duration (s):                  58.32     
Total input tokens:                      29016     
Total input text tokens:                 29016     
Total generated tokens:                  17012     
Total generated tokens (retokenized):    16922     
Request throughput (req/s):              1.37      
Input token throughput (tok/s):          497.55    
Output token throughput (tok/s):         291.71    
Peak output token throughput (tok/s):    592.00    
Peak concurrent requests:                21        
Total token throughput (tok/s):          789.27    
Concurrency:                             11.27     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   8217.63   
Median E2E Latency (ms):                 6536.17   
P90 E2E Latency (ms):                    18150.71  
P99 E2E Latency (ms):                    33348.26  
---------------Time to First Token----------------
Mean TTFT (ms):                          659.34    
Median TTFT (ms):                        287.29    
P99 TTFT (ms):                           2945.19   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          37.39     
Median TPOT (ms):                        37.52     
P99 TPOT (ms):                           90.29     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           35.82     
Median ITL (ms):                         27.18     
P95 ITL (ms):                            152.73    
P99 ITL (ms):                            189.19    
Max ITL (ms):                            2460.31 

Unit benchmarking shows higher bandwidth for the new trtllm kernel compared to flashmla.

bf16kvcache_comparison_bandwidth_vs_seqlen_bar

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments (/tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci) or contact authorized users to do so.
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@akhilg-nv akhilg-nv changed the title Enable trtllm_gen NSA with bf16 kvcache [DeepSeek V3.2] Enable trtllm_gen NSA with bf16 kvcache Jan 16, 2026
@agaction
Copy link

@Fridge003 Could you please take a look? I only just now saw #15546, seems like the changes are similar though here I am using concat_mla_absorb_q_general

@Fridge003
Copy link
Collaborator

Fridge003 commented Jan 19, 2026

@akhilg-nv Please fix the conflict

@akhilg-nv akhilg-nv force-pushed the ds_sparse_trtllm_gen branch from 393f224 to 0c2da34 Compare January 19, 2026 20:38
@Fridge003
Copy link
Collaborator

@akhilg-nv Please fix lint with

pre-commit install
pre-commit run --all-files

@Fridge003
Copy link
Collaborator

/tag-and-rerun-ci

@Fridge003
Copy link
Collaborator

/rerun-failed-ci

@akhilg-nv akhilg-nv force-pushed the ds_sparse_trtllm_gen branch 4 times, most recently from 5622465 to 03ac8f3 Compare January 22, 2026 23:56
@akhilg-nv akhilg-nv force-pushed the ds_sparse_trtllm_gen branch from 03ac8f3 to 2162e7e Compare January 22, 2026 23:58
@akhilg-nv akhilg-nv changed the title [DeepSeek V3.2] Enable trtllm_gen NSA with bf16 kvcache [DeepSeek V3.2] Enable trtllm NSA with bf16 kvcache Jan 22, 2026
@akhilg-nv akhilg-nv force-pushed the ds_sparse_trtllm_gen branch from 2162e7e to 306072d Compare January 23, 2026 00:37
@Fridge003
Copy link
Collaborator

/tag-and-rerun-ci

@Fridge003 Fridge003 dismissed their stale review January 23, 2026 12:20

resolved

@Fridge003 Fridge003 merged commit 2fb3281 into sgl-project:main Jan 23, 2026
208 of 225 checks passed
@mmangkad
Copy link
Contributor

mmangkad commented Jan 23, 2026

The PR mentions speculative decoding testing - I tried but got:

[2026-01-23 18:37:35 TP3] Scheduler hit an exception: Traceback (most recent call last):
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 2937, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 346, in __init__
    self.init_model_worker()
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 535, in init_model_worker
    self.init_tp_model_worker()
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 497, in init_tp_model_worker
    self.tp_worker = TpModelWorker(
                     ^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/managers/tp_worker.py", line 246, in __init__
    self._init_model_runner()
  File "/sglang/python/sglang/srt/managers/tp_worker.py", line 329, in _init_model_runner
    self._model_runner = ModelRunner(
                         ^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 390, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 581, in initialize
    self.kernel_warmup()
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 1706, in kernel_warmup
    self._flashinfer_autotune()
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 1742, in _flashinfer_autotune
    self._dummy_run(batch_size=self.req_to_token_pool.size)
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 1986, in _dummy_run
    run_once()
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 1976, in run_once
    logits_output_or_pp_proxy_tensors = self.model.forward(
                                        ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 2898, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 2711, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 2385, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 1375, in forward
    return self.forward_core(s)
           ^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 1463, in forward_core
    return self.forward_absorb_core(*inner_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 1768, in forward_absorb_core
    attn_output = self.attn_mqa(
                  ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/radix_attention.py", line 124, in forward
    return forward_batch.attn_backend.forward(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/attention/base_attn_backend.py", line 113, in forward
    return self.forward_extend(
           ^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/attention/nsa_backend.py", line 1343, in forward_extend
    raise ValueError(f"Unsupported {nsa_impl = }")
ValueError: Unsupported nsa_impl = 'trtllm'

It looks like the extend path wasn’t updated for trtllm.

Edit: Opened a quick follow‑up PR (#17662) to address this.

Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Co-authored-by: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants