Skip to content

[Feature] NSA optimization roadmap #11989

@hlu1

Description

@hlu1

Attention algorithm

Image

Link to original table

The parts highlighted in blue is work that has been done or in progress.

To summarize:

  • Use MHA for short context lengths, as suggested in https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf
  • Below context lengths 2K, sparse attention is the same as non-sparse attention and we can skip the logits computation and directly generates the indices for sparse MLA kernel or use MHA when possible
  • The current flashmla_decode kernel is not well optimized on B200. So a separate dequant kernel + flashmla_sparse_bf16 works better for prefill + fp8 kvcache if the kv cache is not too long compared to the q sequence length. The heuristics will need to be updated after new optimizations to either the prefill or decode kernels, making it a bit hard to use in practice. Detailed analysis is here

Kernel optimizations

Indexer optimizations

  • [Decode] Optimize dual stream in Indexer [Deepseek V3.2] Optimize use of dual_stream in nsa_indexer/attention #13546
  • [Decode] Move deep_gemm.get_paged_mqa_logits_metadata to init time, similar to attention kernel metadata compute
  • [Prefill] Optimize _get_topk_ragged where there are a lot of small kernels. Try multi-stream, torch.compile, and add new kernels when necessary.
  • [MTP] Enable nextn = 2/4 in deep_gemm.fp8_paged_mqa_logits, which is faster than the current implementation which uses the kernel with nextn = 1 regardless of mtp size.

Min latency optimizations

  • Enable TP in Attention

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions