diff --git a/docker/rocm.Dockerfile b/docker/rocm.Dockerfile index 1cc106a2ec09..d591400c6ce1 100644 --- a/docker/rocm.Dockerfile +++ b/docker/rocm.Dockerfile @@ -22,7 +22,6 @@ ENV BUILD_AITER_ALL="1" ENV BUILD_MOONCAKE="1" ENV AITER_COMMIT="v0.1.4" ENV NO_DEPS_FLAG="" -ENV AITER_MXFP4_MOE_SF="0" # =============================== # Base image 942 and args @@ -34,7 +33,6 @@ ENV BUILD_AITER_ALL="1" ENV BUILD_MOONCAKE="1" ENV AITER_COMMIT="v0.1.7.post1" ENV NO_DEPS_FLAG="" -ENV AITER_MXFP4_MOE_SF="0" # =============================== # Base image 950 and args @@ -44,9 +42,8 @@ ENV BUILD_TRITON="0" ENV BUILD_LLVM="0" ENV BUILD_AITER_ALL="1" ENV BUILD_MOONCAKE="1" -ENV AITER_COMMIT="v0.1.7.post1" +ENV AITER_COMMIT="v0.1.7.post2" ENV NO_DEPS_FLAG="" -ENV AITER_MXFP4_MOE_SF="1" # =============================== # Chosen arch and args FROM ${GPU_ARCH} @@ -107,8 +104,7 @@ RUN git clone ${AITER_REPO} \ && git checkout ${AITER_COMMIT} \ && git submodule update --init --recursive RUN cd aiter \ - && if [ "$GPU_ARCH" = "gfx950" ]; then export AITER_MXFP4_MOE_SF=1; fi \ - && echo "[AITER] GPU_ARCH=${GPU_ARCH} AITER_MXFP4_MOE_SF=${AITER_MXFP4_MOE_SF:-unset}" \ + && echo "[AITER] GPU_ARCH=${GPU_ARCH}" \ && if [ "$BUILD_AITER_ALL" = "1" ] && [ "$BUILD_LLVM" = "1" ]; then \ sh -c "HIP_CLANG_PATH=/sgl-workspace/llvm-project/build/bin/ PREBUILD_KERNELS=1 GPU_ARCHS=$GPU_ARCH_LIST python setup.py develop"; \ elif [ "$BUILD_AITER_ALL" = "1" ]; then \ @@ -299,7 +295,6 @@ RUN python3 -m pip install --no-cache-dir \ # ----------------------- # Performance environment variable. -RUN echo "AITER_MXFP4_MOE_SF=${AITER_MXFP4_MOE_SF}" >> /etc/environment ENV HIP_FORCE_DEV_KERNARG=1 ENV HSA_NO_SCRATCH_RECLAIM=1 diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 847eaf0ee250..d44444a3a8c9 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -39,9 +39,9 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( direct_register_custom_op, - get_bool_env_var, is_cuda, is_flashinfer_available, + is_gfx95_supported, is_hip, is_sm100_supported, is_triton_kernels_available, @@ -72,7 +72,7 @@ ) _is_hip = is_hip() -_is_shuffle_moe_mxfp4 = get_bool_env_var("AITER_MXFP4_MOE_SF") and _is_hip +_is_shuffle_moe_mxfp4 = is_gfx95_supported() if _is_hip: # import aiter @@ -804,14 +804,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data) # Pre-shuffle weight - if _is_shuffle_moe_mxfp4: + is_shuffled = _is_shuffle_moe_mxfp4 + if is_shuffled: w13 = shuffle_weight(w13.contiguous(), (16, 16)) w2 = shuffle_weight(w2.contiguous(), (16, 16)) layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False) + layer.w13_weight.is_shuffled = is_shuffled layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) + layer.w2_weight.is_shuffled = is_shuffled layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False) def create_moe_runner( @@ -842,6 +845,10 @@ def apply( w13_weight = layer.w13_weight w2_weight = layer.w2_weight + if hasattr(layer.w13_weight, "is_shuffled"): + w13_weight.is_shuffled = True + w2_weight.is_shuffled = True + output = fused_moe( x, w13_weight, diff --git a/python/sglang/srt/layers/quantization/quark/quark_moe.py b/python/sglang/srt/layers/quantization/quark/quark_moe.py index 497e69b8e679..e4839220103f 100644 --- a/python/sglang/srt/layers/quantization/quark/quark_moe.py +++ b/python/sglang/srt/layers/quantization/quark/quark_moe.py @@ -13,7 +13,12 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.utils import all_close_1d, per_tensor_dequantize -from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs +from sglang.srt.utils import ( + get_bool_env_var, + is_gfx95_supported, + is_hip, + set_weight_attrs, +) if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import ( @@ -24,8 +29,7 @@ logger = logging.getLogger(__name__) -_is_hip = is_hip() -_is_shuffle_moe_mxfp4 = get_bool_env_var("AITER_MXFP4_MOE_SF") and _is_hip +_is_shuffle_moe_mxfp4 = is_gfx95_supported() __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"] @@ -190,6 +194,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight.data = shuffle_weight( layer.w2_weight.contiguous(), (16, 16) ) + layer.w13_weight.is_shuffled = True + layer.w2_weight.is_shuffled = True def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig @@ -220,6 +226,10 @@ def apply( w13_weight = layer.w13_weight w2_weight = layer.w2_weight + if hasattr(layer.w13_weight, "is_shuffled"): + w13_weight.is_shuffled = True + w2_weight.is_shuffled = True + output = fused_moe( x, w13_weight,