Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import functools
from typing import Optional

import torch

from sglang.srt.utils import is_cuda

_is_cuda = is_cuda()

if _is_cuda:
from sgl_kernel import silu_and_mul


def get_scalar_type(num_bits: int, has_zp: bool):
from sgl_kernel.scalar_type import scalar_types

if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128


def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False,
routed_scaling_factor: float = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (int): The number of bits in expert weights quantization.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
from sglang.srt.layers.moe.fused_moe_triton import (
moe_align_block_size,
try_get_optimal_moe_config,
)

assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2
), "Hidden size mismatch w2"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert (
hidden_states.dtype == w1_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})"
assert (
hidden_states.dtype == w2_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})"
assert num_bits in [4, 8]

M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
topk = topk_ids.shape[1]

get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
None,
is_marlin=True,
)
config = get_config_func(M)

block_size_m = config["BLOCK_SIZE_M"]

if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, block_size_m, global_num_experts
)

if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * (
sorted_token_ids.size(0) // block_size_m
)
device = hidden_states.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
max_workspace_size = min(max_workspace_size, sms * 4)
workspace = torch.zeros(
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
)

scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)

intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K),),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[: M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[: M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache3.view(-1, K)

use_atomic_add = (
hidden_states.dtype == torch.half
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
)

intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
hidden_states,
intermediate_cache1,
w1,
w1_scale,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type_id=scalar_type1.id,
size_m=M,
size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
)

silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)

if expert_map is not None:
intermediate_cache3.zero_()

intermediate_cache3 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
intermediate_cache2,
intermediate_cache3,
w2,
w2_scale,
w2_zeros,
g_idx2,
sort_indices2,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=True,
is_ep=expert_map is not None,
b_q_type_id=scalar_type2.id,
size_m=M * topk,
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
).view(-1, topk, K)

output = hidden_states if inplace else torch.empty_like(hidden_states)
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output)
if routed_scaling_factor is not None:
output *= routed_scaling_factor
return output


def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False,
routed_scaling_factor: float = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
10 changes: 4 additions & 6 deletions python/sglang/srt/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,7 @@
import torch_npu

if _is_cuda:
from sgl_kernel import (
awq_dequantize,
awq_marlin_moe_repack,
awq_marlin_repack,
fused_marlin_moe,
)
from sgl_kernel import awq_dequantize, awq_marlin_moe_repack, awq_marlin_repack


elif _is_hip:
Expand Down Expand Up @@ -835,6 +830,9 @@ def apply(
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import (
fused_marlin_moe,
)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput

assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@
from enum import Enum
from typing import TYPE_CHECKING

try:
from sgl_kernel import fused_marlin_moe

FUSED_MARLIN_MOE_AVAILABLE = True
except ImportError:
FUSED_MARLIN_MOE_AVAILABLE = False

import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
Expand Down Expand Up @@ -56,9 +49,6 @@
from aiter.ops.shuffle import shuffle_weight


if _is_cuda:
from sgl_kernel import fused_marlin_moe

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -635,7 +625,9 @@ def apply(
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:

from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import (
fused_marlin_moe,
)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput

assert (
Expand All @@ -662,7 +654,6 @@ def apply(
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
is_k_full=self.is_k_full,
expert_map=torch.empty(1, device=x.device),
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
)
return StandardCombineInput(hidden_states=output)
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
_is_cuda = is_cuda()

if _is_cuda:
from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
from sgl_kernel import gptq_gemm, gptq_marlin_repack, gptq_shuffle


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1059,14 +1059,14 @@ def apply(
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:

from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import (
fused_marlin_moe,
)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput

x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output

# Delay the import to avoid circular dependency

assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/test/test_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import pytest
import torch
from sgl_kernel import fused_marlin_moe
from sgl_kernel.scalar_type import ScalarType, scalar_types

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import fused_marlin_moe
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize

Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/python/sgl_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
silu_and_mul,
)
from sgl_kernel.expert_specialization import es_fp8_blockwise_scaled_grouped_mm
from sgl_kernel.fused_moe import fused_marlin_moe, moe_wna16_marlin_gemm
from sgl_kernel.fused_moe import moe_wna16_marlin_gemm
from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
Expand Down
Loading
Loading