Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
028ec9e
add int4 cpu frontend (awq)
jianan-gu Jul 21, 2025
4b168a1
remove vllm dependency for cpu
jianan-gu Jul 22, 2025
130c484
minor refine for dtype
jianan-gu Jul 23, 2025
7db42d6
Add CPU int4 moe frontend (#48)
jianan-gu Jul 30, 2025
e87da73
minor fix for da8w4
jianan-gu Jul 30, 2025
7842a9d
refinements
jianan-gu Aug 4, 2025
f92c440
refine block size check
jianan-gu Aug 4, 2025
7be3bb6
frontend refinements
jianan-gu Aug 24, 2025
fc63390
refinements for quant method
jianan-gu Aug 25, 2025
61a646a
Merge branch 'main' into cpu_int4_frontend
jianan-gu Aug 25, 2025
07beb4e
minor fix
jianan-gu Aug 26, 2025
063c3c7
typo for w4a16
jianan-gu Aug 26, 2025
200faab
add moe a8w4 path
jianan-gu Sep 9, 2025
9e6f12b
refine for int4 choice
jianan-gu Oct 27, 2025
d868d2d
Merge remote-tracking branch 'up/main' into cpu_int4_frontend
jianan-gu Oct 27, 2025
8816566
Merge branch 'main' into cpu_int4_frontend
jianan-gu Nov 13, 2025
8a243fa
refine CPUQuantMethod
jianan-gu Nov 13, 2025
289dbd5
refine naming
jianan-gu Nov 13, 2025
3c5cfac
Merge branch 'main' into cpu_int4_frontend
jianan-gu Nov 13, 2025
72b0734
Merge branch 'main' into cpu_int4_frontend
jianan-gu Nov 14, 2025
b35fbb2
Merge remote-tracking branch 'up/main' into cpu_int4_frontend
jianan-gu Nov 19, 2025
500a94f
Merge branch 'main' into cpu_int4_frontend
jianan-gu Dec 5, 2025
4f4cda1
Merge branch 'main' into cpu_int4_frontend
jianan-gu Dec 5, 2025
2d09997
rebase with fix
jianan-gu Dec 5, 2025
ecf97ab
Merge remote-tracking branch 'upstream/main' into cpu_int4_frontend
jianan-gu Dec 15, 2025
1c06950
Merge branch 'main' into cpu_int4_frontend
jianan-gu Dec 18, 2025
79bc5ca
Merge remote-tracking branch 'origin/main' into cpu_int4_frontend
jianan-gu Jan 19, 2026
9496518
Merge branch 'main' into cpu_int4_frontend
jianan-gu Jan 21, 2026
626e197
Merge branch 'main' into cpu_int4_frontend
jianan-gu Jan 21, 2026
d85d428
Merge remote-tracking branch 'origin/main' into cpu_int4_frontend
jianan-gu Jan 30, 2026
607eb96
Merge branch 'main' into cpu_int4_frontend
jianan-gu Jan 30, 2026
c630bc9
Merge branch 'main' into cpu_int4_frontend
jianan-gu Feb 2, 2026
205492c
minor refine for awq pack after rebase.
jianan-gu Feb 2, 2026
18165fa
Merge branch 'main' into cpu_int4_frontend
jianan-gu Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions python/sglang/srt/configs/update_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,24 @@ def may_get_weight_block_size(model_config, load_config):

if quant_config is not None and hasattr(quant_config, "weight_block_size"):
return getattr(quant_config, "weight_block_size")

if quant_config is not None and hasattr(quant_config, "group_size"):
return [getattr(quant_config, "group_size")]

return None


def get_moe_padding_size(weight_block_size):
if weight_block_size is not None:
# See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
assert (
len(weight_block_size) == 2
), "Only len(weight_block_size) == 2 is supported"
assert (
weight_block_size[0] == weight_block_size[1]
), "Only weight_block_size[0] == weight_block_size[1] is supported"

assert len(weight_block_size) in [
1,
2,
], "Only len(weight_block_size) in [1, 2] is supported"
if len(weight_block_size) == 2:
assert (
weight_block_size[0] == weight_block_size[1]
), "Only weight_block_size[0] == weight_block_size[1] is supported"
return weight_block_size[0]

return DEFAULT_MOE_PADDING_SIZE
Expand Down
88 changes: 57 additions & 31 deletions python/sglang/srt/layers/amx_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _init_amx_conv_state(conv_state):


def _amx_process_weight_after_loading(
module, weight_names, transpose_dims=None
module, weight_names, transpose_dims=None, qweight_packed_method=None
) -> None:
# Pack weight for get better performance on CPU
devices = {getattr(module, weight_name).device for weight_name in weight_names}
Expand All @@ -86,40 +86,66 @@ def _amx_process_weight_after_loading(
transpose_dims
), "len(weight_names) should be equal to len(transpose_dims)"

for i, weight_name in enumerate(weight_names):
weight_tensor = getattr(module, weight_name)

if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
is_conv_weight = is_dim_conv_weight(weight_tensor)
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if (
(not dim_is_supported(weight_tensor))
or not dtype_is_supported(weight_tensor)
) and (not is_conv_weight):
logger.warning(
f"Unsupported dimension or dtype for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} and dtype {weight_tensor.dtype} in {module}. "
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
)
module.use_intel_amx_backend = False
return

packed_weight = torch.nn.Parameter(
amx_process_weight_after_loading(weight_tensor, is_conv_weight),
requires_grad=False,
)
packed_weight.__dict__ = weight_tensor.__dict__
setattr(module, weight_name, packed_weight)
if is_conv_weight:
# need to use inplace copy for conv weight amx packing,
# as its usage in radix_linear_attention will use the original conv weight.
weight_tensor = weight_tensor.view(-1, weight_tensor.size(-1))
weight_tensor.copy_(packed_weight)

module.use_intel_amx_backend = (
device == torch.device("cpu") and cpu_has_amx_support()
)

if qweight_packed_method is None:
for i, weight_name in enumerate(weight_names):
weight_tensor = getattr(module, weight_name)

if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
is_conv_weight = is_dim_conv_weight(weight_tensor)
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if (
(not dim_is_supported(weight_tensor))
or not dtype_is_supported(weight_tensor)
) and (not is_conv_weight):
logger.warning(
f"Unsupported dimension or dtype for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} and dtype {weight_tensor.dtype} in {module}. "
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
)
module.use_intel_amx_backend = False
return

packed_weight = torch.nn.Parameter(
amx_process_weight_after_loading(weight_tensor, is_conv_weight),
requires_grad=False,
)
packed_weight.__dict__ = weight_tensor.__dict__
setattr(module, weight_name, packed_weight)
if is_conv_weight:
# need to use inplace copy for conv weight amx packing,
# as its usage in radix_linear_attention will use the original conv weight.
weight_tensor = weight_tensor.view(-1, weight_tensor.size(-1))
weight_tensor.copy_(packed_weight)
else:
assert qweight_packed_method in ["awq"] # TODO: add GPTQ, etc.
qweight_tensor = getattr(module, weight_names[0])
qzeros_tensor = getattr(module, weight_names[1])
scales_tensor = getattr(module, weight_names[2])
qweight, qzeros, scales = torch.ops.sgl_kernel.convert_weight_packed_scale_zp(
qweight_tensor, qzeros_tensor, scales_tensor
)
packed_qweight = torch.nn.Parameter(
qweight.detach(),
requires_grad=False,
)
packed_qzeros = torch.nn.Parameter(
qzeros.detach(),
requires_grad=False,
)
packed_scales = torch.nn.Parameter(
scales.detach(),
requires_grad=False,
)
packed_qweight.__dict__ = qweight_tensor.__dict__
packed_qzeros.__dict__ = qzeros_tensor.__dict__
packed_scales.__dict__ = scales_tensor.__dict__
setattr(module, weight_names[0], packed_qweight)
setattr(module, weight_names[1], packed_qzeros)
setattr(module, weight_names[2], packed_scales)
if (
module.use_intel_amx_backend
and hasattr(module, "bias")
Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def override_quantization_method(self, *args, **kwargs):
}
)

# subset of above quant methods, supported on CPU
CPU_QUANTIZATIPON_METHODS = {
"fp8": Fp8Config,
"w8a8_int8": W8A8Int8Config,
"compressed-tensors": CompressedTensorsConfig,
"awq": AWQConfig,
}

QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS}


Expand All @@ -94,6 +102,16 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
f"Invalid quantization method: {quantization}. "
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
)
from sglang.srt.utils import is_cpu

if is_cpu():
if quantization not in CPU_QUANTIZATIPON_METHODS:
raise ValueError(
f"Invalid quantization method on CPU: {quantization}. "
f"Available methods on CPU: {list(QUANTIZATION_METHODS.keys())}"
)
else:
return CPU_QUANTIZATIPON_METHODS[quantization]

return QUANTIZATION_METHODS[quantization]

Expand Down
95 changes: 89 additions & 6 deletions python/sglang/srt/layers/quantization/awq.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,19 @@
StandardDispatchOutput,
)

from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_cuda,
is_hip,
is_npu,
is_xpu,
use_intel_amx_backend,
)

_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu() and cpu_has_amx_support()
_is_xpu = is_xpu()
_is_npu = is_npu()

Expand All @@ -68,12 +77,18 @@
awq_dequantize_triton as awq_dequantize,
)

elif _is_cpu:
from sglang.srt.layers.amx_utils import (
CPUQuantMethod,
_amx_process_weight_after_loading,
)

elif _is_xpu:
from sgl_kernel import awq_dequantize

warnings.warn(f"XPU does not support fused_marlin_moe currently.")
else:
warnings.warn(f"Only CUDA, HIP and XPU support AWQ currently.")
warnings.warn(f"Only CUDA, HIP, CPU and XPU support AWQ currently.")

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -126,7 +141,10 @@ def get_name(self) -> str:
return "awq"

def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.float16] if not _is_npu else [torch.float16, torch.bfloat16]
if _is_npu or _is_cpu:
return [torch.float16, torch.bfloat16]
else:
return [torch.float16]

@classmethod
def get_min_capability(cls) -> int:
Expand Down Expand Up @@ -175,6 +193,8 @@ def get_quant_method(
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self) if not _is_cpu else AWQMoEIntelAMXMethod(self)
return None


Expand Down Expand Up @@ -428,16 +448,30 @@ def create_weights(
layer.register_parameter("scales", scales)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
if _is_cpu:
_amx_process_weight_after_loading(
layer, ["qweight", "qzeros", "scales"], None, "awq"
)
else:
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.int4_scaled_mm_cpu(
x,
layer.qweight,
layer.qzeros,
layer.scales,
bias,
)

qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
Expand Down Expand Up @@ -945,6 +979,55 @@ def apply(
return StandardCombineInput(hidden_states=output)


class AWQMoEIntelAMXMethod(AWQMoEMethod):
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
_amx_process_weight_after_loading(
layer, ["w13_qweight", "w13_qzeros", "w13_scales"], None, "awq"
)
_amx_process_weight_after_loading(
layer, ["w2_qweight", "w2_qzeros", "w2_scales"], None, "awq"
)

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config

def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput

assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."

x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
output = torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights,
topk_ids,
False, # inplace See [Note] inplace should be False in fused_experts.
CPUQuantMethod.INT4_W4A8,
layer.w13_scales, # w1_scale
layer.w2_scales, # w2_scale
layer.w13_qzeros,
layer.w2_qzeros,
None, # block_size
True, # is_vnni
)
return StandardCombineInput(hidden_states=output)


# Register fake implementations for torch.compile support
if _is_cuda:

Expand Down
15 changes: 6 additions & 9 deletions sgl-kernel/csrc/cpu/gemm_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,19 +594,16 @@ std::tuple<at::Tensor, at::Tensor> autoawq_to_int4pack(
at::Tensor qweight, // (*, K, N / 8), int32
at::Tensor qzeros) // (*, K / group_size, N / 8), int32
{
qweight = qweight.contiguous();
qzeros = qzeros.contiguous();
// bitshifts: [0, 4, 1, 5, 2, 6, 3, 7] * 4
auto bitshifts = at::tensor({0, 4, 1, 5, 2, 6, 3, 7}, at::kInt) * 4;
// qweight: assumed shape [..., K, N/8] (int32)
auto qweight_unsq = qweight.unsqueeze(-1); // [..., K, N/8, 1]
auto shape = qweight_unsq.sizes().vec(); // shape: [A, B, C, 1]
shape[3] = 8;
auto unpacked = at::bitwise_right_shift(qweight_unsq, bitshifts) & 0xF;
auto qweight_final = unpacked.flatten(-2).transpose(-1, -2).to(at::kByte);

auto unpacked = (at::bitwise_right_shift(qweight_unsq, bitshifts) & 0xF).contiguous();
auto qweight_final = unpacked.flatten(-2).transpose(-1, -2).to(at::kByte).clone();
auto qzeros_unsq = qzeros.unsqueeze(-1);
auto qzeros_unpacked = at::bitwise_right_shift(qzeros_unsq, bitshifts) & 0xF;
auto qzeros_final = qzeros_unpacked.flatten(-2).to(at::kByte);

auto qzeros_unpacked = (at::bitwise_right_shift(qzeros_unsq, bitshifts) & 0xF).contiguous();
auto qzeros_final = qzeros_unpacked.flatten(-2).to(at::kByte).clone();
return std::make_tuple(qweight_final, qzeros_final);
}

Expand Down
Loading