Skip to content

Revert "optimize get_topk_ragged by fusing get k and k_scale triton kernel"#18471

Merged
Fridge003 merged 1 commit intomainfrom
revert-16043-opt-get-topk-ragged
Feb 9, 2026
Merged

Revert "optimize get_topk_ragged by fusing get k and k_scale triton kernel"#18471
Fridge003 merged 1 commit intomainfrom
revert-16043-opt-get-topk-ragged

Conversation

@Fridge003
Copy link
Collaborator

Reverts #16043

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Fridge003, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request rolls back a prior optimization that attempted to improve performance by fusing K and scale data retrieval into a single Triton kernel capable of handling batched, ragged sequences. The revert simplifies the underlying Triton kernel to process one sequence at a time, shifting the responsibility for batch iteration back to the Python layer. This change likely addresses stability or performance issues encountered with the more complex batched kernel implementation.

Highlights

  • Revert of Batched Triton Kernel: This pull request fully reverts a previous optimization that aimed to fuse K and scale data gathering into a single Triton kernel designed for ragged batches of sequences.
  • Simplified Triton Kernel API: The Triton kernel for K/S data access (_get_k_and_s_triton_kernel) has been simplified to operate on a single sequence at a time, removing complex batching logic from its implementation.
  • Python-level Batch Processing: Batch processing for K/S data gathering is now handled by iterating through individual sequences in Python, rather than attempting to process an entire batch within a single Triton kernel launch.
  • API Signature Changes: The GetKAndS.triton method and its callers (get_index_k_scale_buffer, _get_topk_ragged) have updated their signatures to reflect the per-sequence processing, removing parameters like seq_len_tensor, seq_len_sum, and max_seq_len.
  • Test File Removal: A dedicated test file (test_get_k_scale_triton_kernel.py) that validated the previous batched Triton kernel implementation has been removed, and existing tests for index_buf_accessor have been updated to match the new API.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
    • Simplified the triton method of GetKAndS by removing batch-specific parameters (seq_len_tensor, seq_len_sum, max_seq_len) and introducing a single seq_len.
    • Refactored the _get_k_and_s_triton function and its kernel (_get_k_and_s_triton_kernel) to process data for a single sequence, eliminating internal batching logic.
    • Removed batch_id, seq_len_ptr, seq_len_num_pow, and page_indice_batch_offset from the Triton kernel parameters and implementation.
  • python/sglang/srt/layers/attention/nsa/nsa_indexer.py
    • Removed the enable_dual_stream parameter from the _get_topk_ragged method.
    • Modified _get_topk_ragged to iterate through each sequence in a batch and call get_index_k_scale_buffer individually, collecting and concatenating the results.
    • Removed assertions related to forward_batch.seq_lens_cpu and forward_batch.extend_seq_lens_cpu.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Updated the get_index_k_scale_buffer method signature to accept a single seq_len instead of batch-related parameters.
    • Adjusted the call to index_buf_accessor.GetKAndS.execute to conform to the simplified API.
  • test/manual/layers/attention/nsa/test_get_k_scale_triton_kernel.py
    • Removed the entire test file, which was dedicated to testing the batched Triton kernel functionality that has now been reverted.
  • test/manual/layers/attention/nsa/test_index_buf_accessor.py
    • Updated all relevant test cases (test_get_k_and_s_correctness, test_get_k_and_s_sequential_pages, etc.) to use the simplified GetKAndS.triton API, removing seq_len_tensor and page_indices_ from test calls.
Activity
  • This pull request is a direct revert of commit sgl-project/sglang#16043, indicating a rollback of a previous optimization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Fridge003
Copy link
Collaborator Author

/rerun-stage stage-c-test-8-gpu-h200

@Fridge003
Copy link
Collaborator Author

/rerun-stage stage-c-test-4-gpu-b200

@github-actions
Copy link
Contributor

github-actions bot commented Feb 9, 2026

✅ Triggered stage-c-test-8-gpu-h200 to run independently (skipping dependencies).

@github-actions
Copy link
Contributor

github-actions bot commented Feb 9, 2026

🔗 View workflow run

@github-actions
Copy link
Contributor

github-actions bot commented Feb 9, 2026

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies).

@github-actions
Copy link
Contributor

github-actions bot commented Feb 9, 2026

🔗 View workflow run

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request reverts a previous optimization that fused the get_k and k_scale Triton kernel for batched requests. The changes correctly revert the batched kernel to a simpler, non-batched version that is called in a loop for each sequence. The related function signatures, calls, and tests are updated consistently across the codebase. I've found one minor opportunity for improvement to avoid a redundant tensor concatenation. Overall, the revert appears to be implemented correctly.

Comment on lines 526 to +529
if _is_fp8_fnuz:
k_fp8 = k_fp8.view(torch.float8_e4m3fnuz)
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fnuz)
else:
k_fp8 = k_fp8.view(torch.float8_e4m3fn)

k_scale = k_scale.view(torch.float32).squeeze(-1)
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid concatenating the k_fp8_list twice, you can perform the concatenation once and then apply the .view().

        k_fp8_cat = torch.cat(k_fp8_list, dim=0)
        k_fp8 = k_fp8_cat.view(torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn)

@Fridge003 Fridge003 merged commit 615a02d into main Feb 9, 2026
110 of 120 checks passed
@Fridge003 Fridge003 deleted the revert-16043-opt-get-topk-ragged branch February 9, 2026 08:37
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant