diff --git a/python/sglang/srt/configs/qwen3_next.py b/python/sglang/srt/configs/qwen3_next.py index 8d0981c39854..94b74cfcc7a2 100644 --- a/python/sglang/srt/configs/qwen3_next.py +++ b/python/sglang/srt/configs/qwen3_next.py @@ -20,8 +20,11 @@ from transformers.utils import logging from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape +from sglang.srt.configs.update_config import adjust_tp_num_heads_if_necessary +from sglang.srt.utils import is_cpu logger = logging.get_logger(__name__) +_is_cpu = is_cpu() class HybridLayerType(enum.Enum): @@ -276,6 +279,10 @@ def full_attention_layer_ids(self): def mamba2_cache_params(self) -> Mamba2CacheParams: from sglang.srt.layers.dp_attention import get_attention_tp_size + if _is_cpu: + world_size = get_attention_tp_size() + adjust_tp_num_heads_if_necessary(self, world_size, False) + shape = Mamba2StateShape.create( tp_world_size=get_attention_tp_size(), intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads, diff --git a/python/sglang/srt/configs/update_config.py b/python/sglang/srt/configs/update_config.py index 46d156d7b90a..f357a7635608 100644 --- a/python/sglang/srt/configs/update_config.py +++ b/python/sglang/srt/configs/update_config.py @@ -54,6 +54,44 @@ def get_num_heads_padding_size(tp_size, weight_block_size, head_dim): return pad_size +def adjust_tp_num_heads_if_necessary(model_config, tp_size, is_post_update): + # is_post_update: whether to update an existing config + from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size + + # Linear attn check logic + if hasattr(model_config, "linear_num_key_heads") and hasattr( + model_config, "linear_num_value_heads" + ): + if ( + model_config.linear_num_key_heads % tp_size != 0 + or model_config.linear_num_value_heads % tp_size != 0 + ): + pad_size = tp_size + linear_num_key_heads_cpu = pad_vocab_size( + model_config.linear_num_key_heads, pad_size + ) + linear_num_value_heads_cpu = ( + linear_num_key_heads_cpu + * model_config.linear_num_value_heads + // model_config.linear_num_key_heads + ) + if is_post_update: + model_config.linear_num_key_heads_cpu = linear_num_key_heads_cpu + model_config.linear_num_value_heads_cpu = linear_num_value_heads_cpu + else: + model_config.linear_num_key_heads = linear_num_key_heads_cpu + model_config.linear_num_value_heads = linear_num_value_heads_cpu + + else: + if is_post_update: + model_config.linear_num_key_heads_cpu = ( + model_config.linear_num_key_heads + ) + model_config.linear_num_value_heads_cpu = ( + model_config.linear_num_value_heads + ) + + def update_intermediate_size(model_config, attr_name, intermediate_padding_size): attr_value = intermediate_padding_size if hasattr(model_config, "hf_config") and hasattr( @@ -137,6 +175,8 @@ def adjust_config_with_unaligned_cpu_tp( model_config.hf_config.num_attention_heads = num_attention_heads model_config.hf_text_config.num_attention_heads = num_attention_heads + adjust_tp_num_heads_if_necessary(model_config.hf_config, tp_size, True) + intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size) model_config = update_intermediate_size( model_config, "moe_intermediate_size", intermediate_padding_size @@ -147,6 +187,9 @@ def adjust_config_with_unaligned_cpu_tp( model_config = update_intermediate_size( model_config, "intermediate_size_mlp", intermediate_padding_size ) + model_config = update_intermediate_size( + model_config, "shared_expert_intermediate_size", intermediate_padding_size + ) if ( hasattr(model_config.hf_config, "vision_config") and model_config.hf_config.vision_config.model_type == "siglip_vision_model" diff --git a/python/sglang/srt/layers/amx_utils.py b/python/sglang/srt/layers/amx_utils.py index 8e1209ea0202..1fb704f9517a 100644 --- a/python/sglang/srt/layers/amx_utils.py +++ b/python/sglang/srt/layers/amx_utils.py @@ -7,13 +7,17 @@ logger = logging.getLogger(__name__) -def amx_process_weight_after_loading(weight): +def amx_process_weight_after_loading(weight, is_conv=False): if weight.device != torch.device("cpu"): return weight if not cpu_has_amx_support(): return weight - - return torch.ops.sgl_kernel.convert_weight_packed(weight) + if is_conv: + return torch.ops.sgl_kernel.causal_conv1d_weight_pack( + weight.view(-1, weight.size(-1)) + ) + else: + return torch.ops.sgl_kernel.convert_weight_packed(weight) # TODO: currently gemm kernel has the below requirements: @@ -30,6 +34,36 @@ def dim_is_supported(weight): return is_oc_support and is_ic_support +def dtype_is_supported(weight): + return weight.dtype in [ + torch.float16, + torch.bfloat16, + torch.int8, + torch.float8_e4m3fn, + ] + + +def is_dim_conv_weight(weight): + return weight.dim() == 3 and weight.size(1) == 1 + + +def _init_amx_conv_state(conv_state): + # CPU AMX layout for conv_state kernel optimization + conv_state_cpu = [] + for conv_shape_t in conv_state: + conv_shape_new = conv_shape_t.as_strided_( + conv_shape_t.size(), + ( + conv_shape_t.stride(0), + conv_shape_t.stride(1), + 1, + conv_shape_t.size(2), + ), + ) + conv_state_cpu.append(conv_shape_new) + return conv_state_cpu + + def _amx_process_weight_after_loading( module, weight_names, transpose_dims=None ) -> None: @@ -48,22 +82,30 @@ def _amx_process_weight_after_loading( if transpose_dims and transpose_dims[i]: weight_tensor = weight_tensor.transpose(*transpose_dims[i]) - + is_conv_weight = is_dim_conv_weight(weight_tensor) # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim. - if not dim_is_supported(weight_tensor): + if ( + (not dim_is_supported(weight_tensor)) + or not dtype_is_supported(weight_tensor) + ) and (not is_conv_weight): logger.warning( - f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. " + f"Unsupported dimension or dtype for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} and dtype {weight_tensor.dtype} in {module}. " f"The derived (OC, IC) dimensions must be divisible by (16, 32). " ) module.use_intel_amx_backend = False return packed_weight = torch.nn.Parameter( - amx_process_weight_after_loading(weight_tensor), + amx_process_weight_after_loading(weight_tensor, is_conv_weight), requires_grad=False, ) packed_weight.__dict__ = weight_tensor.__dict__ setattr(module, weight_name, packed_weight) + if is_conv_weight: + # need to use inplace copy for conv weight amx packing, + # as its usage in radix_linear_attention will use the original conv weight. + weight_tensor = weight_tensor.view(-1, weight_tensor.size(-1)) + weight_tensor.copy_(packed_weight) module.use_intel_amx_backend = ( device == torch.device("cpu") and cpu_has_amx_support() diff --git a/python/sglang/srt/layers/attention/fla/layernorm_gated.py b/python/sglang/srt/layers/attention/fla/layernorm_gated.py index 5d55247da3f5..7bc7b9f47c48 100644 --- a/python/sglang/srt/layers/attention/fla/layernorm_gated.py +++ b/python/sglang/srt/layers/attention/fla/layernorm_gated.py @@ -14,9 +14,17 @@ import triton.language as tl from einops import rearrange -from sglang.srt.utils import cdiv, device_context, is_npu, next_power_of_2 +from sglang.srt.utils import ( + cdiv, + cpu_has_amx_support, + device_context, + is_cpu, + is_npu, + next_power_of_2, +) _is_npu = is_npu() +_use_cpu = is_cpu() and cpu_has_amx_support() def rms_norm_ref( @@ -392,13 +400,21 @@ def reset_parameters(self): def forward(self, x, z=None): """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - return layernorm_fn( - x, - self.weight, - self.bias, - z=z, - eps=self.eps, - group_size=self.group_size, - norm_before_gate=self.norm_before_gate, - is_rms_norm=True, - ) + if _use_cpu: + assert ( + self.norm_before_gate and self.group_size is None + ), "CPU rmsnorm_gated currently only supports norm before gate without group size" + return torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu( + x, self.weight, z, self.eps + ) + else: + return layernorm_fn( + x, + self.weight, + self.bias, + z=z, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + is_rms_norm=True, + ) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 397427826038..afbc840171fc 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -5,11 +5,8 @@ import triton.language as tl from einops import rearrange -from sglang.jit_kernel.cutedsl_gdn import cutedsl_fused_sigmoid_gating_delta_rule_update from sglang.srt.environ import Envs from sglang.srt.layers.attention.base_attn_backend import AttentionBackend -from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule -from sglang.srt.layers.attention.fla.chunk_delta_h import CHUNK_SIZE as FLA_CHUNK_SIZE from sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating from sglang.srt.layers.attention.fla.fused_recurrent import ( fused_recurrent_gated_delta_rule_update, @@ -17,7 +14,6 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( fused_sigmoid_gating_delta_rule_update, ) -from sglang.srt.layers.attention.fla.kda import chunk_kda from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( PAD_SLOT_ID, causal_conv1d_fn, @@ -36,9 +32,20 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput -from sglang.srt.utils import is_cuda, is_npu +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu from sglang.srt.utils.common import rank0_log +if not is_cpu(): + # fix import error on CPU device, no impacts when non-CPU path + from sglang.jit_kernel.cutedsl_gdn import ( + cutedsl_fused_sigmoid_gating_delta_rule_update, + ) + from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule + from sglang.srt.layers.attention.fla.chunk_delta_h import ( + CHUNK_SIZE as FLA_CHUNK_SIZE, + ) + from sglang.srt.layers.attention.fla.kda import chunk_kda + if is_cuda(): from sglang.srt.layers.attention.mamba.causal_conv1d import ( causal_conv1d_fn as causal_conv1d_fn_cuda, @@ -59,6 +66,23 @@ fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu causal_conv1d_fn = causal_conv1d_fn_npu causal_conv1d_update = causal_conv1d_update_npu +elif is_cpu(): + assert ( + cpu_has_amx_support() + ), "CPU requires AMX support for hybrid linear attn backend" + from sgl_kernel.mamba import ( + causal_conv1d_fn_cpu, + causal_conv1d_update_cpu, + chunk_gated_delta_rule_cpu, + ) + + chunk_gated_delta_rule = chunk_gated_delta_rule_cpu + causal_conv1d_fn = causal_conv1d_fn_cpu + causal_conv1d_update = causal_conv1d_update_cpu + fused_sigmoid_gating_delta_rule_update = ( + torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu + ) + fused_gdn_gating = torch.ops.sgl_kernel.fused_gdn_gating_cpu # Kernel to track mamba states if needed based on track mask @@ -790,9 +814,10 @@ def __init__(self, model_runner: ModelRunner): self.conv_states_shape = ( model_runner.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape ) - assert ( - self.conv_states_shape[-1] < FLA_CHUNK_SIZE - ), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}" + if not is_cpu(): + assert ( + self.conv_states_shape[-1] < FLA_CHUNK_SIZE + ), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}" use_cutedsl = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE.get() rank0_log(f"CuTe DSL GDN decode enabled: {use_cutedsl}") @@ -983,7 +1008,7 @@ def forward_extend( # Only cuda env uses fuse ssm_states update recurrent_state = ssm_states recurrent_state_indices_args = {"initial_state_indices": cache_indices} - if is_npu(): + if is_npu() or is_cpu(): recurrent_state = ssm_states[cache_indices] recurrent_state_indices_args = {} core_attn_out, last_recurrent_state, h = chunk_gated_delta_rule( @@ -998,7 +1023,7 @@ def forward_extend( use_qk_l2norm_in_kernel=True, **recurrent_state_indices_args, ) - if is_npu(): + if is_npu() or is_cpu(): last_recurrent_state = last_recurrent_state.to( ssm_states.dtype, copy=False ) diff --git a/python/sglang/srt/layers/attention/intel_amx_backend.py b/python/sglang/srt/layers/attention/intel_amx_backend.py index 4b2974c44e0d..7ab2753741c3 100644 --- a/python/sglang/srt/layers/attention/intel_amx_backend.py +++ b/python/sglang/srt/layers/attention/intel_amx_backend.py @@ -24,8 +24,16 @@ def __init__(self, model_runner: ModelRunner): model_runner.model_config.num_attention_heads // model_runner.tp_size ) - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] - + # [NB]: `layer_id` set to 0 for qwen3-next models, as not all attn layers require kv pool + # using "full_attention_layer_id_mapping" to map which layer needs kv pool + layer_id = 0 + if hasattr(model_runner.token_to_kv_pool, "full_attention_layer_id_mapping"): + layer_id = [*model_runner.token_to_kv_pool.full_attention_layer_id_mapping][ + 0 + ] + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( + layer_id + ).shape[-1] self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu diff --git a/python/sglang/srt/layers/attention/mamba/mamba.py b/python/sglang/srt/layers/attention/mamba/mamba.py index 286c51de945e..46d0d5b3f951 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba.py +++ b/python/sglang/srt/layers/attention/mamba/mamba.py @@ -29,7 +29,7 @@ composed_weight_loader, sharded_weight_loader, ) -from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs +from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs if is_cuda(): from sglang.srt.layers.attention.mamba.causal_conv1d import ( @@ -69,6 +69,19 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # - track boundary of (sharded) param, and loaded_weight, respectively boundary, loaded_boundary = 0, 0 + # Calculate padding size for CPU when TP odd size + if is_cpu(): + full_dim_sum = 0 + full_dim_list = [] + weight_full_dim_list = [] + for full_dim, _, _ in shard_spec: + full_dim_sum = full_dim_sum + full_dim + full_dim_list.append(full_dim) + for full_dim in full_dim_list: + weight_full_dim_list.append( + int(full_dim / full_dim_sum * loaded_weight.size(0)) + ) + # - iterate over the shard specs for full_dim, extra, duplicate_groups in shard_spec: # - full dim is the model dim (before TP). @@ -95,6 +108,33 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # - take these many dims from the loaded weight. take = min(shard_size, full_dim - extra - loaded_skip) + # CPU logic of padding size for qwen3-next + # TODO : make this common for all mamba. + if is_cpu() and loaded_weight.size(0) % tp_size != 0: + import copy + + loaded_weight_ = copy.deepcopy(loaded_weight) + q, k, v = torch.split( + loaded_weight_, + weight_full_dim_list, + dim=0, + ) + pad_qk = torch.zeros( + full_dim_list[0] - weight_full_dim_list[0], + loaded_weight.size(1), + loaded_weight.size(2), + ).to(loaded_weight.dtype) + pad_v = torch.zeros( + full_dim_list[2] - weight_full_dim_list[2], + loaded_weight.size(1), + loaded_weight.size(2), + ).to(loaded_weight.dtype) + q = torch.cat((q, pad_qk), dim=0) + k = torch.cat((k, pad_qk), dim=0) + v = torch.cat((v, pad_v), dim=0) + loaded_weight_qk = torch.cat((q, k), dim=0) + loaded_weight = torch.cat((loaded_weight_qk, v), dim=0) + # - always shard on dim 0 # - the ignore is for a mundane mypy error as it does not # seem to handle slices well. diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 0afbb15fd7e8..5363abcf13e9 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -52,7 +52,14 @@ set_mla_kv_buffer_triton, set_mla_kv_scale_buffer_triton, ) -from sglang.srt.utils import is_cuda, is_hip, is_npu, next_power_of_2 +from sglang.srt.utils import ( + cpu_has_amx_support, + is_cpu, + is_cuda, + is_hip, + is_npu, + next_power_of_2, +) from sglang.srt.utils.custom_op import register_custom_op from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -68,6 +75,8 @@ GB = 1024 * 1024 * 1024 _is_cuda = is_cuda() _is_npu = is_npu() +_is_cpu = is_cpu() +_cpu_has_amx_support = cpu_has_amx_support() _is_hip = is_hip() @@ -230,6 +239,13 @@ def __init__( ) for conv_shape in conv_state_shape ] + + if _is_cpu and _cpu_has_amx_support: + from sglang.srt.layers.amx_utils import _init_amx_conv_state + + # CPU uses a different layout of conv_state for kernel optimization + conv_state = _init_amx_conv_state(conv_state) + temporal_state = torch.zeros( size=(num_mamba_layers, size + 1) + temporal_state_shape, dtype=ssm_dtype, diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index a46b8ffc6799..c7904c47d67b 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -33,7 +33,10 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.dp_attention import get_attention_tp_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization.fp8 import Fp8Config @@ -48,6 +51,7 @@ from sglang.srt.utils import ( BAR_FORMAT, find_local_repo_dir, + is_cpu, log_info_on_rank0, print_warning_once, ) @@ -1040,9 +1044,25 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: shard_size = param.data.shape[shard_axis] start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) - return default_weight_loader(param, loaded_weight) + if ( + is_cpu() + and loaded_weight.size(0) % get_tensor_model_parallel_world_size() != 0 + and loaded_weight.dim() == 1 + ): + param_data = param.data # view copy on param for uneven padding + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + param_data, + loaded_weight, + 0, # param_data_start + start_idx, + shard_axis, + shard_size, + ) + return default_weight_loader(param_data, loaded_weight) + else: + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + return default_weight_loader(param, loaded_weight) return loader diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index dc85e9c52ad9..9e96797c054d 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -46,6 +46,8 @@ from sglang.srt.utils import ( LazyValue, add_prefix, + cpu_has_amx_support, + is_cpu, is_cuda, is_npu, make_layers, @@ -56,6 +58,8 @@ logger = logging.getLogger(__name__) _is_cuda = is_cuda() _is_npu = is_npu() +_is_cpu = is_cpu() +_is_amx_available = cpu_has_amx_support() import triton @@ -209,8 +213,16 @@ def __init__( self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_size = get_attention_tp_size() self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads + self.num_v_heads = ( + config.linear_num_value_heads + if not _is_cpu + else config.linear_num_value_heads_cpu + ) + self.num_k_heads = ( + config.linear_num_key_heads + if not _is_cpu + else config.linear_num_key_heads_cpu + ) self.head_k_dim = config.linear_key_head_dim self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads @@ -366,7 +378,7 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): return query, key, value, z, b, a def _forward_input_proj(self, hidden_states: torch.Tensor): - if _is_npu or get_global_server_args().enable_piecewise_cuda_graph: + if _is_cpu or _is_npu or get_global_server_args().enable_piecewise_cuda_graph: DUAL_STREAM_TOKEN_THRESHOLD = 0 else: DUAL_STREAM_TOKEN_THRESHOLD = 1024 @@ -416,7 +428,11 @@ def _forward( hidden_states ) - if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph: + if ( + self.num_v_heads // self.num_k_heads in [1, 2, 4] + and is_cuda_graph + and not _is_cpu + ): mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( projected_states_qkvz, projected_states_ba, @@ -425,6 +441,17 @@ def _forward( self.head_k_dim, self.head_v_dim, ) + elif _is_cpu and _is_amx_available: + mixed_qkv, z, b, a = ( + torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu( + projected_states_qkvz, + projected_states_ba, + self.num_k_heads // self.attn_tp_size, + self.num_v_heads // self.attn_tp_size, + self.head_k_dim, + self.head_v_dim, + ) + ) else: query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba @@ -433,7 +460,6 @@ def _forward( lambda x: x.reshape(x.shape[0], -1), (query, key, value) ) mixed_qkv = torch.cat((query, key, value), dim=-1) - core_attn_out = self.linear_attn( forward_batch, mixed_qkv=mixed_qkv, diff --git a/sgl-kernel/csrc/cpu/topk.cpp b/sgl-kernel/csrc/cpu/topk.cpp index b4bdcd0b7b37..0471661e58a7 100644 --- a/sgl-kernel/csrc/cpu/topk.cpp +++ b/sgl-kernel/csrc/cpu/topk.cpp @@ -525,6 +525,12 @@ topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t t case 256: LAUNCH_TOPK_SOFTMAX_KERNEL(256); break; + case 384: + LAUNCH_TOPK_SOFTMAX_KERNEL(384); + break; + case 512: + LAUNCH_TOPK_SOFTMAX_KERNEL(512); + break; default: TORCH_CHECK(false, "Unexpected num_experts: ", num_experts); } diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 8e8994e04c95..a24d3573be4a 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -78,7 +78,13 @@ transfer_kv_per_layer, transfer_kv_per_layer_mla, ) -from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update +from sgl_kernel.mamba import ( + causal_conv1d_fn_cpu, + causal_conv1d_fwd, + causal_conv1d_update, + causal_conv1d_update_cpu, + chunk_gated_delta_rule_cpu, +) from sgl_kernel.marlin import ( awq_marlin_moe_repack, awq_marlin_repack, diff --git a/sgl-kernel/python/sgl_kernel/mamba.py b/sgl-kernel/python/sgl_kernel/mamba.py index 85aa5b9479e1..a9ffbfcb5418 100644 --- a/sgl-kernel/python/sgl_kernel/mamba.py +++ b/sgl-kernel/python/sgl_kernel/mamba.py @@ -48,3 +48,73 @@ def causal_conv1d_update( conv_state_indices, pad_slot_id, ) + + +def causal_conv1d_fn_cpu( + mixed_qkv_transposed, + conv_weights, + bias, + activation, + conv_states, + has_initial_state, + cache_indices, + query_start_loc, + seq_lens_cpu, +): + return torch.ops.sgl_kernel.causal_conv1d_fwd_cpu( + mixed_qkv_transposed, + conv_weights, + bias, + conv_states, + query_start_loc, + cache_indices, + has_initial_state, + activation == "silu", + -1, + True, + ) + + +def causal_conv1d_update_cpu( + mixed_qkv, conv_states, conv_weights, bias, activation, conv_state_indices +): + return torch.ops.sgl_kernel.causal_conv1d_update_cpu( + mixed_qkv, + conv_states, + conv_weights, + bias, + activation == "silu", + None, + conv_state_indices, + -1, + True, + ) + + +def chunk_gated_delta_rule_cpu( + q, + k, + v, + g, + beta, + initial_state, + cu_seqlens, + head_first, + use_qk_l2norm_in_kernel, +): + core_attn_out, last_recurrent_state = ( + torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu( + q, + k, + v, + g, + beta, + initial_state, + True, # output_final_state + cu_seqlens, + head_first, + use_qk_l2norm_in_kernel, + ) + ) + h = None # Todo: add return h support + return core_attn_out, last_recurrent_state, h