Skip to content

Conversation

@Qiaolin-Yu
Copy link
Collaborator

Motivation

For long context (e.g., 50000 input len, 10000 output len), target_verify will have very bad performance when using prefill kernel.

image (5) image (6)

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Qiaolin-Yu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of speculative decoding, particularly for models handling very long contexts. By conditionally employing a specialized TRT-LLM MHA decode kernel for the target_verify phase, it mitigates the performance bottlenecks previously encountered when using a generic prefill kernel, leading to more efficient processing of large input and output sequences.

Highlights

  • Performance Optimization: Switched the target_verify step in speculative decoding from using a prefill kernel to a more efficient TRT-LLM MHA decode kernel.
  • Context Length Improvement: Addresses significant performance degradation observed with the prefill kernel for target_verify in long context scenarios (e.g., 50,000 input length, 10,000 output length).
  • Metadata Update: Introduced an update to the attention metadata to correctly set max_seq_len_q based on speculative_num_draft_tokens when initializing for target_verify mode.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Qiaolin-Yu Qiaolin-Yu self-assigned this Nov 26, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for speculative decoding by using the trtllm_mha_decode_kernel for the target_verify step. This is a sensible change, as the decode kernel is better optimized for the fixed query length characteristic of the verification phase. The changes correctly dispatch to the appropriate kernel based on the forward mode and ensure the necessary metadata is updated for CUDA graph replay. I have one suggestion to improve code maintainability by reducing duplication.

Comment on lines +618 to +653
if forward_batch.forward_mode.is_target_verify():
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=self.workspace_buffer,
block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32,
max_seq_len=self.max_context_len,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
out_dtype=self.q_data_type, # model_runner.dtype
q_len_per_req=self.forward_metadata.max_seq_len_q,
)
else:

o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=self.workspace_buffer,
block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32,
max_q_len=self.forward_metadata.max_seq_len_q,
max_kv_len=self.max_context_len,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
batch_size=forward_batch.batch_size,
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
out_dtype=self.q_data_type, # model_runner.dtype
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication in the arguments passed to trtllm_batch_decode_with_kv_cache and trtllm_batch_context_with_kv_cache. To improve readability and maintainability, you could extract the common arguments into a dictionary.

        common_args = {
            "query": q,
            "kv_cache": kv_cache,
            "workspace_buffer": self.workspace_buffer,
            "block_tables": self.forward_metadata.page_table,
            "seq_lens": self.forward_metadata.cache_seqlens_int32,
            "bmm1_scale": bmm1_scale,
            "bmm2_scale": bmm2_scale,
            "window_left": layer.sliding_window_size,
            # TODO: add attention_sink operation or nvfp4 scale factor if needed
            "sinks": attention_sink,
            "out_dtype": self.q_data_type,  # model_runner.dtype
        }

        if forward_batch.forward_mode.is_target_verify():
            o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
                **common_args,
                max_seq_len=self.max_context_len,
                q_len_per_req=self.forward_metadata.max_seq_len_q,
            )
        else:
            o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
                **common_args,
                max_q_len=self.forward_metadata.max_seq_len_q,
                max_kv_len=self.max_context_len,
                batch_size=forward_batch.batch_size,
                cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
                cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
            )

@Qiaolin-Yu
Copy link
Collaborator Author

/tag-and-rerun-ci

@Kangyan-Zhou Kangyan-Zhou merged commit 7cb04dc into sgl-project:main Nov 27, 2025
242 of 275 checks passed
harvenstar pushed a commit to harvenstar/sglang that referenced this pull request Dec 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants