Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b7208d4
add set_default_server_args
iforgetmyname Nov 27, 2025
f1c806e
add init_npu_backend
iforgetmyname Nov 27, 2025
3fac40f
first remove of is_npu
iforgetmyname Nov 27, 2025
4d416a1
NPUPagedTokenToKVPoolAllocator, NPUMHATokenToKVPool and NPUMLATokenTo…
iforgetmyname Nov 28, 2025
f0e2a5a
fix missing import
iforgetmyname Nov 28, 2025
9de1ad3
second remove of is_npu
iforgetmyname Nov 29, 2025
db7675c
refactor topk
iforgetmyname Nov 29, 2025
b192c23
refactor ascend llm backend
iforgetmyname Nov 29, 2025
0e6e557
fix missing import
iforgetmyname Nov 29, 2025
5bd1985
fix missing import
iforgetmyname Nov 29, 2025
f043976
fix missing import
iforgetmyname Nov 29, 2025
7d11eed
NPUW8A8LinearMethod & NPUW8A8DynamicLinearMethod
iforgetmyname Dec 1, 2025
03e2897
fix caller
iforgetmyname Dec 1, 2025
73cd2aa
fix load warning and shape error
iforgetmyname Dec 1, 2025
c572c17
fix warning msg typo
iforgetmyname Dec 1, 2025
472fad0
renaming
iforgetmyname Dec 1, 2025
e360b40
NPUW8A8Int8DynamicMoEMethod, NPUW4A8Int4DynamicMoEMethod, NPUW4A16Int…
iforgetmyname Dec 1, 2025
affcee9
fix import error
iforgetmyname Dec 1, 2025
e72a47d
fix import error
iforgetmyname Dec 1, 2025
12edd98
add modelslim
iforgetmyname Dec 1, 2025
fccd5e6
refactor mla prepare&core
iforgetmyname Dec 1, 2025
e01b705
fix import error
iforgetmyname Dec 2, 2025
547421a
Merge remote-tracking branch 'upstream/main' into refactor
iforgetmyname Dec 2, 2025
959622f
fix precision issue
iforgetmyname Dec 2, 2025
af965ca
fix import error
iforgetmyname Dec 2, 2025
8f50237
restructer npu folder
iforgetmyname Dec 2, 2025
fffa39d
clean out mla_preprocess
iforgetmyname Dec 2, 2025
662cc00
move graph runners
iforgetmyname Dec 2, 2025
ac55f02
move cmo
iforgetmyname Dec 2, 2025
dbb5a25
fix mlapo get quant config
iforgetmyname Dec 2, 2025
af678cf
fix import error
iforgetmyname Dec 2, 2025
a7b832a
Merge branch 'main' into npu_refactor
iforgetmyname Dec 2, 2025
bd9ec29
change quant type and CODEOWNERS
iforgetmyname Dec 2, 2025
e434a34
fix typo
iforgetmyname Dec 2, 2025
b66f31b
fix deepseek_v2 lite accuracy
iforgetmyname Dec 3, 2025
6ae7e66
revert back topk.py change
iforgetmyname Dec 3, 2025
e4154c1
fix prefixcache start args
iforgetmyname Dec 3, 2025
84acb7f
Merge branch 'main' into npu_refactor
iforgetmyname Dec 3, 2025
60a1482
fix unquantized mtp layer breaks mlapo
iforgetmyname Dec 3, 2025
43f8286
Merge remote-tracking branch 'upstream/main' into npu_refactor
iforgetmyname Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
/python/sglang/srt/eplb @fzyzcjy @ch-wan
/python/sglang/srt/function_call @CatherineSue @JustinTong0323
/python/sglang/srt/grpc @CatherineSue @slin1237
/python/sglang/srt/hardware_backend/npu @ping1jing2 @iforgetmyname
/python/sglang/srt/layers @merrymercy @Ying1123 @Fridge003 @ispobock @HaiShaw @ch-wan @BBuf @Edwardf0t1
/python/sglang/srt/layers/quantization @ch-wan @BBuf @Edwardf0t1 @FlamingoPg @AniZpZ
/python/sglang/srt/layers/attention/ascend_backend.py @ping1jing2 @iforgetmyname
/python/sglang/srt/lora @Ying1123 @Fridge003 @lifuhuang
/python/sglang/srt/managers @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann @zhyncs
/python/sglang/srt/mem_cache @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann
/python/sglang/srt/mem_cache/allocator_ascend.py @ping1jing2 @iforgetmyname
/python/sglang/srt/model_executor @merrymercy @Ying1123 @hnyls2002 @Fridge003 @ispobock
/python/sglang/srt/model_executor/npu_graph_runner.py @ping1jing2 @iforgetmyname
/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @hebiao064
/python/sglang/srt/multimodal @mickqian @JustinTong0323 @yhyang201 @yuan-luo
/python/sglang/srt/speculative @Ying1123 @merrymercy @hnyls2002
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ class Envs:
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)

# NPU
SGLANG_NPU_DISABLE_ACL_FORMAT_WEIGHT = EnvBool(False)

# Quantization
SGLANG_INT4_WEIGHT = EnvBool(False)
SGLANG_CPU_QUANTIZATION = EnvBool(False)
Expand Down
10 changes: 2 additions & 8 deletions python/sglang/srt/eplb/expert_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
from sglang.srt.metrics.collector import ExpertDispatchCollector
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_int_env_var, is_npu

_is_npu = is_npu()
from sglang.srt.utils import Withable, get_int_env_var

if TYPE_CHECKING:
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
Expand Down Expand Up @@ -465,10 +463,6 @@ def _list_sum(a: List, b: List) -> List:
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
super().__init__(*args, **kwargs)
if not _is_npu:
device = "cuda"
else:
device = "npu"
self._enable_global_physical_experts = enable_global_physical_experts
self._data = torch.zeros(
(
Expand All @@ -480,7 +474,7 @@ def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
),
),
dtype=torch.int,
device=device,
device="cuda",
)

def reset(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch

if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache

from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
from sglang.srt.utils import get_num_new_pages, next_power_of_2

if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache


def alloc_extend_kernel_ascend(
def _alloc_extend_naive(
prefix_lens,
seq_lens,
last_loc,
Expand Down Expand Up @@ -65,14 +63,14 @@ def alloc_extend_kernel_ascend(
).view(-1)


class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
class NPUPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
kvcache: "KVCache",
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
Expand Down Expand Up @@ -130,7 +128,7 @@ def alloc_extend(
dtype=torch.int32,
device=self.device,
)
alloc_extend_kernel_ascend(
_alloc_extend_naive(
prefix_lens,
seq_lens,
last_loc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import torch_npu

from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.hardware_backend.npu.attention.mla_preprocess import (
is_mla_preprocess_enabled,
)
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Optional

import torch
import torch.nn.functional as F

from sglang.srt.utils import get_bool_env_var, is_npu
from sglang.srt.hardware_backend.npu.utils import npu_format_cast
from sglang.srt.utils import get_bool_env_var

_is_npu = is_npu()
_ENABLE_MLA_PREPROCESS_FLAG = get_bool_env_var("SGLANG_NPU_USE_MLAPO")
_NPU_FORMAT_NZ = 29
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig


@lru_cache(maxsize=1)
def is_mla_preprocess_enabled() -> bool:
return _is_npu and _ENABLE_MLA_PREPROCESS_FLAG


if is_mla_preprocess_enabled():
import sgl_kernel_npu # noqa: F401
import torch_npu

torch.npu.config.allow_internal_format = True
torch.npu.set_compile_mode(jit_compile=False)
return get_bool_env_var("SGLANG_NPU_USE_MLAPO")


def round_up(val: int, align: int) -> int:
Expand Down Expand Up @@ -66,6 +62,7 @@ def __init__(
num_local_heads,
qk_nope_head_dim,
qk_rope_head_dim,
quant_config: Optional["QuantizationConfig"] = None,
):
super().__init__()
self.qkv_a_proj = fused_qkv_a_proj_with_mqa
Expand All @@ -75,6 +72,7 @@ def __init__(
self.w_kc = w_kc.contiguous()
self.rotary_emb = rotary_emb
self.layer_id = layer_id
self.quant_config = quant_config
self.has_preprocess_weights = False
self.dtype = None

Expand Down Expand Up @@ -124,9 +122,7 @@ def preprocess_weights(self, hidden_states):
.unsqueeze(0)
.contiguous()
)
self.qkv_a_proj_weight_nz = torch_npu.npu_format_cast(
fused_qkv_a_proj_with_mqa_weight_nz, _NPU_FORMAT_NZ
)
self.qkv_a_proj_weight_nz = npu_format_cast(fused_qkv_a_proj_with_mqa_weight_nz)

# matmul_0 deq_scale [2112]
fused_qkv_a_proj_with_mqa_deq_scale_q = self.qkv_a_proj.deq_scale.data[
Expand Down Expand Up @@ -198,9 +194,7 @@ def preprocess_weights(self, hidden_states):
q_b_proj_weight_nz = (
transdata(q_b_proj_weight, block_size=(16, 32)).unsqueeze(0).contiguous()
)
self.q_b_proj_weight_nz = torch_npu.npu_format_cast(
q_b_proj_weight_nz, _NPU_FORMAT_NZ
)
self.q_b_proj_weight_nz = npu_format_cast(q_b_proj_weight_nz)

# matmul_1 deq_scale [num_head * 192]
q_b_proj_deq_scale = self.q_b_proj.deq_scale.data.clone()
Expand Down Expand Up @@ -280,7 +274,7 @@ def forward_absorb_prepare_npu_rms_norm_cache(
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
q_pe = torch.ops.npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)

latent_cache = latent_cache.view(
Expand All @@ -300,7 +294,7 @@ def forward_absorb_prepare_npu_rms_norm_cache(
1,
forward_batch.attn_backend.qk_rope_head_dim,
)
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
k_rope, k_nope, _, _ = torch.ops.npu.npu_kv_rmsnorm_rope_cache(
latent_cache,
self.kv_a_layernorm.weight,
cos,
Expand Down Expand Up @@ -378,10 +372,11 @@ def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator)
)

def forward(self, positions, hidden_states, forward_batch, zero_allocator):
assert self.quant_config and self.quant_config.get_name() == "modelslim"
# route by `qkv_a_proj` quant type as MTP layers can be unquantized
_is_w8a8 = (
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
and self.qkv_a_proj.quant_method.quantization_config.get_name()
== "w8a8_int8"
hasattr(self.qkv_a_proj.quant_method, "quant_config")
and self.qkv_a_proj.quant_method.quant_config.get_name() == "modelslim"
)
if _is_w8a8:
return self.forward_mlapo(
Expand Down
54 changes: 54 additions & 0 deletions python/sglang/srt/hardware_backend/npu/cmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

cmo_stream = None


def get_cmo_stream():
"""
Cache Management Operation(CMO).
Launch a new stream to prefetch the weight of matmul when running other
AIV or communication kernels, aiming to overlap the memory access time.
"""
global cmo_stream
return cmo_stream


def set_cmo_stream(stream):
global cmo_stream
cmo_stream = stream


def prepare_weight_cache(handle, cache, PREFETCH_MAX_SIZE=1000000000):
"""
PREFETCH_MAX_SIZE: maximum size (bytes) for each prefetch operation.
This affects the time spent in prefetch:
time ≈ PREFETCH_MAX_SIZE / system_bandwidth
"""
import torch_npu

stream = get_cmo_stream()
if stream is None:
stream = torch.npu.Stream()
set_cmo_stream(stream)
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
if isinstance(cache, list):
for weight in cache:
torch_npu.npu_prefetch(
weight,
handle,
PREFETCH_MAX_SIZE,
)
else:
torch_npu.npu_prefetch(
cache,
handle,
PREFETCH_MAX_SIZE,
)


def wait_cmo_stream():
stream = get_cmo_stream()
if stream is not None:
cur_stream = torch.npu.current_stream()
cur_stream.wait_stream(stream)
Loading
Loading