Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions docker/rocm.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ ENV BUILD_AITER_ALL="1"
ENV BUILD_MOONCAKE="1"
ENV AITER_COMMIT="v0.1.4"
ENV NO_DEPS_FLAG=""
ENV AITER_MXFP4_MOE_SF="0"

# ===============================
# Base image 942 and args
Expand All @@ -34,7 +33,6 @@ ENV BUILD_AITER_ALL="1"
ENV BUILD_MOONCAKE="1"
ENV AITER_COMMIT="v0.1.7.post1"
ENV NO_DEPS_FLAG=""
ENV AITER_MXFP4_MOE_SF="0"

# ===============================
# Base image 950 and args
Expand All @@ -44,9 +42,8 @@ ENV BUILD_TRITON="0"
ENV BUILD_LLVM="0"
ENV BUILD_AITER_ALL="1"
ENV BUILD_MOONCAKE="1"
ENV AITER_COMMIT="v0.1.7.post1"
ENV AITER_COMMIT="v0.1.7.post2"
ENV NO_DEPS_FLAG=""
ENV AITER_MXFP4_MOE_SF="1"
# ===============================
# Chosen arch and args
FROM ${GPU_ARCH}
Expand Down Expand Up @@ -107,8 +104,7 @@ RUN git clone ${AITER_REPO} \
&& git checkout ${AITER_COMMIT} \
&& git submodule update --init --recursive
RUN cd aiter \
&& if [ "$GPU_ARCH" = "gfx950" ]; then export AITER_MXFP4_MOE_SF=1; fi \
&& echo "[AITER] GPU_ARCH=${GPU_ARCH} AITER_MXFP4_MOE_SF=${AITER_MXFP4_MOE_SF:-unset}" \
&& echo "[AITER] GPU_ARCH=${GPU_ARCH}" \
&& if [ "$BUILD_AITER_ALL" = "1" ] && [ "$BUILD_LLVM" = "1" ]; then \
sh -c "HIP_CLANG_PATH=/sgl-workspace/llvm-project/build/bin/ PREBUILD_KERNELS=1 GPU_ARCHS=$GPU_ARCH_LIST python setup.py develop"; \
elif [ "$BUILD_AITER_ALL" = "1" ]; then \
Expand Down Expand Up @@ -299,7 +295,6 @@ RUN python3 -m pip install --no-cache-dir \

# -----------------------
# Performance environment variable.
RUN echo "AITER_MXFP4_MOE_SF=${AITER_MXFP4_MOE_SF}" >> /etc/environment

ENV HIP_FORCE_DEV_KERNARG=1
ENV HSA_NO_SCRATCH_RECLAIM=1
Expand Down
13 changes: 10 additions & 3 deletions python/sglang/srt/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
is_cuda,
is_flashinfer_available,
is_gfx95_supported,
is_hip,
is_sm100_supported,
is_triton_kernels_available,
Expand Down Expand Up @@ -72,7 +72,7 @@
)

_is_hip = is_hip()
_is_shuffle_moe_mxfp4 = get_bool_env_var("AITER_MXFP4_MOE_SF") and _is_hip
_is_shuffle_moe_mxfp4 = is_gfx95_supported()

if _is_hip:
# import aiter
Expand Down Expand Up @@ -804,14 +804,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)

# Pre-shuffle weight
if _is_shuffle_moe_mxfp4:
is_shuffled = _is_shuffle_moe_mxfp4
if is_shuffled:
w13 = shuffle_weight(w13.contiguous(), (16, 16))
w2 = shuffle_weight(w2.contiguous(), (16, 16))

layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
layer.w13_weight.is_shuffled = is_shuffled
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)

layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
layer.w2_weight.is_shuffled = is_shuffled
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)

def create_moe_runner(
Expand Down Expand Up @@ -842,6 +845,10 @@ def apply(
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight

if hasattr(layer.w13_weight, "is_shuffled"):
w13_weight.is_shuffled = True
w2_weight.is_shuffled = True

output = fused_moe(
x,
w13_weight,
Expand Down
16 changes: 13 additions & 3 deletions python/sglang/srt/layers/quantization/quark/quark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import all_close_1d, per_tensor_dequantize
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
from sglang.srt.utils import (
get_bool_env_var,
is_gfx95_supported,
is_hip,
set_weight_attrs,
)

if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
Expand All @@ -24,8 +29,7 @@

logger = logging.getLogger(__name__)

_is_hip = is_hip()
_is_shuffle_moe_mxfp4 = get_bool_env_var("AITER_MXFP4_MOE_SF") and _is_hip
_is_shuffle_moe_mxfp4 = is_gfx95_supported()

__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]

Expand Down Expand Up @@ -190,6 +194,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.contiguous(), (16, 16)
)
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
Expand Down Expand Up @@ -220,6 +226,10 @@ def apply(
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight

if hasattr(layer.w13_weight, "is_shuffled"):
w13_weight.is_shuffled = True
w2_weight.is_shuffled = True

output = fused_moe(
x,
w13_weight,
Expand Down
Loading