Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 69 additions & 280 deletions python/sglang/srt/layers/moe/ep_moe/layer.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs
dispatch_output=dispatch_output,
**kwargs,
)
final_hidden_states = self.dispatcher.combine(combine_input)
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)

# TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[
Expand Down
309 changes: 287 additions & 22 deletions python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
Expand All @@ -15,14 +16,28 @@
register_pre_permute,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import dispose_tensor
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader

if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPLLCombineInput,
DeepEPLLDispatchOutput,
DeepEPNormalCombineInput,
DeepEPNormalDispatchOutput,
)
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)

_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul


# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
Expand All @@ -40,13 +55,23 @@ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
return new_x.transpose(1, 2).contiguous().transpose(1, 2)


def copy_list_to_gpu_no_ce(arr: List[int]):
from sgl_kernel.elementwise import copy_to_gpu_no_ce

tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
return tensor_gpu


@dataclass
class DeepGemmRunnerInput(RunnerInput):
hidden_states: torch.Tensor
hidden_states_scale: torch.Tensor
masked_m: torch.Tensor
expected_m: int
use_masked_gemm: bool
masked_m: Optional[torch.Tensor] = None
expected_m: Optional[int] = None
m_indices: Optional[torch.Tensor] = None

@property
def runner_backend(self) -> MoeRunnerBackend:
Expand Down Expand Up @@ -84,20 +109,100 @@ def run(
running_state: dict,
) -> DeepGemmRunnerOutput:

if runner_input.use_masked_gemm:
hidden_states = self._run_masked_gemm(
runner_input,
quant_info,
running_state,
if not runner_input.use_masked_gemm:
hidden_states = self._run_contiguous_gemm(
runner_input, quant_info, running_state
)
else:
hidden_states = self._run_contiguous_gemm(
runner_input,
quant_info,
running_state,
hidden_states = self._run_masked_gemm(
runner_input, quant_info, running_state
)
return DeepGemmRunnerOutput(hidden_states=hidden_states)

def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:

from sglang.srt.layers.moe.ep_moe.kernels import tma_align_input_scale
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)

hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale
all_tokens = running_state["all_tokens"]
hidden_states_device = running_state["hidden_states_device"]
hidden_states_dtype = running_state["hidden_states_dtype"]
hidden_states_shape = running_state["hidden_states_shape"]
m_indices = runner_input.m_indices

N = quant_info.w13_weight.size(1)
K = hidden_states_shape[1]
scale_block_size = 128

w13_weight_fp8 = (
quant_info.w13_weight,
quant_info.w13_scale,
)
w2_weight_fp8 = (quant_info.w2_weight, quant_info.w2_scale)

gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
hidden_states_scale = tma_align_input_scale(hidden_states_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(hidden_states, hidden_states_scale),
w13_weight_fp8,
gateup_output,
m_indices,
)

dispose_tensor(hidden_states)
dispose_tensor(hidden_states_scale)

down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output

down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del down_input

down_output = torch.empty(
(all_tokens, K),
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)

deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
w2_weight_fp8,
down_output,
m_indices,
)

return down_output

def _run_masked_gemm(
self,
runner_input: DeepGemmRunnerInput,
Expand Down Expand Up @@ -149,6 +254,7 @@ def _run_masked_gemm(
expected_m,
)
dispose_tensor(hidden_states)
dispose_tensor(hidden_states_scale)

# Act
down_input = torch.empty(
Expand Down Expand Up @@ -198,18 +304,9 @@ def _run_masked_gemm(
masked_m,
expected_m,
)
del down_input

return down_output

def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
pass

@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
Expand All @@ -222,6 +319,7 @@ def pre_permute_standard_to_deep_gemm(
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:

from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess

hidden_states, topk_output = dispatch_output
Expand Down Expand Up @@ -257,9 +355,9 @@ def pre_permute_standard_to_deep_gemm(
return DeepGemmRunnerInput(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
use_masked_gemm=True,
masked_m=masked_m,
expected_m=expected_m,
use_masked_gemm=True,
)


Expand Down Expand Up @@ -302,3 +400,170 @@ def post_permute_deep_gemm_to_standard(
return StandardCombineInput(
hidden_states=output,
)


@register_pre_permute("deepep_ll", "deep_gemm")
def pre_permute_deepep_ll_to_deep_gemm(
dispatch_output: DeepEPLLDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:

hidden_states, hidden_states_scale, topk_ids, topk_weights, masked_m, expected_m = (
dispatch_output
)

running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights
running_state["hidden_states_shape"] = hidden_states.shape
running_state["hidden_states_dtype"] = hidden_states.dtype
running_state["hidden_states_device"] = hidden_states.device

return DeepGemmRunnerInput(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
use_masked_gemm=True,
masked_m=masked_m,
expected_m=expected_m,
)


@register_post_permute("deep_gemm", "deepep_ll")
def post_permute_deep_gemm_to_deepep_ll(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepEPLLCombineInput:

from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPLLCombineInput

return DeepEPLLCombineInput(
hidden_states=runner_output.hidden_states,
topk_ids=running_state["topk_ids"],
topk_weights=running_state["topk_weights"],
)


@register_pre_permute("deepep_normal", "deep_gemm")
def pre_permute_deepep_normal_to_deep_gemm(
dispatch_output: DeepEPNormalDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:

from sglang.srt.layers.moe.ep_moe.kernels import ep_scatter

(
hidden_states,
hidden_states_scale,
topk_ids,
topk_weights,
num_recv_tokens_per_expert,
) = dispatch_output
assert runner_config.activation == "silu"

all_tokens = sum(num_recv_tokens_per_expert)
running_state["all_tokens"] = all_tokens

K = hidden_states.shape[1]

hidden_states_shape = hidden_states.shape
hidden_states_device = hidden_states.device
hidden_states_dtype = hidden_states.dtype

running_state["hidden_states_shape"] = hidden_states_shape
running_state["hidden_states_device"] = hidden_states_device
running_state["hidden_states_dtype"] = hidden_states_dtype
running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights

input_tensor = torch.empty(
(all_tokens, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
# TODO check whether need `zeros`
input_tensor_scale = torch.zeros(
(ceil_div(K // 128, 4), all_tokens),
device=hidden_states.device,
dtype=torch.int,
).transpose(0, 1)
else:
input_tensor_scale = torch.empty(
(all_tokens, K // 128),
device=hidden_states.device,
dtype=torch.float32,
)
m_indices = torch.empty(all_tokens, device=hidden_states.device, dtype=torch.int32)
output_index = torch.empty_like(topk_ids)

if get_offloader().forbid_copy_engine_usage:
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
num_recv_tokens_per_expert
)
else:
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)

ep_scatter(
hidden_states,
hidden_states_scale,
topk_ids,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor,
input_tensor_scale,
m_indices,
output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
dispose_tensor(hidden_states)
dispose_tensor(hidden_states_scale)

running_state["output_index"] = output_index

return DeepGemmRunnerInput(
hidden_states=input_tensor,
hidden_states_scale=input_tensor_scale,
use_masked_gemm=False,
m_indices=m_indices,
)


@register_post_permute("deep_gemm", "deepep_normal")
def post_permute_deep_gemm_to_deepep_normal(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepEPNormalCombineInput:

from sglang.srt.layers.moe.ep_moe.kernels import ep_gather
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPNormalCombineInput

hidden_states = runner_output.hidden_states
topk_ids = running_state["topk_ids"]
topk_weights = running_state["topk_weights"]
output_index = running_state["output_index"]

gather_out = torch.empty(
running_state["hidden_states_shape"],
device=running_state["hidden_states_device"],
dtype=torch.bfloat16,
)
ep_gather(hidden_states, topk_ids, topk_weights, output_index, gather_out)

return DeepEPNormalCombineInput(
hidden_states=gather_out,
topk_ids=running_state["topk_ids"],
topk_weights=running_state["topk_weights"],
)
Loading
Loading