diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 5cd2e576609d..532350ea890f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -35,12 +35,12 @@ _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip -if not (_is_npu or _is_hip): - pass - if _use_aiter: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe +elif _is_npu: + import torch_npu + logger = logging.getLogger(__name__) @@ -314,87 +314,44 @@ def forward_npu( assert self.quant_method is not None assert self.moe_runner_config.activation == "silu" - import torch_npu - from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker # NOTE: Ascend's Dispatch & Combine does not support FP16 output_dtype = torch.bfloat16 group_list_type = 1 - def _forward_normal(dispatch_output: DeepEPNormalDispatchOutput): + if DispatchOutputChecker.format_is_deepep_normal(dispatch_output): if TYPE_CHECKING: assert isinstance(dispatch_output, DeepEPNormalDispatchOutput) hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = ( dispatch_output ) - group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to( - hidden_states.device + group_list = torch.tensor( + num_recv_tokens_per_expert, + dtype=torch.int64, + device=hidden_states.device, ) - if self.w13_weight.dtype != torch.int8: - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w13_weight.permute(0, 2, 1)], - # per_token_scale=[hidden_states_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] - hidden_states = torch_npu.npu_swiglu(hidden_states) - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w2_weight.permute(0, 2, 1)], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] + + if self.w13_weight.dtype == torch.bfloat16: + hidden_states = npu_fused_moe_without_routing_weights_bf16( + self, hidden_states, group_list_type, group_list, output_dtype + ) else: - if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"): + input_quant = get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT") + if not input_quant and self.w13_weight.dtype != torch.int32: hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant( hidden_states ) - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w13_weight], - scale=[self.w13_weight_scale.to(output_dtype)], - per_token_scale=[hidden_states_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] - - # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( - hidden_states + hidden_states = self.quant_method.apply_without_routing_weights( + self, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, ) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w2_weight], - scale=[self.w2_weight_scale.to(output_dtype)], - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] - - return hidden_states - - def _forward_ll(dispatch_output: DeepEPLLDispatchOutput): + elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): if TYPE_CHECKING: assert isinstance(dispatch_output, DeepEPLLDispatchOutput) ( @@ -408,75 +365,50 @@ def _forward_ll(dispatch_output: DeepEPLLDispatchOutput): group_list = group_list.to(torch.int64) - if self.w13_weight.dtype != torch.int8: - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w13_weight.permute(0, 2, 1)], - # per_token_scale=[hidden_states_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] - hidden_states = torch_npu.npu_swiglu(hidden_states) - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w2_weight.permute(0, 2, 1)], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] + if self.w13_weight.dtype == torch.bfloat16: + hidden_states = npu_fused_moe_without_routing_weights_bf16( + self, hidden_states, group_list_type, group_list, output_dtype + ) else: - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w13_weight], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=torch.int32, - )[0] - - # act_fn: swiglu - hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=self.w13_weight_scale.to(torch.float32), - activation_scale=hidden_states_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=group_list, - activate_left=True, - quant_mode=1, + hidden_states = self.quant_method.apply_without_routing_weights( + self, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, ) + else: + raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}") - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w2_weight], - scale=[self.w2_weight_scale.to(output_dtype)], - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] + return hidden_states - return hidden_states - if DispatchOutputChecker.format_is_deepep_normal(dispatch_output): - return _forward_normal(dispatch_output) - elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): - return _forward_ll(dispatch_output) - else: - raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}") +def npu_fused_moe_without_routing_weights_bf16( + layer, hidden_states, group_list_type, group_list, output_dtype +): + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[layer.w13_weight.permute(0, 2, 1)], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + hidden_states = torch_npu.npu_swiglu(hidden_states) + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[layer.w2_weight.permute(0, 2, 1)], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + return hidden_states def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 5ceba2f67b6c..3212f02cca5f 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -1,9 +1,11 @@ from __future__ import annotations +import logging from types import MappingProxyType from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast import torch +from compressed_tensors.quantization import QuantizationStrategy from torch.nn.parameter import Parameter from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -21,6 +23,9 @@ QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, +) from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod @@ -43,11 +48,11 @@ _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() -if _is_cuda: - from sgl_kernel import int8_scaled_mm _is_npu = is_npu() -if _is_npu: +if _is_cuda: + from sgl_kernel import int8_scaled_mm +elif _is_npu: import torch_npu try: @@ -58,6 +63,8 @@ else: useMindIETurbo = True +logger = logging.getLogger(__name__) + # func refers to RMSNorm.__init__ def npu_wrapper_rmsnorm_init(func): @@ -192,7 +199,7 @@ def npu_fused_experts( class W8A8Int8Config(QuantizationConfig): - """Config class for W8A8 Int8 Quantization. + """Config class for W8A8 or W4A16 Quantization. - Weight: static, per-channel, symmetric - Activation: dynamic, per-token, symmetric @@ -202,12 +209,27 @@ def __init__(self, quant_config: Dict[str, Any] = {}): super().__init__() self.quant_description = quant_config self.is_dynamic = quant_config.get("is_dynamic", False) + self.is_moe_w4_dynamic = False ignore = cast(List[str], quant_config.get("ignore", [])) self.ignore = ignore if ignore is not None else [] packed_modules_mapping = quant_config.get("packed_modules_mapping", {}) self.packed_modules_mapping = ( packed_modules_mapping if packed_modules_mapping is not None else {} ) + self.target_scheme_map = ( + CompressedTensorsConfig._quantization_scheme_map_from_config( + config=quant_config + ) + ) + target = "MoEGMM" if "MoEGMM" in self.target_scheme_map else "Linear" + target_scheme = self.target_scheme_map.get(target, None) + if target_scheme is None: + self.is_moe_w4_dynamic = False + else: + weight_quant = target_scheme.get("weights") + input_quant = target_scheme.get("input_activations") + self.is_moe_w4_dynamic = self.is_dynamic_token_w4(weight_quant, input_quant) + self.is_moe_input_quant = input_quant if _is_npu: # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models @@ -256,7 +278,7 @@ def get_config_filenames(cls) -> List[str]: def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config: return cls(config) - def get_quant_method( + def _get_quant_method_npu( self, layer: torch.nn.Module, prefix: str, @@ -264,45 +286,73 @@ def get_quant_method( from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + if isinstance(layer, LinearBase): + if should_ignore_layer( + prefix, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + key = "model" + if "vision_model" in prefix: + key = "vision_model" + elif "visual" in prefix: + key = "visual" + packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {}) + prefix_in_quant_config = prefix + proj_name = prefix.split(".")[-1] + if proj_name in packed_modules_mapping_subset: + prefix_in_quant_config = prefix.replace( + proj_name, packed_modules_mapping_subset[proj_name][0] + ) + self.is_dynamic = ( + self.quant_description[prefix_in_quant_config + ".weight"] + == "W8A8_DYNAMIC" + ) + if self.is_layer_skipped(prefix, packed_modules_mapping_subset): + return UnquantizedLinearMethod() + return ( + NPU_W8A8DynamicLinearMethod(self) + if self.is_dynamic + else NPU_W8A8LinearMethod(self) + ) + elif isinstance(layer, FusedMoE): + prefix_in_quant_config = prefix + ".0.down_proj.weight" + is_moe_w4a8_dynamic = ( + self.quant_description.get(prefix_in_quant_config, "STATIC") + == "W4A8_DYNAMIC" + ) + if ( + self.is_moe_w4_dynamic and self.is_moe_input_quant is not None + ) or is_moe_w4a8_dynamic: + raise ValueError("npu does not support W4A8 currently!") + elif self.is_moe_w4_dynamic and self.is_moe_input_quant is None: + return NPU_W4A16MoEMethod(self) + else: + return NPU_W8A8MoEMethod(self) + return None + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional[QuantizeMethodBase]: if _is_npu: + return self._get_quant_method_npu(layer, prefix) + else: + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if should_ignore_layer( + prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping + ): + return UnquantizedLinearMethod() if isinstance(layer, LinearBase): - key = "model" - if "vision_model" in prefix: - key = "vision_model" - elif "visual" in prefix: - key = "visual" - packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {}) - prefix_in_quant_config = prefix - proj_name = prefix.split(".")[-1] - if proj_name in packed_modules_mapping_subset: - prefix_in_quant_config = prefix.replace( - proj_name, packed_modules_mapping_subset[proj_name][0] - ) - self.is_dynamic = ( - self.quant_description[prefix_in_quant_config + ".weight"] - == "W8A8_DYNAMIC" - ) - if self.is_layer_skipped(prefix, packed_modules_mapping_subset): - return UnquantizedLinearMethod() - return ( - NPU_W8A8DynamicLinearMethod(self) - if self.is_dynamic - else NPU_W8A8LinearMethod(self) - ) + return W8A8Int8LinearMethod(self) elif isinstance(layer, FusedMoE): - return NPU_W8A8MoEMethod(self) + return W8A8Int8MoEMethod(self) return None - if should_ignore_layer( - prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping - ): - return UnquantizedLinearMethod() - if isinstance(layer, LinearBase): - return W8A8Int8LinearMethod(self) - elif isinstance(layer, FusedMoE): - return W8A8Int8MoEMethod(self) - return None - def is_layer_skipped( self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) ): @@ -337,6 +387,27 @@ def is_layer_skipped( def get_scaled_act_names(self) -> List[str]: return [] + def is_dynamic_token_w4(self, weight_quant, input_quant) -> bool: + is_w4 = weight_quant.num_bits == 4 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value + ) + if input_quant is not None: + is_token = ( + weight_strategy + and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + else: + is_token = weight_strategy + is_dynamic = not weight_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_w4 and weight_quant.symmetric and is_token and is_dynamic + class W8A8Int8LinearMethod(LinearMethodBase): @@ -1050,3 +1121,373 @@ def apply( top_k=topk_ids.shape[1], ) return StandardCombineInput(hidden_states=output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[layer.w13_weight], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32, + )[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=layer.w13_weight_scale.to(torch.float32), + activation_scale=hidden_states_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[layer.w2_weight], + scale=[layer.w2_weight_scale.to(output_dtype)], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + return hidden_states + + +class NPU_W4A16MoEMethod(FusedMoEMethodBase): + """MoE method for NPU W4A16 quantization. + + This class search for specific quantization + implementations supported on NPU hardware for moe methods. + + Args: + quant_config: The NPU quantization config. + """ + + def __init__(self, quantization_config: W8A8Int8Config) -> None: + self.quantization_config = quantization_config + self.quant_method = self + self.pack_factor = 8 # weight dtype is int4, but use int32 to create + target = ( + "MoEGMM" if "MoEGMM" in quantization_config.target_scheme_map else "Linear" + ) + if target in quantization_config.target_scheme_map: + self.group_size = quantization_config.target_scheme_map[target][ + "weights" + ].group_size + else: + self.group_size = 128 + logger.warning_once("NPU_W4A16MoEMethod !!!") + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.num_experts = num_experts + if ( + extra_weight_attrs.get( + "intermediate_size_full", intermediate_size_per_partition + ) + // intermediate_size_per_partition + > 1 + ): + quant_method = FusedMoeWeightScaleSupported.GROUP.value + else: + quant_method = FusedMoeWeightScaleSupported.CHANNEL.value + extra_weight_attrs.update({"quant_method": quant_method}) + # weight + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + weight_scale_dtype = torch.bfloat16 + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # offset + w13_weight_offset = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + + w2_weight_offset = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + def pack_to_int32(self, weight: torch.Tensor): + assert weight.dim() == 3 + if weight.dtype == torch.int32: + # pack 8 int4 to int32, we use a int32 to represent a int4 + assert ( + weight.shape[-1] % 8 == 0 + ), "the last dim of weight needs to be divided by 8" + new_weight = torch_npu.npu_convert_weight_to_int4pack(weight.flatten(0, 1)) + new_weight = new_weight.view(weight.shape[0], weight.shape[1], -1) + elif weight.dtype == torch.int8: + # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 + assert ( + weight.shape[-1] % 4 == 0 + ), "the last dim of weight needs to be divided by 4" + new_weight = weight.view(torch.int32).contiguous() + else: + raise ValueError(f"{weight.dtype=} is not supported !") + return new_weight + + def unpack_from_int32( + self, + value: torch.Tensor, + num_bits: int, + shape: torch.Size = None, + packed_dim=1, + ) -> torch.Tensor: + """ + Unpacks a tensor of packed int32 weights into individual int8s, maintaining the + original bit range. + + Return tensors in int8 + + :param value: tensor to unpack + :param num_bits: number of bits to unpack each data point into + :param shape: shape to unpack into, used to remove padding + :returns: unpacked int8 tensor + """ + if value.dtype is not torch.int32: + raise ValueError( + f"Expected {torch.int32} but got {value.dtype}, Aborting unpack." + ) + + if num_bits > 8: + raise ValueError("Unpacking is only supported for less than 8 bits") + + pack_factor = 32 // num_bits + + # unpack + mask = (1 << num_bits) - 1 + + if packed_dim == 1: + unpacked = torch.zeros( + (value.shape[0], value.shape[1] * pack_factor), + device=value.device, + dtype=torch.int32, + ) + for i in range(pack_factor): + unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask + + # remove padding + if shape is not None: + original_row_size = int(shape[1]) + unpacked = unpacked[:, :original_row_size] + else: + unpacked = torch.zeros( + (value.shape[0] * pack_factor, value.shape[1]), + device=value.device, + dtype=torch.int32, + ) + for i in range(pack_factor): + unpacked[i::pack_factor, :] = (value >> (num_bits * i)) & mask + + # remove padding + original_row_size = int(shape[0]) + unpacked = unpacked[:original_row_size, :] + + # bits are packed in unsigned format, reformat to signed + # update the value range from unsigned to signed + offset = pow(2, num_bits) // 2 + unpacked = (unpacked - offset).to(torch.int8) + + return unpacked + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w13_weight_scale = layer.w13_weight_scale.data.transpose(-1, -2).contiguous() + w2_weight_scale = layer.w2_weight_scale.data.transpose(-1, -2).contiguous() + layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False) + + layer.w13_weight_offset = Parameter( + layer.w13_weight_offset.data.transpose(-1, -2).contiguous(), + requires_grad=False, + ) + layer.w2_weight_offset = Parameter( + layer.w2_weight_offset.data.transpose(-1, -2).contiguous(), + requires_grad=False, + ) + + # w = [n, k // 8] --> [k, n // 8] + # w13_weight = layer.w13_weight.data.transpose(1, 2).contiguous() + # w2_weight = layer.w2_weight.data.transpose(1, 2).contiguous() + unpacked_w13_weight = ( + self.unpack_from_int32(layer.w13_weight.data.flatten(0, 1), 4) + .view(layer.w13_weight.data.shape[0], layer.w13_weight.data.shape[1], -1) + .transpose(1, 2) + .contiguous() + .int() + ) + unpacked_w2_weight = ( + self.unpack_from_int32(layer.w2_weight.data.flatten(0, 1), 4) + .view(layer.w2_weight.data.shape[0], layer.w2_weight.data.shape[1], -1) + .transpose(1, 2) + .contiguous() + .int() + ) + + w13_weight = self.pack_to_int32(unpacked_w13_weight) + w2_weight = self.pack_to_int32(unpacked_w2_weight) + + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + topk_weights, topk_ids, _ = topk_output + topk_ids = topk_ids.to(torch.int32) + topk_weights = topk_weights.to(x.dtype) + output = npu_fused_experts( + hidden_states=x, + w13=layer.w13_weight, + w13_scale=layer.w13_weight_scale, + w13_offset=layer.w13_weight_offset, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_offset=layer.w2_weight_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=topk_ids.shape[1], + use_wna16=True, + ) + return StandardCombineInput(hidden_states=output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + if hidden_states_scale is None: + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[layer.w13_weight], + antiquant_scale=[layer.w13_weight_scale], + antiquant_offset=[layer.w13_weight_offset], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + + # gmm2: down_proj + out_hidden = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[layer.w2_weight], + antiquant_scale=[layer.w2_weight_scale], + antiquant_offset=[layer.w2_weight_offset], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + else: + raise ValueError( + "when weight is int4, hidden_states only supports non-quant dtype!" + ) + + return out_hidden diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c8dc396eca9c..ab41263ddf92 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -217,7 +217,7 @@ def add_chunked_prefix_cache_attention_backend(backend_name): SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) # Detect stragger ranks in model loading -UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 +UNBALANCED_MODEL_LOADING_TIMEOUT_S = 480 # leave more time for post data processing # the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077) MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3 diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 584b15bf40a5..3a2f65c46893 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -3979,6 +3979,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue + if _is_npu: + name = name.replace("weight_packed", "weight") # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise @@ -4006,7 +4008,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue + if _is_npu: + name = name.replace("weight_packed", "weight") name = name.replace(weight_name, param_name) + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader maybe_executor_submit(