Skip to content

Comments

[5/n]decouple quantization implementation from vLLM dependency#9454

Merged
zhyncs merged 3 commits intosgl-project:mainfrom
Hongbosherlock:fbgemm
Aug 21, 2025
Merged

[5/n]decouple quantization implementation from vLLM dependency#9454
zhyncs merged 3 commits intosgl-project:mainfrom
Hongbosherlock:fbgemm

Conversation

@Hongbosherlock
Copy link
Contributor

Motivation

remove vllm dependency for fbgemm_fp8 quantization.

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Hongbosherlock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new FBGEMM FP8 quantization implementation, decoupling it from the vLLM dependency. It provides a dedicated configuration and linear method for FP8 quantization within the sglang framework, including support for leveraging the Marlin kernel on GPUs that do not natively support FP8. This change aims to integrate FP8 quantization directly into sglang's runtime (SRT) layers, enhancing its quantization capabilities.

Highlights

  • FBGEMM FP8 Configuration: Implemented FBGEMMFp8Config to define the configuration for FBGEMM FP8 quantization, including an ignore list for modules and input scale upper bound.
  • FBGEMM FP8 Linear Method: Developed FBGEMMFp8LinearMethod to manage the creation, processing, and application of FP8 quantized weights, supporting both native FP8 and Marlin-accelerated FP8 operations.
  • FP8 Marlin Utilities: Added marlin_utils_fp8.py with helper functions for FP8 Marlin quantization, such as apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, and pack_fp8_to_int32.
  • Marlin Kernel Fallback: Enabled automatic fallback to Marlin kernel for FP8 quantization on GPUs without native FP8 hardware support, improving compatibility and performance.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request decouples the fbgemm_fp8 quantization from its vLLM dependency by introducing new implementations for the quantization configuration and Marlin utilities. The changes are well-structured and the logic appears correct. I have a few suggestions to improve code clarity and maintainability, primarily related to using hasattr for attribute checks and refactoring a utility function for better consistency.

Comment on lines +120 to +124
if "weight_scale" in dir(layer):
scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using in dir(layer) to check for attribute existence is less idiomatic and potentially less efficient than using hasattr(). dir() can be slow as it inspects the entire object's namespace, including methods and inherited attributes. hasattr() is a more direct and cleaner way to check for the presence of a specific attribute.

Suggested change
if "weight_scale" in dir(layer):
scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
if hasattr(layer, "weight_scale"):
scales = layer.weight_scale.to(layer.orig_dtype)
elif hasattr(layer, "weight_scale_inv"):
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv

Comment on lines +226 to +233
if name + "_weight_scale" in dir(layer):
new_name = name + "_weight_scale"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
elif name + "_weight_scale_inv" in dir(layer):
new_name = name + "_weight_scale_inv"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For checking attribute existence, hasattr() is preferred over using in dir() as it is more idiomatic, readable, and avoids potential side effects or performance issues from inspecting the entire object namespace.

Suggested change
if name + "_weight_scale" in dir(layer):
new_name = name + "_weight_scale"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
elif name + "_weight_scale_inv" in dir(layer):
new_name = name + "_weight_scale_inv"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
if hasattr(layer, name + "_weight_scale"):
new_name = name + "_weight_scale"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
elif hasattr(layer, name + "_weight_scale_inv"):
new_name = name + "_weight_scale_inv"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)

Comment on lines 302 to 315
def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
size_k_first: bool = True) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.ndim == 2

fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
fp8_tensor = fp8_tensor.contiguous()
# fp8_tensor is contiguous and have shape (N, K) now
# with `.view(torch.int32)`, it become (N, K // 4)
int32_tensor = fp8_tensor.view(torch.int32)
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pack_fp8_to_int32 function can be simplified to always return the tensor in the desired GPTQ format (K/4, N), regardless of the size_k_first flag. This makes the function's behavior more consistent and simplifies the calling code in prepare_fp8_layer_for_marlin and prepare_moe_fp8_layer_for_marlin, which currently have to handle transposition conditionally.

Suggested change
def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
size_k_first: bool = True) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.ndim == 2
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
fp8_tensor = fp8_tensor.contiguous()
# fp8_tensor is contiguous and have shape (N, K) now
# with `.view(torch.int32)`, it become (N, K // 4)
int32_tensor = fp8_tensor.view(torch.int32)
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
size_k_first: bool = True) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements with shape (K/4, N)).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.ndim == 2
if size_k_first:
# Input shape (K, N), needs to be (N, K) for view
fp8_tensor = fp8_tensor.T
# Shape is now (N, K)
fp8_tensor = fp8_tensor.contiguous()
# fp8_tensor is contiguous and has shape (N, K) now
# with `.view(torch.int32)`, it becomes (N, K // 4)
int32_tensor = fp8_tensor.view(torch.int32)
# Always return in gptq format (K/4, N)
return int32_tensor.T.contiguous()

@AniZpZ AniZpZ self-assigned this Aug 21, 2025
@zhyncs zhyncs merged commit 9c8e4f6 into sgl-project:main Aug 21, 2025
64 of 71 checks passed
@zhyncs
Copy link
Collaborator

zhyncs commented Aug 21, 2025

python3 -m sglang.launch_server --model meta-llama/Llama-3.1-405B-FP8 --tp 8

This doesn't work @AniZpZ @Hongbosherlock

@zhyncs zhyncs mentioned this pull request Aug 21, 2025
4 tasks
@Hongbosherlock
Copy link
Contributor Author

python3 -m sglang.launch_server --model meta-llama/Llama-3.1-405B-FP8 --tp 8

This doesn't work @AniZpZ @Hongbosherlock

I'm sorry...I was going to test the end-to-end accuracy this morning, forgot to mention that this PR wasn't fully ready yet...

@Hongbosherlock Hongbosherlock deleted the fbgemm branch August 22, 2025 03:47
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
@AniZpZ AniZpZ mentioned this pull request Sep 24, 2025
15 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants