Skip to content

[feature] implement dcp for deepseek_v2#14194

Open
staugust wants to merge 22 commits intosgl-project:mainfrom
antgroup:yjh/dcp-dev-main
Open

[feature] implement dcp for deepseek_v2#14194
staugust wants to merge 22 commits intosgl-project:mainfrom
antgroup:yjh/dcp-dev-main

Conversation

@staugust
Copy link

@staugust staugust commented Dec 1, 2025

Motivation

Here's the first step to fully implement #12196 to support much longer context with TP 8 under 8xH20.
Currently, it only works with attention backend flashinfer. It's compatible with chunked-prefill and decode cuda graph. It doesn't support radix-cache, pd disaggregation and mtp.

update 2025-12-04 21:30
prefix cache supported.

Modifications

Details in Decode-Context Parallelism (DCP) for DeepSeek-v2.

Modifications in compute attn_output

With dcp, kv cache are split into dcp ranks.

  1. Each dcp rank computes attn_output for absorbed_q with num_tp_local_q_heads * dcp_size heads and partial kv cache. Hence, an all_gather is introduced to let each tp rank inside a dcp group has total absorbed_q.
  2. After we got attn_output and lse for total absorbed_q and partial kv, lses are gathered via dcp group to correct attn_output with computed scaling factor.
  3. Finally, a reduce_scatter of corrected_attn_output makes each tp rank keeps final attn_output for tp_local_q_heads and full kv, just like what pure tp does.

Here a simple computation workflow for deepseek_v2 with tp+dcp:
image

Cache Management

To minimize changes to SGLang's core logic, we implemented a new DCPTokenToKVPoolAllocator, and let TokenToKVPool keeps real_kv_size no changed.

  1. Virtual Capacity Expansion
    Allocatable KV size is real_kv_size * dcp_world_size.
    When ScheduleBatch checks free space or allocates KV buffer, DCPTokenToKVPoolAllocator behaves as if it can allocate real_kv_size * dcp_world_size KV caches.
  2. Index Allocation Strategy
    Allocates one KV buffer index for each token in the request, and aligns each request with dcp_world_size * original_page_size as the alignment unit. This ensures token position of corresponding kv cache, token_idx % dcp_world_size equals to out_cache_loc % dcp_world_size, which simplifies the mapping logic of out_cache_loc and its real index in TokenToKVPool.

Here's an exmaple, with page_size=1, dcp_world_size=4, two requests (r1: green, r2: yellow):
● r1 indices: [0, 1, 2, 3, 8, 9]
● r2 indices: [4, 5, 6]
All ranks see consistent out_cache_loc values, but keeps different kv cache.
image

When read/write kv cache from/into kv_buffer, the code likes below:

# Filter cache locations handled by current rank
filtered_local_indices = forward_batch.out_cache_loc[
    (forward_batch.out_cache_loc % dcp_world_size) == dcp_rank
    ]

# Calculate actual positions in KV buffer
actual_locations = filtered_local_indices // dcp_world_size

# Store KV Cache
kv_buffer[actual_locations] = kv[filtered_local_indices]

# Read KV Cache
out_cache_loc = req_to_token_pool[req_idx][:req.seq_len]
filtered_local_indices = out_cache_loc[
    (out_cache_loc % dcp_world_size) == dcp_rank
    ]
actual_locations = filtered_local_indices // dcp_world_size

kv_buffer[actual_localtions]

Accuracy Tests

benchmark/gsm8k/bench_sglang.py with dcp8 and chunked-prefill, radix cache enabled.

100%|████████████| 200/200 [00:30<00:00,  6.49it/s]
Accuracy: 0.975
Invalid: 0.000
Latency: 31.077 s
Output throughput: 609.384 token/s

benchmark/gsm8k/bench_sglang.py with tp8, sglang commit: e8ba5a6

100%|████████████| 200/200 [00:24<00:00,  8.07it/s]
Accuracy: 0.970
Invalid: 0.000
Latency: 25.061 s
Output throughput: 751.645 token/s

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @staugust, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces Decode Context Parallel (DCP) capabilities, primarily targeting the deepseek_v2 model, to enhance performance and scalability for large context windows. The changes involve establishing a new distributed group for DCP, implementing an interleaved KV cache storage strategy across ranks, and modifying attention mechanisms to perform collective communication for accurate distributed computation. A new Triton kernel is utilized for correcting attention outputs, and the entire system is validated with a new, comprehensive set of unit tests.

Highlights

  • Decode Context Parallel (DCP) Infrastructure: Introduced core components for Decode Context Parallel (DCP), including new distributed groups, environment variable-based configuration, and helper functions to manage DCP ranks and world sizes. This lays the groundwork for distributing context processing across multiple devices.
  • Interleaved KV Cache Storage: Implemented a novel interleaved storage mechanism for the KV cache, where tokens are distributed across DCP ranks based on their position modulo the world size. A new DcpTokenToKVPoolAllocator manages this distribution, ensuring each rank handles a specific subset of KV cache entries.
  • Distributed Attention Computation for DeepSeek V2: Adapted the attention mechanism, specifically for the deepseek_v2 model, to support DCP. This involves partitioning attention heads, performing all-gather operations on query components, and using a custom Triton kernel with reduce-scatter to correctly combine and normalize attention outputs from distributed computations.
  • Attention Output Correction with Log-Sum-Exp (LSE): Developed a Triton kernel and Python utility functions to correct attention outputs using all-gathered Log-Sum-Exp (LSE) values. This ensures numerical stability and correctness when attention is computed in a distributed fashion.
  • Comprehensive Unit Testing for Interleaved Storage: Added a dedicated and extensive unit test suite (test_dcp_interleaved_storage.py) to rigorously validate the correctness of the interleaved KV cache allocation, deallocation, and indexing logic across various DCP configurations and scenarios.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Decode Context Parallelism (DCP) for the deepseek_v2 model, which is a significant enhancement for distributed inference. The implementation correctly integrates DCP by interleaving KV cache storage across ranks and adjusting attention calculations. Key changes include modifications to parallel state management, a new DcpTokenToKVPoolAllocator for interleaved KV cache, and updates to attention backends (flashinfer_mla_backend.py, flashmla_backend.py, deepseek_v2.py) to handle DCP-specific logic for query components and attention output. A new test file test_dcp_interleaved_storage.py has been added, providing good coverage for the new allocator logic. Overall, the changes are well-structured and address the requirements for DCP.

# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()

assert input_tensor.shape[0] % world_size == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion input_tensor.shape[0] % world_size == 0 implies a strict requirement that the input tensor's first dimension must be divisible by the world size. This is a critical constraint for the reduce_scatter_along_dim operation. Please add a docstring to the function clearly stating this precondition, or consider how to handle cases where this might not hold true (e.g., padding or a more flexible splitting strategy).

Comment on lines +369 to +371
assert (
self.dcp_world_size == 1
), "FlashMLA does not support DCP for FP8 kv cache"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion indicates a critical limitation: FlashMLA does not support DCP for FP8 KV cache. This should be clearly documented in the model's capabilities or a more user-friendly error message should be provided to the user if this configuration is attempted.

Comment on lines 673 to 674
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment highlights a potential issue with reduce_scatter_tensor requiring contiguous input. If this is a known PyTorch bug, it would be beneficial to include a reference to the relevant PyTorch issue tracker. If it's a potential bug in the current implementation, further investigation might be warranted.

)
group_ranks.append(ranks)

# message queue broadcaster is only used in tensor model parallel group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment "message queue broadcaster is only used in tensor model parallel group" is directly above the initialization of the _DCP group. If the use_message_queue_broadcaster argument is also relevant for DCP, this comment might be misleading. Please clarify if DCP also utilizes the message queue broadcaster, or if this argument is redundant for DCP initialization.

Comment on lines +471 to +472
# Compute local lengths following the same formula as filter_seq_indices.
kv_len_arr_cpu = ((kv_len_arr_cpu - dcp_rank - 1) // dcp_world_size) + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment "Compute local lengths following the same formula as filter_seq_indices" refers to filter_seq_indices, which is defined later in FlashInferMLAIndicesUpdaterDecode. For better code organization and to avoid potential inconsistencies, consider defining filter_seq_indices as a standalone helper function or a static method that can be easily reused and referenced.

return self.real_allocator.free(free_index)

def filter_local_indices(self, indices):
# TODO write a triton kernel to make this faster
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment TODO write a triton kernel to make this faster is a good note for future optimization. It indicates a potential performance bottleneck that could be improved by offloading the filtering logic to a Triton kernel.

else:
self.rotary_emb = None

# TODO(augusto.yjh) 这里要改逻辑, local_heads是all heads, 而且还要返回lse,用来修正attn_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment TODO(augusto.yjh) 这里要改逻辑, local_heads是all heads, 而且还要返回lse,用来修正attn_out is in Chinese. For consistency and clarity within the codebase, please translate this comment to English.

layer_id=self.layer_id,
)

# TODO(augusto.yjh) 这里要all_gather q_pe 和 q_node_out,以 tp8为例, [1, 8, 64] [1, 8, 512] 经过all gather后为 [1, 64, 64] [1, 64, 512], k_pe 为 [1, 1, 64], k_nope 为 [1, 1, 512], 从 local heads到all heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment TODO(augusto.yjh) 这里要all_gather q_pe 和 q_node_out,以 tp8为例, [1, 8, 64] [1, 64, 64] [1, 64, 512], k_pe 为 [1, 1, 64], k_nope 为 [1, 1, 512], 从 local heads到all heads is in Chinese. For consistency and clarity within the codebase, please translate this comment to English.

}

attn_output = self.attn_mqa(
# TODO(augusto.yjh) 返回lse, correct attn_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment TODO(augusto.yjh) 返回lse, correct attn_output is in Chinese. For consistency and clarity within the codebase, please translate this comment to English.

Comment on lines +2000 to +2001
# TODO(augusto.yjh) all gather lse,订正attn_output
# TODO(augusto.yjh) 执行reduce scatter, 先reduce拿到正确的 attn_output, 再按local_num_heads scatter attn_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comments TODO(augusto.yjh) all gather lse,订正attn_output and TODO(augusto.yjh) 执行reduce scatter, 先reduce拿到正确的 attn_output, 再按local_num_heads scatter attn_output are in Chinese. For consistency and clarity within the codebase, please translate these comments to English.

@staugust staugust changed the title [WIP] dcp for deepseek_v2 [feature] implement dcp for deepseek_v2 Dec 3, 2025
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when try to run the code I got

File "/root/sglang-ant/python/sglang/srt/model_executor/model_runner.py", line 2068
    else:
    ^^^^
SyntaxError: invalid syntax

wondering is this else block aligned @staugust

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check modifications , maybe there's something wrong when rebasing to main branch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sophie8 Fixed, I'll do both performance and speed benchmark, and paste result later.

@heroes999
Copy link

Really need this feature. vLLM TP+DP+DCP has performance issues.

@staugust
Copy link
Author

staugust commented Dec 8, 2025

Really need this feature. vLLM TP+DP+DCP has performance issues.

Thank you. Could you share your usage scenario? For example, details about the model architecture, GPU type, and parameters such as TP/DP/DCP? It's complicated when enable TP+DP+DCP.

@staugust staugust requested a review from hanming-lu as a code owner December 8, 2025 02:50
@heroes999
Copy link

heroes999 commented Dec 8, 2025

Really need this feature. vLLM TP+DP+DCP has performance issues.

Thank you. Could you share your usage scenario? For example, details about the model architecture, GPU type, and parameters such as TP/DP/DCP? It's complicated when enable TP+DP+DCP.

Sure. Our usage scenario is 128K long context or huge batch size on 64GB-HBM GPU(not nvidia, but is compatible),without DCP, the KVCache cost is unaffordable. We'd like to tune performance with 16 gpus using dp8tp2dcp2(or dp2tp8dcp8, moe tp16).

PS: Is there some performance data of long context on H20?

@staugust
Copy link
Author

staugust commented Dec 8, 2025

Really need this feature. vLLM TP+DP+DCP has performance issues.

Thank you. Could you share your usage scenario? For example, details about the model architecture, GPU type, and parameters such as TP/DP/DCP? It's complicated when enable TP+DP+DCP.

Sure. Our usage scenario is 128K long context or huge batch size on 64GB-HBM GPU(not nvidia, but is compatible),without DCP, the KVCache cost is unaffordable. We'd like to tune performance with 16 gpus using dp8tp2dcp2(or dp2tp8dcp8, moe tp16).

@heroes999 Got it, I've updated the code, maybe you can have a try to see whether it works. For now, the extra communication operations introduced by DCP works, but not tuned. I have just rebased the code, and I'll post a speed benchmark later.


# build decode context parallel groups
decode_context_model_parallel_size = get_dcp_size_from_env()
if decode_context_model_parallel_size > 1:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may can move the check to the argument part to give users better guidance about how to enable dcp and constraints of enablement?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, define new cli parameter in server_args.py should be better.

grid = (B, H, 1)

regular_args = (
out,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a dumb q, just confused here, both the pointers for outputs_ptr and new_output_ptr are out, is this intended?

@lavdnone2
Copy link

What about tp 4?

@staugust
Copy link
Author

What about tp 4?

tp_size must be divisible by dcp_size. With tp_size set to 4, dcp_size can be 4, 2, 1. Here's an example of how tp_size 4 and dcp_size 2 works.

The model has H attention heads.
Tensor Parallel (TP): splits heads
DCP: splits kv cache
Example:
H = 16 heads
TP = 4
DCP = 2

TP splits heads:
TP0: heads [0..3]
TP1: heads [4..7]
TP2: heads [8..11]
TP3: heads [12..15]

DCP splits tokens:
Rank assignment = (token_idx % dcp_world_size == dcp_rank)

Final mapping:

TP0+DP0:
GPU0: q_absorb[0..3], kv_cache_tokens [0,2,4,..]

TP1+DP1:
GPU1: q_absorb[4..7], kv_cache_tokens [1,3,5,..]

TP2+DP0:
GPU2: q_absorb[8..11], kv_cache_tokens [0,2,4,..]

TP3+DP1:
GPU3: q_absorb[12..15], kv_cache_tokens [1,3,5,..]

@Kangyan-Zhou
Copy link
Collaborator

/tag-and-rerun-ci

staugust and others added 17 commits January 8, 2026 09:47
correct params for forward_extend in flashinfer_mla
fix bugs in set dcp_kv_indptr
make chunked req align with dcp_world_size
fix bugs in compute attn for deepseek_v2
estimate pages for dcp with page_size * dcp_world_size
re-org kv indice with same order

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
also gather k_pe
calculate real kv indice
make  prefix_chunk_len align to dcp_world_size
all gather prefix cache kv  which is aligned to dcp_world_size
fix bugs in fetch extend prefix kv cache
fix bugs in gather kv for mha one shot
fix bugs when rebase to main branch
fix pre-commit ast errors
return attn_out and lse when forward_batch is decode otherwise return attn_out without lse
only return lse for dcp mla
correct conditions to return lse for decode

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
…n is incompatible with TP group symm-mem. Modifications will be made after the resolution of the multi-group symmetric memory coexistence issue.)

misc: remove unneed code after rebase

fix: fix ar coredump when dcp use symmetric memory

fea: add symm-mem unit perf test
@staugust
Copy link
Author

staugust commented Jan 8, 2026

Here's performance benchmark with model Kimi-K2-Instruct-0905 on Nvidia H20-3e, comparing with tp8 and attenion dp8+tp8.

  1. Under the same batch size, the prefill/decode throughput of DCP8+TP8 remains almost the same as TP8, DP8+TP8, with no noticeable performance drop.
  2. DCP8+TP8 enables a larger batch size, raising the peak request throughput from 0.48 (DP8+TP8) to 0.65.

parallel strategy max_concurrency req/s input tok/s output tok/s peak output tok/s mean tpot (ms)
dcp8+tp8 64 0.65 2601.77 975.66 1280.00 59.04
dcp8+tp8 48 0.57 2283.99 856.50 1104.00 50.93
dcp8+tp8 32 0.48 1920.62 720.23 888.00 40.75
dcp8+tp8 8 0.25 1002.05 375.77 416.00 20.03
tp8 8 0.25 1015.66 380.87 439.00 19.71
dp8+tp8 32 0.48 1933.49 725.06 896.00 39.83

launch command

MODEL=/home/models/moonshotai/Kimi-K2-Instruct-0905
SGLANG_DCP_SYMM_ONLY=true \
SGLANG_DCP=8 \
NCCL_DEBUG=WARN \
PYTHONUNBUFFERED=1 \
TORCHINDUCTOR_FX_GRAPH_CACHE=1 \
TORCHINDUCTOR_AUTOGRAD_CACHE=1 \
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \
TORCHINDUCTOR_CACHE_DIR=/home/admin/inductor_root_cache \
nohup python3 -m sglang.launch_server \
  --model-path ${MODEL} \
  --host 0.0.0.0 \
  --port 8188 \
  --trust-remote-code \
  --enable-cache-report \
  --log-level info \
  --tp-size 8 \
  --max-running-requests 48 \
  --mem-fraction-static 0.90 \
  --chunked-prefill-size 32768 \
  --context-length 262144 \
  --attention-backend flashinfer \
  --disable-radix-cache \
  --enable-symm-mem \
  &>> /home/local/workspace/scripts/sglang.out 2>&1 &

Bench Result
dp8+tp8 , max concurrency 64, modify mem fraction to 0.92 to enable large bs

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1000000.0
Max request concurrency:                 64
Successful requests:                     512
Benchmark duration (s):                  787.16
Total input tokens:                      2048000
Total input text tokens:                 2048000
Total generated tokens:                  768000
Total generated tokens (retokenized):    702661
Request throughput (req/s):              0.65
Input token throughput (tok/s):          2601.77
Output token throughput (tok/s):         975.66
Peak output token throughput (tok/s):    1280.00
Peak concurrent requests:                128
Total token throughput (tok/s):          3577.43
Concurrency:                             63.99
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   98376.54
Median E2E Latency (ms):                 96720.70
P90 E2E Latency (ms):                    109090.95
P99 E2E Latency (ms):                    109104.15
---------------Time to First Token----------------
Mean TTFT (ms):                          9869.27
Median TTFT (ms):                        10171.33
P99 TTFT (ms):                           17175.59
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          59.04
Median TPOT (ms):                        58.73
P99 TPOT (ms):                           71.16
---------------Inter-Token Latency----------------
Mean ITL (ms):                           58.96
Median ITL (ms):                         52.30
P95 ITL (ms):                            69.06
P99 ITL (ms):                            90.80
Max ITL (ms):                            16411.23
==================================================

dcp8+tp8 max concurrency 48

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1000000.0
Max request concurrency:                 48
Successful requests:                     480
Benchmark duration (s):                  840.64
Total input tokens:                      1920000
Total input text tokens:                 1920000
Total generated tokens:                  720000
Total generated tokens (retokenized):    658869
Request throughput (req/s):              0.57
Input token throughput (tok/s):          2283.99
Output token throughput (tok/s):         856.50
Peak output token throughput (tok/s):    1104.00
Peak concurrent requests:                96
Total token throughput (tok/s):          3140.48
Concurrency:                             47.99
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   84049.63
Median E2E Latency (ms):                 82368.98
P90 E2E Latency (ms):                    86933.74
P99 E2E Latency (ms):                    95900.47
---------------Time to First Token----------------
Mean TTFT (ms):                          7710.26
Median TTFT (ms):                        7704.56
P99 TTFT (ms):                           12933.05
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          50.93
Median TPOT (ms):                        50.41
P99 TPOT (ms):                           62.39
---------------Inter-Token Latency----------------
Mean ITL (ms):                           50.86
Median ITL (ms):                         46.31
P95 ITL (ms):                            53.37
P99 ITL (ms):                            81.61
Max ITL (ms):                            12354.65
==================================================

dcp8+tp8 max concurrency 32

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1000000.0
Max request concurrency:                 32
Successful requests:                     480
Benchmark duration (s):                  999.68
Total input tokens:                      1920000
Total input text tokens:                 1920000
Total generated tokens:                  720000
Total generated tokens (retokenized):    653258
Request throughput (req/s):              0.48
Input token throughput (tok/s):          1920.62
Output token throughput (tok/s):         720.23
Peak output token throughput (tok/s):    888.00
Peak concurrent requests:                64
Total token throughput (tok/s):          2640.85
Concurrency:                             32.00
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   66635.55
Median E2E Latency (ms):                 66163.16
P90 E2E Latency (ms):                    68120.88
P99 E2E Latency (ms):                    70367.68
---------------Time to First Token----------------
Mean TTFT (ms):                          5546.29
Median TTFT (ms):                        5316.55
P99 TTFT (ms):                           8680.60
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          40.75
Median TPOT (ms):                        40.63
P99 TPOT (ms):                           45.33
---------------Inter-Token Latency----------------
Mean ITL (ms):                           40.72
Median ITL (ms):                         38.32
P95 ITL (ms):                            41.12
P99 ITL (ms):                            51.45
Max ITL (ms):                            8120.84
==================================================

dcp8+tp8 max concurrency 8

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1000000.0
Max request concurrency:                 8
Successful requests:                     480
Benchmark duration (s):                  1916.08
Total input tokens:                      1920000
Total input text tokens:                 1920000
Total generated tokens:                  720000
Total generated tokens (retokenized):    658729
Request throughput (req/s):              0.25
Input token throughput (tok/s):          1002.05
Output token throughput (tok/s):         375.77
Peak output token throughput (tok/s):    416.00
Peak concurrent requests:                16
Total token throughput (tok/s):          1377.81
Concurrency:                             8.00
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   31932.18
Median E2E Latency (ms):                 31827.15
P90 E2E Latency (ms):                    32111.67
P99 E2E Latency (ms):                    35255.66
---------------Time to First Token----------------
Mean TTFT (ms):                          1908.91
Median TTFT (ms):                        1976.45
P99 TTFT (ms):                           2300.13
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          20.03
Median TPOT (ms):                        19.91
P99 TPOT (ms):                           22.22
---------------Inter-Token Latency----------------
Mean ITL (ms):                           20.01
Median ITL (ms):                         19.68
P95 ITL (ms):                            20.24
P99 ITL (ms):                            23.95
Max ITL (ms):                            1623.93
==================================================

tp8 max concurrency 8

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1000000.0
Max request concurrency:                 8
Successful requests:                     480
Benchmark duration (s):                  1890.40
Total input tokens:                      1920000
Total input text tokens:                 1920000
Total generated tokens:                  720000
Total generated tokens (retokenized):    656560
Request throughput (req/s):              0.25
Input token throughput (tok/s):          1015.66
Output token throughput (tok/s):         380.87
Peak output token throughput (tok/s):    439.00
Peak concurrent requests:                16
Total token throughput (tok/s):          1396.53
Concurrency:                             8.00
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   31504.31
Median E2E Latency (ms):                 31492.23
P90 E2E Latency (ms):                    31677.08
P99 E2E Latency (ms):                    32740.69
---------------Time to First Token----------------
Mean TTFT (ms):                          1956.22
Median TTFT (ms):                        2027.92
P99 TTFT (ms):                           2354.99
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          19.71
Median TPOT (ms):                        19.63
P99 TPOT (ms):                           20.79
---------------Inter-Token Latency----------------
Mean ITL (ms):                           19.70
Median ITL (ms):                         19.50
P95 ITL (ms):                            20.70
P99 ITL (ms):                            23.01
Max ITL (ms):                            1853.03
==================================================

attention dp8+tp8 max concurrency 48 , update mem fraction to 0.94 , otherwise, no cuda mem for KV cache

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    1000000.0
Max request concurrency:                 32
Successful requests:                     480
Benchmark duration (s):                  993.02
Total input tokens:                      1920000
Total input text tokens:                 1920000
Total generated tokens:                  720000
Total generated tokens (retokenized):    658354
Request throughput (req/s):              0.48
Input token throughput (tok/s):          1933.49
Output token throughput (tok/s):         725.06
Peak output token throughput (tok/s):    896.00
Peak concurrent requests:                64
Total token throughput (tok/s):          2658.55
Concurrency:                             32.00
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   66194.06
Median E2E Latency (ms):                 65721.34
P90 E2E Latency (ms):                    67355.91
P99 E2E Latency (ms):                    72074.54
---------------Time to First Token----------------
Mean TTFT (ms):                          6491.52
Median TTFT (ms):                        6501.04
P99 TTFT (ms):                           9620.72
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          39.83
Median TPOT (ms):                        39.55
P99 TPOT (ms):                           45.23
---------------Inter-Token Latency----------------
Mean ITL (ms):                           39.80
Median ITL (ms):                         37.63
P95 ITL (ms):                            49.92
P99 ITL (ms):                            60.57
Max ITL (ms):                            8045.31
==================================================

@heroes999
Copy link

dcp+tp with full graph supported?

@staugust
Copy link
Author

@heroes999 yes, full cuda graph is supported.

@heroes999
Copy link

I read the dev doc, what is the meaning of "remove redundant communication"?
image

@staugust
Copy link
Author

@heroes999 Each tp rank has to keep compressed KV cache for models with MLA attention, with dcp, each dcp rank only keeps part of full compressed KV cache.

@Rythsman
Copy link

I read the dev doc, what is the meaning of "remove redundant communication"? image

@heroes999 I have made some additions to the document, hoping they will help you understand. cc @staugust

@Yangxinhub
Copy link

is support tp+dcp+mtp? or are there any plans to support it?

@staugust
Copy link
Author

@Yangxinhub For now, this pr does not support tp+dcp+mtp. We'll support tp+dcp+mtp after this pr is reviewed and merged into main branch.

@ec-jt
Copy link

ec-jt commented Jan 29, 2026

@heroes999 Each tp rank has to keep compressed KV cache for models with MLA attention, with dcp, each dcp rank only keeps part of full compressed KV cache.

PP Support?

@Rythsman
Copy link

Rythsman commented Feb 3, 2026

@heroes999 Each tp rank has to keep compressed KV cache for models with MLA attention, with dcp, each dcp rank only keeps part of full compressed KV cache.

PP Support?

I’ve enabled TP+DCP+PP before, and it worked without any issues. You can give it a try too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants

Comments