Skip to content

Comments

Migrate renorm kernels from sgl-kernel to FlashInfer JIT#18854

Open
Johnsonms wants to merge 7 commits intosgl-project:mainfrom
Johnsonms:renorm-kernel
Open

Migrate renorm kernels from sgl-kernel to FlashInfer JIT#18854
Johnsonms wants to merge 7 commits intosgl-project:mainfrom
Johnsonms:renorm-kernel

Conversation

@Johnsonms
Copy link
Contributor

@Johnsonms Johnsonms commented Feb 15, 2026

Migrate top_k_renorm_probs, top_p_renorm_probs, and top_k_mask_logits from compiled sgl-kernel CUDA implementations to FlashInfer's JIT-compiled public API. This reduces the sgl-kernel wheel size and leverages FlashInfer's optimized and maintained implementations.

Changes:

  • Use flashinfer.sampling public API instead of torch.ops.sgl_kernel
  • Remove CUDA kernel bindings from common_extension.cc and common_extension_musa.cc
  • Remove function declarations from sgl_kernel_ops.h
  • Remove renorm.cu from build (keep sampling.cu for other functions)
  • Add comprehensive benchmark suite (bench_renorm.py)

Motivation

#17865
move (external) flashinfer/csrc/renorm.cu: https://github.com/flashinfer-ai/flashinfer/blob/bc29697ba20b7e6bdb728ded98f04788e16ee021/csrc/renorm.cu (≈4MB)

Modifications

Accuracy Tests

python -m pytest sgl-kernel/tests/test_sampling.py -s
image

python -m sglang.launch_server --model Qwen/Qwen3-4B
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 128
image
python -m sglang.launch_server --model /data/Qwen3-8B/ --disable-radix-cache
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 128
image

Benchmarking and Profiling

python benchmark/bench_renorm.py

Correctness check
image

Benchmarking

  1. Benchmarking top_k_renorm_probs...
image
  1. Benchmarking top_p_renorm_probs...
image
  1. Benchmarking top_k_mask_logits...
image

Used the data mades sumary as following:
image

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Migrate top_k_renorm_probs, top_p_renorm_probs, and top_k_mask_logits
from compiled sgl-kernel CUDA implementations to FlashInfer's JIT-compiled
public API. This reduces the sgl-kernel wheel size and leverages FlashInfer's
optimized and maintained implementations.

Changes:
- Use flashinfer.sampling public API instead of torch.ops.sgl_kernel
- Remove internal _*_internal() functions (45 lines)
- Remove CUDA kernel bindings from common_extension.cc and common_extension_musa.cc
- Remove function declarations from sgl_kernel_ops.h
- Remove renorm.cu from build (keep sampling.cu for other functions)
- Add comprehensive benchmark suite (bench_renorm.py)
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Johnsonms, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the core probability renormalization and logit masking functionalities within the sgl-kernel. By transitioning from custom CUDA implementations to FlashInfer's highly optimized and JIT-compiled API, the project benefits from reduced binary size, enhanced performance, and simplified maintenance. The changes streamline the codebase and integrate with a well-established, performant library for these critical operations, as evidenced by the detailed benchmarks provided.

Highlights

  • FlashInfer Migration: Migrated top_k_renorm_probs, top_p_renorm_probs, and top_k_mask_logits functions from custom sgl-kernel CUDA implementations to FlashInfer's JIT-compiled public API.
  • Code Simplification and Size Reduction: Removed internal _internal() functions, CUDA kernel bindings, and associated declarations, leading to a reduction in the sgl-kernel wheel size and leveraging FlashInfer's optimized and maintained implementations.
  • Benchmarking: Introduced a comprehensive benchmark suite (bench_renorm.py) to validate correctness and measure performance improvements of the FlashInfer-based implementations against PyTorch references.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • sgl-kernel/CMakeLists.txt
    • Removed renorm.cu from the list of compiled FlashInfer sources.
  • sgl-kernel/benchmark/bench_renorm.py
    • Added a new benchmark script to compare the performance and correctness of top-k/top-p renorm and top-k mask logits between PyTorch reference implementations and the SGL Kernel (FlashInfer) implementations.
  • sgl-kernel/csrc/common_extension.cc
    • Removed the CUDA binding definitions and implementations for top_k_renorm_probs, top_p_renorm_probs, and top_k_mask_logits.
  • sgl-kernel/csrc/common_extension_musa.cc
    • Removed the MUSA binding definitions and implementations for top_k_renorm_probs, top_p_renorm_probs, and top_k_mask_logits.
  • sgl-kernel/include/sgl_kernel_ops.h
    • Removed the function declarations for top_k_renorm_probs, top_p_renorm_probs, and top_k_mask_logits.
  • sgl-kernel/python/sgl_kernel/sampling.py
    • Updated top_k_renorm_probs, top_p_renorm_probs, and top_k_mask_logits to call flashinfer.sampling functions directly.
    • Removed their previous internal _internal implementations and torch.ops.sgl_kernel calls.
    • Imported flashinfer.sampling.
  • sgl-kernel/setup_musa.py
    • Removed renorm.cu from the list of MUSA compilation sources.
Activity
  • The pull request author provided a detailed description outlining the motivation and modifications.
  • Accuracy tests for sgl-kernel/tests/test_sampling.py were run and passed.
  • Benchmarks for bench_sglang.py were executed on Qwen3-4B and Qwen3-8B models, showing consistent performance.
  • A comprehensive benchmark suite (bench_renorm.py) was added and executed, demonstrating significant performance improvements of the SGL Kernel (FlashInfer) implementations over Torch Reference for renorm and mask logits operations across various batch sizes and vocabulary sizes.
  • Correctness checks for top_k_renorm_probs, top_p_renorm_probs, and top_k_mask_logits passed successfully.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully migrates the top_k_renorm_probs, top_p_renorm_probs, and top_k_mask_logits kernels from the internal sgl-kernel CUDA implementations to FlashInfer's JIT-compiled API. This is a great move to reduce the wheel size and leverage a well-maintained, optimized library. The changes are clean and consistent across the codebase, including build files and C++/Python bindings.

The addition of a comprehensive benchmark suite in bench_renorm.py is excellent for verifying correctness and performance. I've provided a few suggestions to improve the reference implementations within the benchmark file to make them more robust and efficient, which will strengthen the benchmark's reliability.

- Reorder imports in sampling.py (flashinfer before torch)
- Fix string quotes and formatting in bench_renorm.py
- Use torch.topk instead of sort+pivot for exact k-element selection
- Vectorize all three reference implementations for better performance
- Handle edge cases: ties, k=0, variable k per batch
- More robust and realistic baseline for benchmarking

Changes:
- torch_top_k_renorm_probs: Vectorized batch operations
- torch_top_p_renorm_probs: Vectorized cumsum and masking
- torch_top_k_mask_logits: Vectorized scatter operations

This makes the performance comparison more meaningful by using
efficient PyTorch operations instead of slow Python loops.
@DarkSharpness
Copy link
Collaborator

  1. Have you compared performance of new flashinfer with sgl-kernel (old flashinfer)? In theory there should be no performance regression, but we need to double check it.
  2. Is this PR conflict with use flashinfer.sampling #18696 ?

@Johnsonms
Copy link
Contributor Author

Johnsonms commented Feb 15, 2026

  1. Have you compared performance of new flashinfer with sgl-kernel (old flashinfer)? In theory there should be no performance regression, but we need to double check it.
    Thanks to @DarkSharpness, it did need
    I updated that in Benchmarking and Profiling
  1. Is this PR conflict with use flashinfer.sampling #18696 ?
    We are using the same sampling.cu file, but for different kernels, this PR is for min_p_sampling_from_probs top_k_top_p_sampling_from_probs top_p_sampling_from_probs , my PR is for top_k_renorm_probs top_p_renorm_probs and top_k_mask_logits under the file https://github.com/flashinfer-ai/flashinfer/blob/bc29697ba20b7e6bdb728ded98f04788e16ee021/csrc/renorm.cu

@Johnsonms
Copy link
Contributor Author

This change enables both CUDA and MUSA platforms to use appropriate
kernel implementations:
- CUDA: Uses FlashInfer JIT (requires NVRTC)
- MUSA: Uses compiled sgl_kernel ops (NVRTC not available)

Changes:
- Add platform detection in sampling.py with _internal helper functions
- Restore MUSA torch op registrations for renorm kernels
- Include renorm.cu in MUSA build sources
- Move bench_renorm.py to jit_kernel/benchmark/
- Add comprehensive unit tests for all three renorm kernels

The platform dispatch checks device.type and routes to the appropriate
implementation, ensuring compatibility across hardware platforms.
@BBuf
Copy link
Collaborator

BBuf commented Feb 18, 2026

/tag-and-rerun-ci

Add sgl_kernel to known_first_party and flashinfer to known_third_party
in .isort.cfg so local and CI isort agree on import ordering.
@BBuf
Copy link
Collaborator

BBuf commented Feb 21, 2026

/tag-and-rerun-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants