Support batch size > 1 when enable CP#23269
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
0ef797e to
e98e47b
Compare
|
/tag-run-ci-label |
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
There was a problem hiding this comment.
@Shunkangz Thank you for the contribution, I left some reviews. There will be another round of more thorough review. We can discuss in the slack channel as well.
Couple of things I want to call out:
- During development of MLA CP, I found out two notable bugs in current MHA CP impl. 1 is and 2 is
sglang/python/sglang/srt/layers/attention/flashattention_backend.py
Lines 492 to 508 in 5466e9f
sglang/python/sglang/srt/layers/utils/cp_utils.py
Lines 471 to 472 in 5466e9f
- We need MHA CP zigzag mode test for
attn_cp_size == 4andbs > 1
Thank you for pointing this out. In this PR, I only want to support the bs > 1 with GQA model. For MLA CP part, I left it as original implementation. I believe that the MLA CP should be refactored and aligned with our existing logic such as args, layer communicator and so on. |
Ah sorry I should be clearer here:
|
For 1, let's discuss in details through slack. For 2, I think that the existing TestContextParallelMetadata already cover this. Can you confirm it? |
Yes let's discuss 1 further in slack tomorrow. For 2, sorry what test is this? |
I mean the TestContextParallelMetadata might already cover this. |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
prepare_mlp_sync_batch was changed to pad global_num_tokens to (attn_cp_size * 2) (commit 2f27b0d 'Pad seq to 2 * cp') so the zigzag CP split is load-balanced. But cal_padded_tokens in nsa/utils.py still padded to attn_cp_size, and its comment promised the two were in sync. The drift broke DRAFT_EXTEND_V2 on the NSA FA3 path: - q / topk_indices end up at ceil_align(N, 2*cp_size) rows (prepare_mlp_sync_batch -> _pad_topk_indices). - nsa_cache_seqlens / nsa_cu_seqlens_q only reach ceil_align(N, cp_size) (pad_nsa_cache_seqlens via cal_padded_tokens). When N is divisible by cp_size but not 2*cp_size, the two disagree and flash_attn_with_kvcache fails the check: RuntimeError: batch_size must be equal to batch_size_k Fix by aligning to attn_cp_size * 2 in cal_padded_tokens, matching the sync-batch comment's invariant. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/rerun-failed-ci |
Motivation
Enable batch size > 1 with context parallel.
Modifications
The main modification is the context_parallel_metadata for attention.
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci