Skip to content

Comments

Fix EPLB + FP4 Quantization Compatibility Issue#13715

Merged
Fridge003 merged 5 commits intosgl-project:mainfrom
shifangx:shifang/fix-eplb-fp4
Jan 10, 2026
Merged

Fix EPLB + FP4 Quantization Compatibility Issue#13715
Fridge003 merged 5 commits intosgl-project:mainfrom
shifangx:shifang/fix-eplb-fp4

Conversation

@shifangx
Copy link
Contributor

@shifangx shifangx commented Nov 21, 2025


Note: The following description was generated by AI and may not be accurate.

EPLB + FP4 Quantization Compatibility Issue Analysis and Fix

Problem Description

When using EPLB (Expert-based Load Balancing) for expert rebalancing, the system crashes with an AssertionError:

AssertionError: num_local_physical_experts=6 [x.shape for x in routed_experts_weights]=[
    torch.Size([6, 4096, 3584]), 
    torch.Size([6, 7168, 1024]), 
    torch.Size([6, 4096, 448]), 
    torch.Size([6, 7168, 128]), 
    torch.Size([6]), 
    torch.Size([6]), 
    torch.Size([288, 2]),  # ❌ First dimension is 288, not 6
    torch.Size([288]),      # ❌ First dimension is 288, not 6
    torch.Size([6]), 
    torch.Size([6]), 
    torch.Size([]),         # ❌ Missing dimension
    torch.Size([6])
]

Error Location: sglang/python/sglang/srt/eplb/expert_location_updater.py:176

assert all(
    tensor.shape[0] == num_local_physical_experts
    for tensor in routed_experts_weights
), f"{num_local_physical_experts=} {[x.shape for x in routed_experts_weights]=}"

Environment Configuration

  • Model: DeepSeek-V3-0324
  • Quantization Method: FP4 (ModelOpt NVFP4)
  • EP Size: 48 (6 local physical experts per rank)
  • Total Experts: 288
  • EPLB: Enabled, using --enable-eplb and --expert-distribution-recorder-mode stat

Root Cause

When using FP4 quantization, ModelOptNvFp4FusedMoEMethod creates some special parameters whose shapes do not meet the requirements for EPLB expert migration:

1. Global Expert Scale Parameters

In sglang/python/sglang/srt/layers/quantization/modelopt_quant.py:1240-1252:

w13_input_scale = PerTensorScaleParameter(
    data=torch.empty(layer.num_experts, 2, dtype=torch.float32),  # ❌ Uses global expert count
    weight_loader=weight_loader,
)
w13_input_scale._sglang_require_global_experts = True  # Marked as global parameter
layer.register_parameter("w13_input_scale", w13_input_scale)

w2_input_scale = PerTensorScaleParameter(
    data=torch.empty(layer.num_experts, dtype=torch.float32),  # ❌ Uses global expert count
    weight_loader=weight_loader,
)
w2_input_scale._sglang_require_global_experts = True  # Marked as global parameter
layer.register_parameter("w2_input_scale", w2_input_scale)

Problem:

  • The first dimension of these parameters is num_experts (288), not num_local_physical_experts (6)
  • They are marked with _sglang_require_global_experts = True, indicating they are global parameters
  • EPLB expects all parameters to have the local expert count as the first dimension

2. Swizzled Blockscale Parameters

In sglang/python/sglang/srt/layers/quantization/modelopt_quant.py:1197 and 1214:

layer.w13_blockscale_swizzled = Parameter(
    self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)

layer.w2_blockscale_swizzled = Parameter(
    self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
)

The swizzle_blockscale function reshapes the tensor, losing the expert dimension:

def swizzle_blockscale(self, scale: torch.Tensor):
    # ...
    return (
        swizzled_scale.reshape(M_padded, K_padded)  # ❌ Lost batch (expert) dimension
        if scale_ndim == 2
        else swizzled_scale
    )

Problem:

  • Original w13_weight_scale shape: [num_local_experts, M, K]
  • After swizzling shape: [M_padded, K_padded] - missing expert dimension
  • EPLB cannot handle this shape correctly

3. Why Are These Parameters Included?

In deepseek_v2.py, the get_moe_weights() method returns all expert parameters:

def get_moe_weights(self):
    return [
        x.data
        for name, x in self.experts.named_parameters()
        if name not in ["correction_bias"]  # Only excludes correction_bias
    ]

This includes all parameters registered via register_parameter, including the problematic parameters above.

Code Call Chain

EPLBManager.rebalance()
  ↓
ModelRunner.update_expert_location()
  ↓
ExpertLocationUpdater.update()
  ↓
_update_expert_weights_raw()
  ↓
update_expert_weights_single_layer()
    ↓
    assert all(tensor.shape[0] == num_local_physical_experts ...)  # ❌ Assertion fails

Where routed_experts_weights_of_layer comes from:

# model_runner.py:977
self.expert_location_updater.update(
    self.model.routed_experts_weights_of_layer,  # Obtained from model
    ...
)

# deepseek_v2.py:3158
@property
def routed_experts_weights_of_layer(self):
    return self._routed_experts_weights_of_layer.value

# deepseek_v2.py:3148
self._routed_experts_weights_of_layer = LazyValue(
    lambda: {
        layer_id: layer.mlp.get_moe_weights()  # ← Calls get_moe_weights()
        for layer_id, layer in enumerate(self.model.layers)
        if isinstance(layer.mlp, DeepseekV2MoE)
    }
)

Solution

Modify the get_moe_weights() method in all MoE models to filter out parameters that should not participate in EPLB migration:

def get_moe_weights(self):
    return [
        x.data
        for name, x in self.experts.named_parameters()
        if name not in ["correction_bias"]
            and not getattr(x, "_sglang_require_global_experts", False)
            and not name.endswith("_blockscale_swizzled")
            and x.data.ndim > 0  # Exclude scalar tensors
            and x.data.shape[0] == self.experts.num_local_experts  # Exclude tensors with wrong first dimension
    ]

Motivation

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @shifangx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a compatibility issue arising from the simultaneous use of expert-parallel load balancing and FP4 quantization within Mixture-of-Experts models. The fix involves refining the parameter selection logic in the get_moe_weights method across various MoE model architectures, preventing incorrect inclusion of certain parameter types that led to conflicts. This ensures the stability and proper functioning of these models under the specified configurations.

Highlights

  • MoE Weight Collection Logic: Modified the get_moe_weights function across seven different Mixture-of-Experts (MoE) model implementations. The update introduces additional filtering conditions to exclude parameters marked as _sglang_require_global_experts or those ending with _blockscale_swizzled from the collected expert weights.
  • Issue Resolution: This change specifically addresses and fixes an issue that occurred when using the --enable-eplb (expert-parallel load balancing) flag in combination with fp4 (4-bit floating point) quantization, ensuring correct parameter handling in these scenarios.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@shifangx shifangx changed the title fix issue with --enable-eplb and fp4 EPLB + FP4 Quantization Compatibility Issue Analysis and Fix Nov 21, 2025
@shifangx shifangx changed the title EPLB + FP4 Quantization Compatibility Issue Analysis and Fix Fix EPLB + FP4 Quantization Compatibility Issue Nov 21, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue related to --enable-eplb and fp4 by correctly filtering parameters in the get_moe_weights method. The change is applied consistently across multiple model files, which is good. However, this has led to significant code duplication, as the get_moe_weights method is now identical in seven different files. I recommend refactoring this duplicated logic into a common base class. This would greatly improve the maintainability of the codebase, making future changes to this logic much simpler. I've added a specific comment with a suggestion on how to implement this refactoring.

@kaixih
Copy link
Collaborator

kaixih commented Nov 21, 2025

Original w13_weight_scale shape: [num_local_experts, M, K]
After swizzling shape: [M_padded, K_padded] - missing expert dimension
EPLB cannot handle this shape correctly

Can you remind me why the swizzled shape has no expert dim?

@kaixih
Copy link
Collaborator

kaixih commented Nov 22, 2025

I got an error of:

          ^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                             
  File "/kaixih/workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 2173, in forward_extend                                                                                                                           
    return self.model.forward(                                                                                                                                                                                                             
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                             
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context                                                                                                                                 
    return func(*args, **kwargs)                                                                                                                                                                                                           
           ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                           
  File "/kaixih/workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 3509, in forward                                                                                                                                           
    hidden_states = self.model(                                                                                                                                                                                                            
                    ^^^^^^^^^^^                                                                                                                                                                                                            
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl                                                                                                                              
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl                                                                                                                                      
    return forward_call(*args, **kwargs)                                                                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                   
  File "/kaixih/workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 3319, in forward                                                                                                                                           
    hidden_states, residual = layer(                                                                                                                                                                                                       
                              ^^^^^^                                                                                                                                                                                                       
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl                                                                                                                              
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl                                                                                                                                      
    return forward_call(*args, **kwargs)                                                                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                   
  File "/kaixih/workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 3057, in forward                                                                                                                                           
    hidden_states = self.mlp(                                                                                                                                                                                                              
                    ^^^^^^^^^                                                                                                                                                                                                              
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl                                                                                                                              
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl                                                                                                                                      
    return forward_call(*args, **kwargs)                                                                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                   
  File "/kaixih/workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 801, in forward                                                                                                                                            
    return self.forward_normal_dual_stream(                                                                                                                                                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                
  File "/kaixih/workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 835, in forward_normal_dual_stream                                                                                                                         
    final_hidden_states = self.experts(hidden_states, topk_output)                                                                                                                                                                         
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                         
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl                                                                                                                              
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl                                                                                                                                      
    return forward_call(*args, **kwargs)                                                                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                   
  File "/kaixih/workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 873, in forward                                                                                                                             
    dispatch_output = self.dispatcher.dispatch(                                                                                                                                                                                            
                      ^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                            
  File "/kaixih/workspace/sglang/python/sglang/srt/layers/moe/token_dispatcher/standard.py", line 106, in dispatch                                                                                                                         
    assert global_scale is not None, "input_global_scale is not set"                                                                                                                                                                       
           ^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                        
AssertionError: input_global_scale is not set    

Here is my script: https://gist.github.com/kaixih/32bdc4fec4feabe9305d1acb2e1f96db (adapted from your slack messge)

@shifangx shifangx changed the title Fix EPLB + FP4 Quantization Compatibility Issue [draft] Fix EPLB + FP4 Quantization Compatibility Issue Nov 23, 2025
@Fridge003
Copy link
Collaborator

Setting SGLANG_MOE_NVFP4_DISPATCH=1 for prefill node should solve this.
There has been some refactors on MoE recently.
#13715 (comment)

@shifangx
Copy link
Contributor Author

shifangx commented Nov 24, 2025

Perhaps it would be better for a colleague who is familiar with the fp4 weight format to handle this issue.
So I just close this pr.

@shifangx
Copy link
Contributor Author

Original w13_weight_scale shape: [num_local_experts, M, K]
After swizzling shape: [M_padded, K_padded] - missing expert dimension
EPLB cannot handle this shape correctly

Can you remind me why the swizzled shape has no expert dim?

The description was generated by AI and may not be accurate.

@wenscarl
Copy link
Collaborator

wenscarl commented Dec 9, 2025

Just tried this fix in 0.5.5.post.2 container with script here except I have to manually start router and client. I didn't obverse any issue running e2e.

@wenscarl wenscarl force-pushed the shifang/fix-eplb-fp4 branch from dd53378 to ee86ef4 Compare December 31, 2025 21:25
@wenscarl wenscarl requested a review from Fridge003 December 31, 2025 21:25
@shifangx shifangx changed the title [draft] Fix EPLB + FP4 Quantization Compatibility Issue Fix EPLB + FP4 Quantization Compatibility Issue Jan 8, 2026
@Fridge003
Copy link
Collaborator

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Jan 9, 2026
@Fridge003 Fridge003 merged commit d27f16f into sgl-project:main Jan 10, 2026
163 of 167 checks passed
Fridge003 pushed a commit that referenced this pull request Jan 10, 2026
Co-authored-by: Shu Wang <shuw@nvidia.com>
Fridge003 pushed a commit that referenced this pull request Jan 11, 2026
Co-authored-by: Shu Wang <shuw@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants