Skip to content

Commit 7dcf910

Browse files
1am9trashsogalin
andauthored
Add support for new aiter version (AR accuracy, is_shuffled PR) (#13554)
Co-authored-by: sogalin <[email protected]>
1 parent c7b37b7 commit 7dcf910

File tree

3 files changed

+25
-13
lines changed

3 files changed

+25
-13
lines changed

docker/rocm.Dockerfile

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ ENV BUILD_AITER_ALL="1"
2222
ENV BUILD_MOONCAKE="1"
2323
ENV AITER_COMMIT="v0.1.4"
2424
ENV NO_DEPS_FLAG=""
25-
ENV AITER_MXFP4_MOE_SF="0"
2625

2726
# ===============================
2827
# Base image 942 and args
@@ -34,7 +33,6 @@ ENV BUILD_AITER_ALL="1"
3433
ENV BUILD_MOONCAKE="1"
3534
ENV AITER_COMMIT="v0.1.7.post1"
3635
ENV NO_DEPS_FLAG=""
37-
ENV AITER_MXFP4_MOE_SF="0"
3836

3937
# ===============================
4038
# Base image 950 and args
@@ -44,9 +42,8 @@ ENV BUILD_TRITON="0"
4442
ENV BUILD_LLVM="0"
4543
ENV BUILD_AITER_ALL="1"
4644
ENV BUILD_MOONCAKE="1"
47-
ENV AITER_COMMIT="v0.1.7.post1"
45+
ENV AITER_COMMIT="v0.1.7.post2"
4846
ENV NO_DEPS_FLAG=""
49-
ENV AITER_MXFP4_MOE_SF="1"
5047
# ===============================
5148
# Chosen arch and args
5249
FROM ${GPU_ARCH}
@@ -107,8 +104,7 @@ RUN git clone ${AITER_REPO} \
107104
&& git checkout ${AITER_COMMIT} \
108105
&& git submodule update --init --recursive
109106
RUN cd aiter \
110-
&& if [ "$GPU_ARCH" = "gfx950" ]; then export AITER_MXFP4_MOE_SF=1; fi \
111-
&& echo "[AITER] GPU_ARCH=${GPU_ARCH} AITER_MXFP4_MOE_SF=${AITER_MXFP4_MOE_SF:-unset}" \
107+
&& echo "[AITER] GPU_ARCH=${GPU_ARCH}" \
112108
&& if [ "$BUILD_AITER_ALL" = "1" ] && [ "$BUILD_LLVM" = "1" ]; then \
113109
sh -c "HIP_CLANG_PATH=/sgl-workspace/llvm-project/build/bin/ PREBUILD_KERNELS=1 GPU_ARCHS=$GPU_ARCH_LIST python setup.py develop"; \
114110
elif [ "$BUILD_AITER_ALL" = "1" ]; then \
@@ -299,7 +295,6 @@ RUN python3 -m pip install --no-cache-dir \
299295

300296
# -----------------------
301297
# Performance environment variable.
302-
RUN echo "AITER_MXFP4_MOE_SF=${AITER_MXFP4_MOE_SF}" >> /etc/environment
303298

304299
ENV HIP_FORCE_DEV_KERNARG=1
305300
ENV HSA_NO_SCRATCH_RECLAIM=1

python/sglang/srt/layers/quantization/mxfp4.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
from sglang.srt.server_args import get_global_server_args
4040
from sglang.srt.utils import (
4141
direct_register_custom_op,
42-
get_bool_env_var,
4342
is_cuda,
4443
is_flashinfer_available,
44+
is_gfx95_supported,
4545
is_hip,
4646
is_sm100_supported,
4747
is_triton_kernels_available,
@@ -72,7 +72,7 @@
7272
)
7373

7474
_is_hip = is_hip()
75-
_is_shuffle_moe_mxfp4 = get_bool_env_var("AITER_MXFP4_MOE_SF") and _is_hip
75+
_is_shuffle_moe_mxfp4 = is_gfx95_supported()
7676

7777
if _is_hip:
7878
# import aiter
@@ -804,14 +804,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
804804
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
805805

806806
# Pre-shuffle weight
807-
if _is_shuffle_moe_mxfp4:
807+
is_shuffled = _is_shuffle_moe_mxfp4
808+
if is_shuffled:
808809
w13 = shuffle_weight(w13.contiguous(), (16, 16))
809810
w2 = shuffle_weight(w2.contiguous(), (16, 16))
810811

811812
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
813+
layer.w13_weight.is_shuffled = is_shuffled
812814
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
813815

814816
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
817+
layer.w2_weight.is_shuffled = is_shuffled
815818
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
816819

817820
def create_moe_runner(
@@ -842,6 +845,10 @@ def apply(
842845
w13_weight = layer.w13_weight
843846
w2_weight = layer.w2_weight
844847

848+
if hasattr(layer.w13_weight, "is_shuffled"):
849+
w13_weight.is_shuffled = True
850+
w2_weight.is_shuffled = True
851+
845852
output = fused_moe(
846853
x,
847854
w13_weight,

python/sglang/srt/layers/quantization/quark/quark_moe.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
1414
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
1515
from sglang.srt.layers.quantization.utils import all_close_1d, per_tensor_dequantize
16-
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
16+
from sglang.srt.utils import (
17+
get_bool_env_var,
18+
is_gfx95_supported,
19+
is_hip,
20+
set_weight_attrs,
21+
)
1722

1823
if TYPE_CHECKING:
1924
from sglang.srt.layers.moe.token_dispatcher import (
@@ -24,8 +29,7 @@
2429

2530
logger = logging.getLogger(__name__)
2631

27-
_is_hip = is_hip()
28-
_is_shuffle_moe_mxfp4 = get_bool_env_var("AITER_MXFP4_MOE_SF") and _is_hip
32+
_is_shuffle_moe_mxfp4 = is_gfx95_supported()
2933

3034
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
3135

@@ -190,6 +194,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
190194
layer.w2_weight.data = shuffle_weight(
191195
layer.w2_weight.contiguous(), (16, 16)
192196
)
197+
layer.w13_weight.is_shuffled = True
198+
layer.w2_weight.is_shuffled = True
193199

194200
def create_moe_runner(
195201
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
@@ -220,6 +226,10 @@ def apply(
220226
w13_weight = layer.w13_weight
221227
w2_weight = layer.w2_weight
222228

229+
if hasattr(layer.w13_weight, "is_shuffled"):
230+
w13_weight.is_shuffled = True
231+
w2_weight.is_shuffled = True
232+
223233
output = fused_moe(
224234
x,
225235
w13_weight,

0 commit comments

Comments
 (0)