Skip to content

Comments

[Feature] support dpsk v32 flashinfer decode#15546

Closed
DarkSharpness wants to merge 2 commits intosgl-project:mainfrom
DarkSharpness:dpsk_v32_fi
Closed

[Feature] support dpsk v32 flashinfer decode#15546
DarkSharpness wants to merge 2 commits intosgl-project:mainfrom
DarkSharpness:dpsk_v32_fi

Conversation

@DarkSharpness
Copy link
Collaborator

@DarkSharpness DarkSharpness commented Dec 20, 2025

Motivation

In next flashinfer release, there will be support for sparse MLA decode on Blackwell.

Modifications

  1. Integrate flashinfer decode for sparse MLA (NSA).
  2. Use flashinfer as default NSA_decode for Blackwell.

Support for fp8 is working in progress.

Accuracy Tests

Benchmarking and Profiling

8 * B200.

# server
python3 -m sglang.launch_server --model-path "deepseek-ai/DeepSeek-V3.2" --tp 8 --enable-dp-attention --dp 8 --kv-cache-dtype bf16 --disable-radix

# client
python3 -m sglang.test.send_one

Before (flashmla_sparse for decode):

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|   10.264    |  512   |   1.000    |      49.88      |
+-------------+--------+------------+-----------------+

After (flashinfer for decode)

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    9.670    |  512   |   1.000    |      52.95      |
+-------------+--------+------------+-----------------+

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @DarkSharpness, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces and integrates flashinfer as an optimized backend for sparse Multi-Layer Attention (MLA) decoding, particularly targeting Blackwell GPUs. The change aims to leverage flashinfer's capabilities to enhance decoding performance, making it the default choice for specific hardware and data type configurations.

Highlights

  • FlashInfer Integration: Added 'flashinfer' as a new supported backend for Native Sparse Attention (NSA) decode operations, expanding the available high-performance options.
  • Blackwell Optimization: Configured 'flashinfer' to be the default NSA decode backend for Blackwell GPUs (SM major >= 10) when using 'bfloat16' KV cache, aiming for optimal performance on this architecture.
  • Performance Improvement: Benchmarking results indicate a notable speed increase for decode operations, moving from 49.88 tokens/s to 52.95 tokens/s on an 8 * B200 setup when utilizing 'flashinfer'.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@DarkSharpness DarkSharpness changed the title [Featuure] support dpsk v32 flashinfer decode [Feature] support dpsk v32 flashinfer decode Dec 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates flashinfer for sparse MLA decode on Blackwell GPUs, which is a great performance enhancement. The implementation looks solid, but I have a few suggestions to improve code quality. Specifically, I've pointed out a debug print statement that should be removed or replaced with logging and some hardcoded model-specific parameters that should be dynamically retrieved from the model configuration for better maintainability.

self.workspace_buffer = global_workspace_buffer
else:
self.workspace_buffer = None
print(f"{self.nsa_prefill_impl = } {self.nsa_decode_impl = }")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging. Please remove it before merging to keep the production logs clean.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true

Comment on lines +1323 to +1329
return self._forward_flashinfer(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
metadata=metadata,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid hardcoding model-specific parameters in _forward_flashinfer, please pass the layer object to it. This will allow accessing the model configuration dynamically in the next step.

Suggested change
return self._forward_flashinfer(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
metadata=metadata,
)
return self._forward_flashinfer(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
metadata=metadata,
layer=layer,
)

Comment on lines +1544 to +1568
def _forward_flashinfer(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
page_table_1: torch.Tensor,
sm_scale: float,
metadata: NSAMetadata,
) -> torch.Tensor:
import flashinfer

assert self.workspace_buffer is not None
return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=q_all.unsqueeze(1), # TODO(dark): support MTP
kv_cache=kv_cache.view(-1, 1, self.real_page_size, self.kv_cache_dim),
workspace_buffer=self.workspace_buffer,
qk_nope_head_dim=128,
kv_lora_rank=512,
qk_rope_head_dim=64,
block_tables=page_table_1.unsqueeze(1), # NOTE: 1 is MTP length
seq_lens=metadata.nsa_seqlens_expanded,
max_seq_len=metadata.nsa_max_seqlen_q,
sparse_mla_top_k=self.nsa_index_topk,
bmm1_scale=sm_scale,
enable_pdl=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of hardcoding values for qk_nope_head_dim, kv_lora_rank, and qk_rope_head_dim, please retrieve them from the layer.model_config. This makes the implementation more generic and avoids magic numbers. This change depends on passing the layer object from forward_decode.

    def _forward_flashinfer(
        self,
        q_all: torch.Tensor,
        kv_cache: torch.Tensor,
        page_table_1: torch.Tensor,
        sm_scale: float,
        metadata: NSAMetadata,
        layer: "RadixAttention",
    ) -> torch.Tensor:
        import flashinfer

        assert self.workspace_buffer is not None
        return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
            query=q_all.unsqueeze(1),  # TODO(dark): support MTP
            kv_cache=kv_cache.view(-1, 1, self.real_page_size, self.kv_cache_dim),
            workspace_buffer=self.workspace_buffer,
            qk_nope_head_dim=layer.model_config.qk_nope_head_dim,
            kv_lora_rank=layer.model_config.kv_lora_rank,
            qk_rope_head_dim=layer.model_config.qk_rope_head_dim,
            block_tables=page_table_1.unsqueeze(1),  # NOTE: 1 is MTP length
            seq_lens=metadata.nsa_seqlens_expanded,
            max_seq_len=metadata.nsa_max_seqlen_q,
            sparse_mla_top_k=self.nsa_index_topk,
            bmm1_scale=sm_scale,
            enable_pdl=True,
        )

self.workspace_buffer = global_workspace_buffer
else:
self.workspace_buffer = None
print(f"{self.nsa_prefill_impl = } {self.nsa_decode_impl = }")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(f"{self.nsa_prefill_impl = } {self.nsa_decode_impl = }")

)
elif self.nsa_decode_impl == "flashinfer":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we apply _concat_mla_absorb_q_general(q_nope, q_rope), which might be faster

@Fridge003
Copy link
Collaborator

/tag-and-rerun-ci

@Fridge003
Copy link
Collaborator

Fridge003 commented Jan 20, 2026

This Pr is covered by #16758
Will add @DarkSharpness as co-author

@Fridge003 Fridge003 closed this Jan 20, 2026
@DarkSharpness DarkSharpness deleted the dpsk_v32_fi branch January 20, 2026 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants