-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Closed
Labels
blackwellSM100/SM120SM100/SM120deepseekenhancementNew feature or requestNew feature or requestfeaturenvidia
Description
Attention algorithm
The parts highlighted in blue is work that has been done or in progress.
To summarize:
- Use MHA for short context lengths, as suggested in https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf
- Below context lengths 2K, sparse attention is the same as non-sparse attention and we can skip the logits computation and directly generates the indices for sparse MLA kernel or use MHA when possible
- The current flashmla_decode kernel is not well optimized on B200. So a separate dequant kernel + flashmla_sparse_bf16 works better for prefill + fp8 kvcache if the kv cache is not too long compared to the q sequence length. The heuristics will need to be updated after new optimizations to either the prefill or decode kernels, making it a bit hard to use in practice. Detailed analysis is here
Kernel optimizations
- FP8 per tensor sparse MLA kernel from trtllm
- More optimizations to flashmla_decode on B200
- Adaptive MHA attention pathway DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill #11892 [DeepSeek-V3.2][NSA] Enable MHA Pathway for Short Sequence Prefill on B200 (SM100) #12788 [DeepseekV3.2] Deepseek fp8 support for MHA path #12964
- Layernorm optimization Enable mixed type LayerNorm kernel for NSA indexer #12044
- quantize_k_cache_fast (curr: 3 kernel, 6us). No need to cat, but pass in two separate tensors.
- Optimize
torch.cat([q_nope, q_rope])by either writing a fast triton/cuda kernel or using torch.compile. It's used for prefill/decode but the prefill one is much bigger and has more room for optimizations. The trtllm kernel supports separateq_nopeandq_rope, but flashmla doesn't. [DeepseekV32]: use_concat_mla_absorb_q_generalto replacetorch.cat#12215 [Deepseek V3.2] Use torch.compile to speed up torch.cat in nsa #13022 - DeepGeMM fp8_mqa_logits optimizations [sgl-kernel] chore: update deepgemm version #13402
Indexer optimizations
- [Decode] Optimize dual stream in Indexer [Deepseek V3.2] Optimize use of dual_stream in nsa_indexer/attention #13546
- [Decode] Move deep_gemm.get_paged_mqa_logits_metadata to init time, similar to attention kernel metadata compute
- [Prefill] Optimize _get_topk_ragged where there are a lot of small kernels. Try multi-stream, torch.compile, and add new kernels when necessary.
- [MTP] Enable nextn = 2/4 in deep_gemm.fp8_paged_mqa_logits, which is faster than the current implementation which uses the kernel with nextn = 1 regardless of mtp size.
Min latency optimizations
- Enable TP in Attention
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
blackwellSM100/SM120SM100/SM120deepseekenhancementNew feature or requestNew feature or requestfeaturenvidia