Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
74e3d91
[Feature] Optimize DeepSeek's DeepEP on Ascend NPU
iforgetmyname Jul 26, 2025
aec03b6
Merge branch 'main' into feature/ascend_deepep_optimize
ping1jing2 Jul 26, 2025
68a8d23
Merge remote-tracking branch 'origin/main' into feature/ascend_deepep…
iforgetmyname Jul 28, 2025
1e0c9b5
fix ascenddeepepmoe input order
iforgetmyname Jul 28, 2025
79b0dc3
fix NPU_W8A8EPMoEMethod input args
iforgetmyname Jul 28, 2025
cffa861
Merge remote-tracking branch 'origin/main' into feature/ascend_deepep…
iforgetmyname Jul 29, 2025
ae04a83
fix allgather_into_tensor using torch.ops.sglang
iforgetmyname Jul 29, 2025
2a76c2a
Merge branch 'main' into feature/ascend_deepep_optimize
iforgetmyname Jul 29, 2025
72c974c
reconstruct ascend deepep
Jul 29, 2025
7e5abbb
Merge branch 'main' into feature/ascend_deepep_optimize
iforgetmyname Jul 30, 2025
d25f1de
refactor AscendDeepEP to support latest impl
iforgetmyname Jul 30, 2025
bac0889
fix per-commit
iforgetmyname Jul 30, 2025
4beb9f3
Merge branch 'main' into feature/ascend_deepep_optimize
iforgetmyname Jul 31, 2025
abf1e75
Merge branch 'main' into feature/ascend_deepep_optimize
iforgetmyname Jul 31, 2025
dec761c
fix missing deepepmoe init args
iforgetmyname Jul 31, 2025
985c430
Merge remote-tracking branch 'origin/main' into feature/ascend_deepep…
iforgetmyname Aug 2, 2025
4479094
Merge branch 'main' into feature/ascend_deepep_optimize
iforgetmyname Aug 3, 2025
a7a6f95
Merge branch 'main' into feature/ascend_deepep_optimize
iforgetmyname Aug 4, 2025
a3a4b87
fix missing import
iforgetmyname Aug 4, 2025
b9c1a9f
matching the moe refactor
iforgetmyname Aug 4, 2025
5ce4ef2
fix per-commit
iforgetmyname Aug 4, 2025
9497a22
fix deepgemm warning
iforgetmyname Aug 4, 2025
0cd1921
Merge branch 'main' into feature/ascend_deepep_optimize
ping1jing2 Aug 4, 2025
fe26f2e
fix a merge error
iforgetmyname Aug 4, 2025
edc945b
Merge branch 'main' into feature/ascend_deepep_optimize
iforgetmyname Aug 6, 2025
7de6efc
warpping with _is_npu
iforgetmyname Aug 6, 2025
b591daa
fix a typo
iforgetmyname Aug 6, 2025
cc5b0f1
fix import issue
iforgetmyname Aug 6, 2025
0bbd282
Merge branch 'main' into feature/ascend_deepep_optimize
iforgetmyname Aug 6, 2025
752153b
Merge branch 'main' into feature/ascend_deepep_optimize
ping1jing2 Aug 6, 2025
f0d4b1c
Merge branch 'main' into feature/ascend_deepep_optimize
Alcanderian Aug 6, 2025
6abac2a
Merge branch 'main' into feature/ascend_deepep_optimize
iforgetmyname Aug 7, 2025
c178204
Merge branch 'main' into feature/ascend_deepep_optimize
Alcanderian Aug 7, 2025
c2f1043
Merge branch 'main' into feature/ascend_deepep_optimize
Alcanderian Aug 7, 2025
75f1dd8
Merge branch 'main' into feature/ascend_deepep_optimize
Alcanderian Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
supports_custom_op,
)

_is_npu = is_npu()


@dataclass
class GraphCaptureContext:
Expand Down Expand Up @@ -591,7 +593,7 @@ def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
)

def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
if not supports_custom_op():
if _is_npu or not supports_custom_op():
self._all_gather_into_tensor(output, input)
else:
torch.ops.sglang.reg_all_gather_into_tensor(
Expand Down Expand Up @@ -1127,7 +1129,7 @@ def init_model_parallel_group(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=not is_npu(),
use_pynccl=not _is_npu,
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_hpu_communicator=True,
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/attention/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()

def get_cuda_graph_seq_len_fill_value(self):
return 1

def forward_extend(
self,
q,
Expand Down
62 changes: 60 additions & 2 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
Expand Down Expand Up @@ -387,7 +388,8 @@ def __init__(
return_recv_hook=True,
)

if self.deepep_mode.enable_low_latency():
if self.deepep_mode.enable_low_latency() and not _is_npu:
# NPU supports low_latency deepep without deepgemm
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
Expand All @@ -404,7 +406,7 @@ def __init__(
)
# the last one is invalid rank_id
self.expert_mask[:-1] = 1
else:
elif not _is_npu:
self.w13_weight_fp8 = (
self.w13_weight,
(
Expand Down Expand Up @@ -459,6 +461,8 @@ def moe_impl(self, dispatch_output: DispatchOutput):
if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output)
if _is_npu:
return self.forward_npu(dispatch_output)
if dispatch_output.format.is_deepep_normal():
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
Expand Down Expand Up @@ -723,6 +727,60 @@ def forward_deepgemm_masked(

return down_output

def forward_npu(
self,
dispatch_output: DeepEPLLOutput,
):
if TYPE_CHECKING:
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
assert self.quant_method is not None
assert self.activation == "silu"

# NOTE: Ascend's Dispatch & Combine does not support FP16
output_dtype = torch.bfloat16

pertoken_scale = hidden_states[1]
hidden_states = hidden_states[0]

group_list_type = 1
seg_indptr = seg_indptr.to(torch.int64)

import torch_npu

# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
output_dtype=output_dtype,
)[0]

# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)

hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)

# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
output_dtype=output_dtype,
)[0]

return hidden_states


def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep():
Expand Down
85 changes: 61 additions & 24 deletions python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,23 @@
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
is_hip,
is_npu,
load_json_config,
)

_is_npu = is_npu()

try:
from deep_ep import Buffer, Config

from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
if not _is_npu:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)

use_deepep = True
except ImportError:
Expand Down Expand Up @@ -80,8 +89,24 @@ def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll


class AscendDeepEPLLOutput(NamedTuple):
"""AscendDeepEP low latency dispatch output."""

hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
masked_m: torch.Tensor
seg_indptr: torch.Tensor
expected_m: int

@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll


assert isinstance(DeepEPNormalOutput, DispatchOutput)
assert isinstance(DeepEPLLOutput, DispatchOutput)
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)


class DeepEPDispatchMode(IntEnum):
Expand Down Expand Up @@ -150,19 +175,20 @@ def get_deepep_buffer(
else:
raise NotImplementedError

total_num_sms = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
if (
(deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"]
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
):
logger.warning(
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
f"This may result in highly suboptimal performance. "
f"Consider using --deepep-config to change the behavior."
)
if not _is_npu:
total_num_sms = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
if (
(deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"]
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
):
logger.warning(
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
f"This may result in highly suboptimal performance. "
f"Consider using --deepep-config to change the behavior."
)

cls._buffer = Buffer(
group,
Expand Down Expand Up @@ -507,13 +533,24 @@ def dispatch_b(
masked_m
)

return DeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
)
if _is_npu:
deepep_output = AscendDeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
self.handle[1],
expected_m,
)
else:
deepep_output = DeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
)
return deepep_output

def _dispatch_core(
self,
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,11 @@ def forward_npu(

# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
router_logits = router_logits.to(torch.float32)
return torch_npu.npu_moe_gating_top_k(
router_logits,
k=self.top_k,
bias=self.correction_bias,
bias=self.correction_bias.to(torch.float32),
k_group=self.topk_group,
group_count=self.num_expert_group,
group_select_mode=1,
Expand Down
70 changes: 39 additions & 31 deletions python/sglang/srt/layers/quantization/w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
import importlib
import sys
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
cast,
)

import torch
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -79,22 +90,16 @@ def _rmsnorm_forward_oot(
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
original_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(original_dtype)

x = (
torch_npu.npu_rms_norm(
x, self.weight.to(torch.float32), self.variance_epsilon
)[0]
+ self.bias
)
out, _, residual_out = torch_npu.npu_add_rms_norm(
residual, x, self.weight.data, self.variance_epsilon
)
out = out + self.bias
return out.to(x.dtype), residual_out

if residual is None:
return x.to(original_dtype)
return x.to(original_dtype), residual
out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
out = out + self.bias
return out.to(x.dtype)

return _rmsnorm_forward_oot

Expand Down Expand Up @@ -571,8 +576,10 @@ def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
# To prevent import loops
from sglang.srt.layers.linear import RowParallelLinear

original_dtype = x.dtype
if original_dtype != torch.int8:
x = torch_npu.npu_quantize(
Expand All @@ -583,8 +590,12 @@ def apply(
-1,
True,
)

quant_bias = layer.quant_bias if tp_rank == 0 else None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in Attention TP>1 case)
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
quant_bias = None
else:
quant_bias = layer.quant_bias
return torch_npu.npu_quant_matmul(
x,
layer.weight,
Expand Down Expand Up @@ -651,13 +662,21 @@ def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
# To prevent import loops
from sglang.srt.layers.linear import RowParallelLinear

original_dtype = x.dtype
if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)

quant_bias = layer.quant_bias if tp_rank == 0 else None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in Attention TP>1 case)
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
quant_bias = None
else:
quant_bias = layer.quant_bias

return ops.quant_matmul(
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
)
Expand Down Expand Up @@ -737,11 +756,6 @@ def apply(
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from sglang.srt.layers.linear import RowParallelLinear

if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias)


Expand Down Expand Up @@ -780,7 +794,6 @@ def apply(
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
original_dtype = x.dtype
# use ATB quantize
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
return torch_npu.npu_quant_matmul(
quant_out,
Expand Down Expand Up @@ -863,11 +876,6 @@ def apply(
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from sglang.srt.layers.linear import RowParallelLinear

if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias)


Expand Down
Loading
Loading