Skip to content

Support overlap-spec-v2 with fa3 attention backend#12128

Closed
b8zhong wants to merge 4 commits intosgl-project:mainfrom
bzhng-development:enable-overlap-spec
Closed

Support overlap-spec-v2 with fa3 attention backend#12128
b8zhong wants to merge 4 commits intosgl-project:mainfrom
bzhng-development:enable-overlap-spec

Conversation

@b8zhong
Copy link
Collaborator

@b8zhong b8zhong commented Oct 25, 2025

Motivation

Enable it on FA3 MLA, which is the default for dpsk + hopper

Modifications

python3 -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-R1 \
  --trust-remote-code \
  --enable-beta-spec \
  --tp 8 \
  --speculative-algorithm=EAGLE \
  --model-loader-extra-config '{
    "enable_multithread_load": true,
    "num_threads": 8
  }'

The speed of python3 -m sglang.test.send_one from

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    3.644    |  512   |   2.296    |     140.52      |
+-------------+--------+------------+-----------------+

To

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    3.206    |  512   |   2.296    |     159.71      |
+-------------+--------+------------+-----------------+

So around a 13% improvement. However, there seems to be kind of slow kernels, which I will address seperately.

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319
...
Accuracy: 0.959
Invalid: 0.000
Latency: 96.427 s
Output throughput: 1408.604 token/s

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @b8zhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the overlap-spec-v2 feature into the FlashAttention v3 backend, primarily to enhance the performance of large language models on specific hardware configurations like dpsk and hopper. The changes focus on adapting the attention mechanism's metadata initialization and conditional logic to properly support the new draft_extend_v2 forward mode. This foundational work results in a measurable speedup in token generation, laying the groundwork for further optimizations in speculative decoding.

Highlights

  • Overlap-Spec-v2 Support: This pull request introduces support for the overlap-spec-v2 mechanism within the FlashAttention v3 (FA3) attention backend, which is critical for optimizing performance on dpsk and hopper architectures.
  • Performance Improvement: The integration of overlap-spec-v2 yields a significant performance boost, demonstrating a 13% speedup in the send_one test, increasing token generation from 140.52 to 159.71 tokens/s.
  • Enhanced Metadata Handling: Updates have been made to the init_forward_metadata function to correctly handle draft_extend_v2 mode, ensuring accurate calculation of key, value, and query sequence length metadata for efficient attention operations.
  • Unified Draft Extend Logic: The is_draft_extend checks across various functions, including forward_extend and CUDA graph capture/replay, have been updated to consistently include v2, streamlining the logic for speculative decoding.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for overlap-spec-v2 with the fa3 attention backend by introducing a new DRAFT_EXTEND_V2 forward mode. The changes correctly handle this new mode in most metadata initialization functions. However, I've identified a potential issue in the CUDA graph replay logic where the new logic for DRAFT_EXTEND_V2 is incorrectly applied to the existing DRAFT_EXTEND mode. This could lead to incorrect behavior if that code path is used. My review includes a detailed explanation and a suggested fix for this issue.

)

elif forward_mode.is_draft_extend():
elif forward_mode.is_draft_extend(include_v2=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this change correctly includes DRAFT_EXTEND_V2, the logic within this elif block seems to incorrectly handle the original DRAFT_EXTEND mode.

The block now updates max_seq_len_q and cu_seqlens_q based on spec_info.accept_length (lines 1957-1965), which is specific to DRAFT_EXTEND_V2. For the original DRAFT_EXTEND mode, these parameters are static and should not be updated during replay. This could lead to incorrect behavior if DRAFT_EXTEND mode is used with CUDA graphs.

I suggest wrapping the logic that uses accept_length in a condition that checks for DRAFT_EXTEND_V2 mode specifically. Here is a suggested implementation for the whole block:

        elif forward_mode.is_draft_extend(include_v2=True):
            metadata = self.draft_extend_metadata[bs]
            metadata.cache_seqlens_int32.copy_(seq_lens)

            metadata.max_seq_len_k = seq_lens_cpu.max().item()
            metadata.cu_seqlens_k[1:].copy_(
                torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
            )

            if forward_mode.is_draft_extend_v2():
                accept_length = spec_info.accept_length[:bs]
                if spec_info.accept_length_cpu:
                    metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
                else:
                    metadata.max_seq_len_q = 1
                metadata.cu_seqlens_q[1:].copy_(
                    torch.cumsum(accept_length, dim=0, dtype=torch.int32)
                )

@JustinTong0323
Copy link
Collaborator

Acc seems dropped, check CI

@JustinTong0323
Copy link
Collaborator

btw, check #12113

@Qiaolin-Yu
Copy link
Collaborator

could you share the profile to make sure it's overlapped correctly?

@b8zhong
Copy link
Collaborator Author

b8zhong commented Oct 29, 2025

#11874

@b8zhong b8zhong closed this Oct 29, 2025
@b8zhong b8zhong deleted the enable-overlap-spec branch December 3, 2025 05:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments