Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
626ec0a
add CPU optimized frontend for qwen3-next
jianan-gu Nov 3, 2025
7d7fa12
minor fix
jianan-gu Nov 3, 2025
b1472a1
memory pool changes for amx conv
jianan-gu Nov 3, 2025
6be8b13
add TP padding for qwen3-next on CPU
jianan-gu Oct 31, 2025
5564da4
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Nov 14, 2025
7ee14bb
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Nov 19, 2025
13571bd
fix lint
jianan-gu Nov 19, 2025
fef27aa
Merge remote-tracking branch 'origin/main' into qwen-next-cpu-frontend
jianan-gu Dec 1, 2025
167a01d
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Dec 4, 2025
bf1e05d
Merge remote-tracking branch 'origin/main' into qwen-next-cpu-frontend
jianan-gu Dec 5, 2025
0d1559d
rebase with latest kernels
jianan-gu Dec 5, 2025
4ed825a
Merge remote-tracking branch 'upstream/main' into qwen-next-cpu-frontend
jianan-gu Dec 8, 2025
eeebbb2
Update python/sglang/srt/layers/attention/intel_amx_backend.py
jianan-gu Dec 8, 2025
c032b1e
refine codes
jianan-gu Dec 8, 2025
ad1d6e2
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Dec 8, 2025
51fc77d
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Dec 8, 2025
dd8aa35
Merge branch 'main' into qwen-next-cpu-frontend
FlamingoPg Dec 8, 2025
ffb443b
Merge branch 'main' into qwen-next-cpu-frontend
FlamingoPg Dec 8, 2025
d476fcd
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Dec 9, 2025
8a6dae6
minor fix after rebase
jianan-gu Dec 9, 2025
742ea26
refine mamba apis
jianan-gu Dec 9, 2025
188db21
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Dec 9, 2025
855d43e
Merge branch 'main' into qwen-next-cpu-frontend
FlamingoPg Dec 9, 2025
82ee88c
Merge branch 'main' into qwen-next-cpu-frontend
FlamingoPg Dec 10, 2025
75ef3be
Merge remote-tracking branch 'upstream/main' into qwen-next-cpu-frontend
jianan-gu Dec 11, 2025
24d3543
Merge remote-tracking branch 'upstream/main' into qwen-next-cpu-frontend
jianan-gu Dec 15, 2025
421dbaa
Merge remote-tracking branch 'upstream/main' into qwen-next-cpu-frontend
jianan-gu Dec 15, 2025
c40ae9d
adjust mamba cache after rebase
jianan-gu Dec 15, 2025
c777aa3
minor refinements
jianan-gu Dec 15, 2025
cbf8adb
final minor refinements
jianan-gu Dec 15, 2025
b987530
format
jianan-gu Dec 15, 2025
210adff
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Dec 16, 2025
79b4a70
Merge remote-tracking branch 'upstream/main' into qwen-next-cpu-frontend
jianan-gu Dec 16, 2025
5f89a7c
Merge remote-tracking branch 'origin/main' into qwen-next-cpu-frontend
jianan-gu Jan 13, 2026
9520d90
rebase api
jianan-gu Jan 13, 2026
e008c0a
refine api
jianan-gu Jan 13, 2026
9fb3ddf
Merge remote-tracking branch 'origin/main' into qwen-next-cpu-frontend
jianan-gu Jan 14, 2026
5b80c61
Merge remote-tracking branch 'origin/main' into qwen-next-cpu-frontend
jianan-gu Jan 19, 2026
9cc7b92
format after rebase
jianan-gu Jan 19, 2026
2c5309a
minor refinements
jianan-gu Jan 20, 2026
242f479
Merge remote-tracking branch 'origin/main' into qwen-next-cpu-frontend
jianan-gu Jan 20, 2026
1a97f90
refinements per reviews
jianan-gu Jan 21, 2026
2b19634
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Jan 21, 2026
37fa1fb
Merge remote-tracking branch 'origin/main' into qwen-next-cpu-frontend
jianan-gu Jan 21, 2026
fdfac34
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Jan 22, 2026
9837203
Merge remote-tracking branch 'origin/main' into qwen-next-cpu-frontend
jianan-gu Jan 22, 2026
7be31cb
minor refine after rebase
jianan-gu Jan 22, 2026
19f1eda
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Jan 23, 2026
2e5a1c1
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Jan 28, 2026
fd1eb27
Merge branch 'main' into qwen-next-cpu-frontend
jianan-gu Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/sglang/srt/configs/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
from transformers.utils import logging

from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.configs.update_config import adjust_tp_num_heads_if_necessary
from sglang.srt.utils import is_cpu

logger = logging.get_logger(__name__)
_is_cpu = is_cpu()


class HybridLayerType(enum.Enum):
Expand Down Expand Up @@ -276,6 +279,10 @@ def full_attention_layer_ids(self):
def mamba2_cache_params(self) -> Mamba2CacheParams:
from sglang.srt.layers.dp_attention import get_attention_tp_size

if _is_cpu:
world_size = get_attention_tp_size()
adjust_tp_num_heads_if_necessary(self, world_size, False)

shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
Expand Down
43 changes: 43 additions & 0 deletions python/sglang/srt/configs/update_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,44 @@ def get_num_heads_padding_size(tp_size, weight_block_size, head_dim):
return pad_size


def adjust_tp_num_heads_if_necessary(model_config, tp_size, is_post_update):
# is_post_update: whether to update an existing config
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size

# Linear attn check logic
if hasattr(model_config, "linear_num_key_heads") and hasattr(
model_config, "linear_num_value_heads"
):
if (
model_config.linear_num_key_heads % tp_size != 0
or model_config.linear_num_value_heads % tp_size != 0
):
pad_size = tp_size
linear_num_key_heads_cpu = pad_vocab_size(
model_config.linear_num_key_heads, pad_size
)
linear_num_value_heads_cpu = (
linear_num_key_heads_cpu
* model_config.linear_num_value_heads
// model_config.linear_num_key_heads
)
if is_post_update:
model_config.linear_num_key_heads_cpu = linear_num_key_heads_cpu
model_config.linear_num_value_heads_cpu = linear_num_value_heads_cpu
else:
model_config.linear_num_key_heads = linear_num_key_heads_cpu
model_config.linear_num_value_heads = linear_num_value_heads_cpu

else:
if is_post_update:
model_config.linear_num_key_heads_cpu = (
model_config.linear_num_key_heads
)
model_config.linear_num_value_heads_cpu = (
model_config.linear_num_value_heads
)


def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
attr_value = intermediate_padding_size
if hasattr(model_config, "hf_config") and hasattr(
Expand Down Expand Up @@ -137,6 +175,8 @@ def adjust_config_with_unaligned_cpu_tp(
model_config.hf_config.num_attention_heads = num_attention_heads
model_config.hf_text_config.num_attention_heads = num_attention_heads

adjust_tp_num_heads_if_necessary(model_config.hf_config, tp_size, True)

intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
model_config = update_intermediate_size(
model_config, "moe_intermediate_size", intermediate_padding_size
Expand All @@ -147,6 +187,9 @@ def adjust_config_with_unaligned_cpu_tp(
model_config = update_intermediate_size(
model_config, "intermediate_size_mlp", intermediate_padding_size
)
model_config = update_intermediate_size(
model_config, "shared_expert_intermediate_size", intermediate_padding_size
)
if (
hasattr(model_config.hf_config, "vision_config")
and model_config.hf_config.vision_config.model_type == "siglip_vision_model"
Expand Down
56 changes: 49 additions & 7 deletions python/sglang/srt/layers/amx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
logger = logging.getLogger(__name__)


def amx_process_weight_after_loading(weight):
def amx_process_weight_after_loading(weight, is_conv=False):
if weight.device != torch.device("cpu"):
return weight
if not cpu_has_amx_support():
return weight

return torch.ops.sgl_kernel.convert_weight_packed(weight)
if is_conv:
return torch.ops.sgl_kernel.causal_conv1d_weight_pack(
weight.view(-1, weight.size(-1))
)
else:
return torch.ops.sgl_kernel.convert_weight_packed(weight)


# TODO: currently gemm kernel has the below requirements:
Expand All @@ -30,6 +34,36 @@ def dim_is_supported(weight):
return is_oc_support and is_ic_support


def dtype_is_supported(weight):
return weight.dtype in [
torch.float16,
torch.bfloat16,
torch.int8,
torch.float8_e4m3fn,
]


def is_dim_conv_weight(weight):
return weight.dim() == 3 and weight.size(1) == 1


def _init_amx_conv_state(conv_state):
# CPU AMX layout for conv_state kernel optimization
conv_state_cpu = []
for conv_shape_t in conv_state:
conv_shape_new = conv_shape_t.as_strided_(
conv_shape_t.size(),
(
conv_shape_t.stride(0),
conv_shape_t.stride(1),
1,
conv_shape_t.size(2),
),
)
conv_state_cpu.append(conv_shape_new)
return conv_state_cpu


def _amx_process_weight_after_loading(
module, weight_names, transpose_dims=None
) -> None:
Expand All @@ -48,22 +82,30 @@ def _amx_process_weight_after_loading(

if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])

is_conv_weight = is_dim_conv_weight(weight_tensor)
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if not dim_is_supported(weight_tensor):
if (
(not dim_is_supported(weight_tensor))
or not dtype_is_supported(weight_tensor)
) and (not is_conv_weight):
logger.warning(
f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
f"Unsupported dimension or dtype for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} and dtype {weight_tensor.dtype} in {module}. "
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
)
module.use_intel_amx_backend = False
return

packed_weight = torch.nn.Parameter(
amx_process_weight_after_loading(weight_tensor),
amx_process_weight_after_loading(weight_tensor, is_conv_weight),
requires_grad=False,
)
packed_weight.__dict__ = weight_tensor.__dict__
setattr(module, weight_name, packed_weight)
if is_conv_weight:
# need to use inplace copy for conv weight amx packing,
# as its usage in radix_linear_attention will use the original conv weight.
weight_tensor = weight_tensor.view(-1, weight_tensor.size(-1))
weight_tensor.copy_(packed_weight)

module.use_intel_amx_backend = (
device == torch.device("cpu") and cpu_has_amx_support()
Expand Down
38 changes: 27 additions & 11 deletions python/sglang/srt/layers/attention/fla/layernorm_gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@
import triton.language as tl
from einops import rearrange

from sglang.srt.utils import cdiv, device_context, is_npu, next_power_of_2
from sglang.srt.utils import (
cdiv,
cpu_has_amx_support,
device_context,
is_cpu,
is_npu,
next_power_of_2,
)

_is_npu = is_npu()
_use_cpu = is_cpu() and cpu_has_amx_support()


def rms_norm_ref(
Expand Down Expand Up @@ -392,13 +400,21 @@ def reset_parameters(self):

def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return layernorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
is_rms_norm=True,
)
if _use_cpu:
assert (
self.norm_before_gate and self.group_size is None
), "CPU rmsnorm_gated currently only supports norm before gate without group size"
return torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu(
x, self.weight, z, self.eps
)
else:
return layernorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
is_rms_norm=True,
)
45 changes: 35 additions & 10 deletions python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,15 @@
import triton.language as tl
from einops import rearrange

from sglang.jit_kernel.cutedsl_gdn import cutedsl_fused_sigmoid_gating_delta_rule_update
from sglang.srt.environ import Envs
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
from sglang.srt.layers.attention.fla.chunk_delta_h import CHUNK_SIZE as FLA_CHUNK_SIZE
from sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating
from sglang.srt.layers.attention.fla.fused_recurrent import (
fused_recurrent_gated_delta_rule_update,
)
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update,
)
from sglang.srt.layers.attention.fla.kda import chunk_kda
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn,
Expand All @@ -36,9 +32,20 @@
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import is_cuda, is_npu
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
from sglang.srt.utils.common import rank0_log

if not is_cpu():
# fix import error on CPU device, no impacts when non-CPU path
from sglang.jit_kernel.cutedsl_gdn import (
cutedsl_fused_sigmoid_gating_delta_rule_update,
)
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
from sglang.srt.layers.attention.fla.chunk_delta_h import (
CHUNK_SIZE as FLA_CHUNK_SIZE,
)
from sglang.srt.layers.attention.fla.kda import chunk_kda

if is_cuda():
from sglang.srt.layers.attention.mamba.causal_conv1d import (
causal_conv1d_fn as causal_conv1d_fn_cuda,
Expand All @@ -59,6 +66,23 @@
fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu
causal_conv1d_fn = causal_conv1d_fn_npu
causal_conv1d_update = causal_conv1d_update_npu
elif is_cpu():
assert (
cpu_has_amx_support()
), "CPU requires AMX support for hybrid linear attn backend"
from sgl_kernel.mamba import (
causal_conv1d_fn_cpu,
causal_conv1d_update_cpu,
chunk_gated_delta_rule_cpu,
)

chunk_gated_delta_rule = chunk_gated_delta_rule_cpu
causal_conv1d_fn = causal_conv1d_fn_cpu
causal_conv1d_update = causal_conv1d_update_cpu
fused_sigmoid_gating_delta_rule_update = (
torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu
)
fused_gdn_gating = torch.ops.sgl_kernel.fused_gdn_gating_cpu


# Kernel to track mamba states if needed based on track mask
Expand Down Expand Up @@ -790,9 +814,10 @@ def __init__(self, model_runner: ModelRunner):
self.conv_states_shape = (
model_runner.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape
)
assert (
self.conv_states_shape[-1] < FLA_CHUNK_SIZE
), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}"
if not is_cpu():
assert (
self.conv_states_shape[-1] < FLA_CHUNK_SIZE
), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}"

use_cutedsl = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE.get()
rank0_log(f"CuTe DSL GDN decode enabled: {use_cutedsl}")
Expand Down Expand Up @@ -983,7 +1008,7 @@ def forward_extend(
# Only cuda env uses fuse ssm_states update
recurrent_state = ssm_states
recurrent_state_indices_args = {"initial_state_indices": cache_indices}
if is_npu():
if is_npu() or is_cpu():
recurrent_state = ssm_states[cache_indices]
recurrent_state_indices_args = {}
core_attn_out, last_recurrent_state, h = chunk_gated_delta_rule(
Expand All @@ -998,7 +1023,7 @@ def forward_extend(
use_qk_l2norm_in_kernel=True,
**recurrent_state_indices_args,
)
if is_npu():
if is_npu() or is_cpu():
last_recurrent_state = last_recurrent_state.to(
ssm_states.dtype, copy=False
)
Expand Down
12 changes: 10 additions & 2 deletions python/sglang/srt/layers/attention/intel_amx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@ def __init__(self, model_runner: ModelRunner):
model_runner.model_config.num_attention_heads // model_runner.tp_size
)

self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]

# [NB]: `layer_id` set to 0 for qwen3-next models, as not all attn layers require kv pool
# using "full_attention_layer_id_mapping" to map which layer needs kv pool
layer_id = 0
if hasattr(model_runner.token_to_kv_pool, "full_attention_layer_id_mapping"):
layer_id = [*model_runner.token_to_kv_pool.full_attention_layer_id_mapping][
0
]
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(
layer_id
).shape[-1]
self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu
self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu

Expand Down
Loading
Loading