-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Description
Background
We aim to optimize the code structure of MoE modules in SGLang to enhance extensibility. Currently, there are three main MoE modules: FusedMoE, EPMoE, and DeepEPMoE (along with recent additions like FlashInferEPMoE and FlashInferFusedMoE as of this document's preparation). Their implementations suffer from several issues:
- Inconsistent logic flow. Computation logics vary across modules. For instance,
FusedMoEcomputesselect_expertswithin its forward function, whileDeepEPMoEhandles it externally. Similarly, some forward functions managerouted_scaling_factorinternally, but others do not. - Poor extensibility. We plan to support multiple all-to-all communication backends under EP (e.g., DeepEP, PPLX, etc) and grouped-GEMM backends (e.g., Triton, DeepGEMM, Triton Kernels, FlashInfer MoE, etc). The current design requires a dedicated forward function for each backend combination, leading to redundancy.
- Lengthy and duplicated code. Common variable combinations are repeated across functions. For example, over 10 MoE quantization methods each handle about 15 nearly identical inputs in their
applyfunctions. DeepEP dispatch outputs (8 in total) are duplicated in multiple model files.
Design
To streamline the code structure, we will deprecate all MoE modules except FusedMoE and gradually merge existing functionalities into it. Below is an overview of the target code structure:
[input_hidden_states]
|
v
TopK.forward -> `select_experts` / `triton_kernels.routing` / bypass
|
V
[TopKOutput]
|
v
FusedMoE.forward -> Dispatcher.dispatch -> DeepEP / PPLX / bypass
| |
| v
| [DispatchOutput]
| |
| v
| quant_mothod.apply -> MoeRunner.forward -
| | |
| | v
| | pre-permute + grouped_gemm + post-permute
| | |
| |--------------------------------
| v
| [CombineInput]
| |
| v
| Dispatcher.combine -> DeepEP / PPLX / bypass
| |
|---------------------
v
[final_hidden_states]
In addition to existing arguments like --quantization, we will introduce --moe-a2a-backend and --moe-runner-backend to allow users to select the optimal dispatching and grouped-GEMM backends for their use cases.
If a developer wants to support a new backend, they only need to implement the Dispatcher or grouped-GEMM logic and define the input/output formats. A PermuteMethodPool will automatically select appropriate pre-permute and post-permute functions for layout conversions (if required). Developers can also register new permute functions for unsupported layouts. The TopK forward method will be automatically determined based on the backend arguments.
Tasks
The refactoring process is divided into three stages around MoeRunner.forward: preparation, implementation, and adoption.
Stage 1: Preparation
This stage focuses on unifying computation structures across all MoE modules and their forward functions, while wrapping dependent variables for better organization.
- Structure modification
- Move all
select_expertscomputations outside MoE modules. [1/N] MoE Refactor: refactorselect_experts#7966 - Move all all-to-all communication (for dispatch and combine) inside MoE modules. [3/N] MoE Refactor: Simplify DeepEP Output #8421
- Move all
routed_scaling_factormultiplications inside MoE modules. - Unify weight loading and quantization methods across all MoE modules. [2/N] MoE Refactor: Unify weight loader and quant methods #8397
- Unify Triton kernels for
FusedMoEandEPMoE. [4/N] MoE Refactor: Unified Triton Kernel for FusedMoE and EPMoE #8515
- Move all
- Variable wrap-up
- TopK config (e.g.,
use_grouped_topk,renormalize) and TopK output. [1/N] MoE Refactor: refactorselect_experts#7966 - Dispatch output. [3/N] MoE Refactor: Simplify DeepEP Output #8421
- TopK config (e.g.,
- Server args update
- Support
--moe-a2a-backend. [5/N] MoE Refactor: Update MoE parallelism arguments #8658
- Support
Stage 2: Implementation
In this stage, we will implement the MoeRunner framework.
- Implement the framework. [7/N] MoE Refactor: the implementation of new framework #9269
- Variable wrap-up
- MoE model config (e.g.,
activation,no_combine). [6/N] MoE Refactor: Cleanup MoE-related configs #8849 - Quantization utils (e.g.,
input_scale). [7/N] MoE Refactor: the implementation of new framework #9269 - Combine input. [7/N] MoE Refactor: the implementation of new framework #9269
- MoE model config (e.g.,
- Update server args
- Support
--moe-runner-backend.
- Support
Stage 3: Adoption
The third stage gradually adopts the new framework and replaces existing implementations with the unified structure. This incremental approach allows new grouped-GEMM backends to be merged during refactoring, as long as they are functional and non-invasive.
For MoE backends implemented in quantization files, we need to check the apply method (or apply_with_router_logits / apply_without_routing_weights) and distribute the implementation to the corresponding MoE backend files. Here is the tentative plan for reorganizing the current implementation.
-
awq.py-
marlin.pyRefactor Marlin MoeRunner #14554
-
- blockwise_int8.py
-
fp8.py-
intel_amx.py -
aiter.py -
cutlass.pyRefactor Cutlass MoE runner integration #12023 -
triton.py[7/N] MoE Refactor: the implementation of new framework #9269 -
flashinfer_trtllm.pyMoE Refactor: Refactorfp8.py->flashinfer_trllm.py#15151
-
-
gptq.py-
marlin.pyRefactor Marlin MoeRunner #14554
-
-
modelopt_quant.py-
triton.py[7/N] MoE Refactor: the implementation of new framework #9269 -
flashinfer_trtllm.pyMoE Refactor: Refactormodelopt_quant.py->flashinfer_trllm.py#16685 -
flashinfer_cutlass.py -
cutlass.pyRefactor Cutlass MoE runner integration #12023 -
flashinfer_cutedsl.py
-
-
moe_wna16.py -
mxfp4.py-
flashinfer_trtllm.py -
triton_kernels.pyRefactor Triton-kernel MoE runner integration #11795 -
triton.py[7/N] MoE Refactor: the implementation of new framework #9269 -
aiter.py
-
-
unquant.py-
triton_kernels.pyRefactor Triton-kernel MoE runner integration #11795 -
aiter.py -
triton.py[7/N] MoE Refactor: the implementation of new framework #9269 -
intel_amx.py -
torch_native.py -
npu.py
-
- w4afp8.py
-
w8a8_fp8.py -
w8a8_int8.py-
intel_amx.py -
triton.py[7/N] MoE Refactor: the implementation of new framework #9269 -
npu.py
-
Some MoE backends are implemented as a separate NN module. Their implementation should be scattered into the corresponding MoE backend and quantization files.
-
FlashInferFusedMoE.forward->flashinfer_trtllm.py+fp8.py -
FlashInferFP4MoE.forward->flashinfer_trtllm.py+modelopt_quant.py -
EPMoE.forward_deepgemm->deep_gemm.py+fp8.py[8/N] MoE Refactor: deprecateEPMoE#11211 -
DeepEPMoE.forward_*-
deep_gemm.py+fp8.py[10/N] MoE Refactor: reorganize deepgemm runner in DeepEPMoE #12054 -
aiter.py+fp8.py -
flashinfer_cutedsl.py+modelopt_quant.py -
npu.py+fp8.py
-