-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Description
Checklist
- 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- 2. Please use English, otherwise it will be closed.
Motivation
Share our optimization methods for Qwen
Qwen3-32B
decode optimization
-
Fused OPs
OPs Ascend_OPs location RMSNorm torch_npu.npu_rms_norm
torch_npu.npu_add_rms_normlayernorm.py::RMSNorm::forward_npu RoPE torch_npu.npu_mrope(including cos_sin_cache) layernorm.py::RotaryEmbedding::forward_npu RadixAttention prefill: torch_npu._npu_flash_attention_qlens
decode:torch_npu._npu_paged_attentionascend_backend.py::AscendAttnBackend::forward_extend, forward_decode SiluAndMul torch_npu.npu_swiglu activation.py::SiluAndMul::forward_npu -
Summary of other key features
-
W8A8 quantization
-
enable ACLGraph
The aim is to improve the performance of Launch task during the decode phase. issue 8030
Notice:
- Special handling the
actual_seq_lengthsin torch_npu.npu_fused_infer_attention_score during the npu_graph_runner.py::NPUGraphRunner::replay. - Support
torch_npu._npu_paged_attention, which is faster thantorch_npu.npu_fused_infer_attention_score, in ACLGraph. more details see PR24572. In this case, you should special handle thecontext_lensargument.
注意:
- 在 replay 阶段需要特殊处理 torch_npu.npu_fused_infer_attention_score 的入参
actual_seq_lengths。关于这个问题的背景和解释详见 IFA的tiling 依赖actual_seq_len host值,是问题的起源。 - 我们发现当前阶段
torch_npu._npu_paged_attention比torch_npu.npu_fused_infer_attention_score计算快一些,但是不支持ACLGraph入图,所以与torch_npu沟通后提了 PR24572 来支持。
- Special handling the
-
[CMO权重预取] Prefetch the weight of matmul when running the AIV kernels
Using torch_npu.npu_prefetch to Prefetch the weight of matmul(gate_up, down proj) when running other AIV kernels, aiming to overlap the memory access time.权重预取通过在计算matmul(gate_up,down proj)前,提前将其右矩阵权重预加载到L2 cache上,从而减少了算子运算阶段的访存开销,缩短了matmul算子执行时间。
-
prefill optimization -- sequence parallelism
see #10519 in details
host optimization
will update late
Qwen2.5-VL----only Vision part
please see #9189 #10556 #11047
-
VisionAttention
Usingtorch_npu._npu_flash_attention_unpadfor attention acceleration, including sin/cos cache andtorch_npu.npu_rotary_mul -
VisionPatchEmbed
Usingmatmulinstead ofConv3Dfor Patch acceleration.How does it work?
feature_map.shape = (N, C, D, H, W);
kernel_size.shape = (Cout, C, d, h, w);
for Qwen2.5-VL: kernel_size == stride and D=d, H=h, W=w
sliding window times: S = (D/d) * (H/h) * (W/w) = 1
Conv3D result:Hidden_state = (N, S, C*d*h*w) x (Cout, C*d*h*w)^T = (N, 1, C*d*h*w) x (Cout, C*d*h*w)^T
Equals to:Hidden_state = (N, C*d*h*w) x (Cout, C*d*h*w)^T -
VisionTransformer
attention padding
Because Ascend Cube unit has the highest performance when the input.shape is divisible by 16, we padding the attn_head from [40, 40] to [64, 64] in
Qwen2_5_VLForConditionalGeneration.load_weights