Skip to content

MoE Refactor: Refactor modelopt_quant.py -> flashinfer_trllm.py#16685

Merged
ch-wan merged 3 commits intosgl-project:mainfrom
bzhng-development:brayden/refactor-modelopt-moe-flashinfer-trtllm
Feb 3, 2026
Merged

MoE Refactor: Refactor modelopt_quant.py -> flashinfer_trllm.py#16685
ch-wan merged 3 commits intosgl-project:mainfrom
bzhng-development:brayden/refactor-modelopt-moe-flashinfer-trtllm

Conversation

@b8zhong
Copy link
Collaborator

@b8zhong b8zhong commented Jan 8, 2026

Motivation

Followup on #15151 (comment), and part of #8715

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added the quant LLM Quantization label Jan 8, 2026
@b8zhong
Copy link
Collaborator Author

b8zhong commented Jan 8, 2026

/tag-and-rerun-ci one more time?

@github-actions github-actions bot added the run-ci label Jan 8, 2026
@b8zhong b8zhong force-pushed the brayden/refactor-modelopt-moe-flashinfer-trtllm branch from 3b79e64 to 25c7e12 Compare January 9, 2026 17:47
@ch-wan ch-wan self-assigned this Jan 16, 2026
@b8zhong b8zhong force-pushed the brayden/refactor-modelopt-moe-flashinfer-trtllm branch from 25c7e12 to 8ec5536 Compare January 18, 2026 17:41
routing_method_type=routing_method_type,
)

return fused_experts_none_to_flashinfer_trtllm_fp4(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using runner?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it would be good. The problem is (to my undersatnding), since we can only register 1 fused function for flashinfer_trtllm, while there might either be trtllm_fp4_block_scale_moe, or fused_experts_none_to_flashinfer_trtllm_fp8`.

Therefore to simplify the codes I didn't, other wise I suggest maybe the diff below to get the flashinfer_trtllm to register two different fused backends. But I am open to either.

git --no-pager diff
diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
index 74c56761a..c2b037bda 100644
--- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
+++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
@@ -207,7 +207,6 @@ class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo):
     use_routing_scales_on_input: bool = False
 
 
-@register_fused_func("none", "flashinfer_trtllm")
 def fused_experts_none_to_flashinfer_trtllm_fp8(
     dispatch_output: StandardDispatchOutput,
     quant_info: FlashInferTrtllmFp8MoeQuantInfo,
@@ -478,3 +477,21 @@ def fused_experts_none_to_flashinfer_trtllm_fp4(
     )[0]
 
     return StandardCombineInput(hidden_states=result)
+
+
+@register_fused_func("none", "flashinfer_trtllm")
+def fused_experts_none_to_flashinfer_trtllm(
+    dispatch_output: StandardDispatchOutput,
+    quant_info: MoeQuantInfo,
+    runner_config: MoeRunnerConfig,
+) -> StandardCombineInput:
+    """Dispatch to FP8 or FP4 FlashInfer TRT-LLM MoE based on quant_info type."""
+    if isinstance(quant_info, FlashInferTrtllmFp4MoeQuantInfo):
+        return fused_experts_none_to_flashinfer_trtllm_fp4(
+            dispatch_output, quant_info, runner_config
+        )
+    return fused_experts_none_to_flashinfer_trtllm_fp8(
+        dispatch_output,
+        cast(FlashInferTrtllmFp8MoeQuantInfo, quant_info),
+        runner_config,
+    )
diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py
index a71c0bc65..3b13ce408 100755
--- a/python/sglang/srt/layers/quantization/modelopt_quant.py
+++ b/python/sglang/srt/layers/quantization/modelopt_quant.py
@@ -1496,6 +1496,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
         self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
     ):
         self.moe_runner_config = moe_runner_config
+        if get_moe_runner_backend().is_flashinfer_trtllm():
+            self.runner = MoeRunner(
+                MoeRunnerBackend.FLASHINFER_TRTLLM, moe_runner_config
+            )
 
     def apply(
         self,
@@ -1514,11 +1518,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
         ), f"{activation=} missing from {ACT_STR_TO_TYPE_MAP.keys()=}"
         moe_runner_config = self.moe_runner_config
 
-        # FlashInfer TRTLLM FP4 path - check if layer has shuffled weights
+        # FlashInfer TRTLLM FP4 path - layer has shuffled weights only when
+        # backend is flashinfer_trtllm
         if hasattr(layer, "gemm1_weights_fp4_shuffled"):
             from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
                 FlashInferTrtllmFp4MoeQuantInfo,
-                fused_experts_none_to_flashinfer_trtllm_fp4,
             )
             from sglang.srt.layers.moe.utils import RoutingMethodType
 
@@ -1543,9 +1547,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
                 routing_method_type=routing_method_type,
             )
 
-            return fused_experts_none_to_flashinfer_trtllm_fp4(
-                dispatch_output, quant_info, moe_runner_config
-            )
+            return self.runner.run(dispatch_output, quant_info)
 
         if self.enable_flashinfer_cutlass_moe:
             assert (

@b8zhong b8zhong force-pushed the brayden/refactor-modelopt-moe-flashinfer-trtllm branch 2 times, most recently from bcf041f to 8a7d539 Compare January 30, 2026 05:01
@b8zhong b8zhong added the format Auto Format Code label Jan 30, 2026
@b8zhong b8zhong force-pushed the brayden/refactor-modelopt-moe-flashinfer-trtllm branch from 8a7d539 to 4c8b913 Compare January 30, 2026 14:33
b8zhong and others added 3 commits January 30, 2026 14:57
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
@b8zhong b8zhong force-pushed the brayden/refactor-modelopt-moe-flashinfer-trtllm branch from 4c8b913 to 91c4216 Compare January 30, 2026 19:57
@ch-wan ch-wan merged commit 78bf13d into sgl-project:main Feb 3, 2026
275 of 318 checks passed
hhu-scitix pushed a commit to scitix/sglang that referenced this pull request Feb 3, 2026
…gl-project#16685)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 5, 2026
…gl-project#16685)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
sfiisf pushed a commit to sfiisf/sglang that referenced this pull request Feb 5, 2026
…gl-project#16685)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
@b8zhong b8zhong mentioned this pull request Feb 5, 2026
66 tasks
@b8zhong b8zhong deleted the brayden/refactor-modelopt-moe-flashinfer-trtllm branch February 6, 2026 21:29
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
…gl-project#16685)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
hhu-scitix pushed a commit to scitix/sglang that referenced this pull request Feb 16, 2026
…gl-project#16685)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

format Auto Format Code quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments