Skip to content

Comments

[NPU]ACLGraph Compilation support and PassManager with AddRmsNorm & Quantize fuse. TorchAir compiler backend support.#11104

Open
eshoguli wants to merge 98 commits intosgl-project:mainfrom
eshoguli:eshogulin/pass_manager
Open

[NPU]ACLGraph Compilation support and PassManager with AddRmsNorm & Quantize fuse. TorchAir compiler backend support.#11104
eshoguli wants to merge 98 commits intosgl-project:mainfrom
eshoguli:eshogulin/pass_manager

Conversation

@eshoguli
Copy link
Contributor

@eshoguli eshoguli commented Sep 30, 2025

Motivation

This Pull Request is part of architecture update to implement hardware independence layer on Qwen3 model as example. Other models will be supported later. Implemented changes make the model compilable. Implemented in the Pull Request PassManager and fusing passes (for fp16 and quantized models) based on torch.fx.replace remove hardware specific logic from Qwen3 model and allow to reuse the same Python models for different hardware with hardware specific optimizations (passes: SplitQkvRmsnormRopeFuse, NpuAddRmsNormQuantFuse, NpuAddRmsNormDynamicQuantFuse). Implemented compiler backends in this Pull Request specify usage of optimal inference for chosen hardware. Additionally, Pull Request has performance gain for fp16 and quantized models.

  1. Performance gain:

sglang.benchserving: Ascend 910A3, batch size = 32, 64, 128:

Branch Total token throughput (tok/s) Median TTFT (ms) Median ITL (ms)
Reference 3561.37 3166.90 31.34
Compilation
--enable-torch-compile
3677.94 3144.41 30.66

GSM8K Ascend 910A2:

Branch batch size 16 batch size 32 batch size 64
Reference Accuracy: 0.829
Latency: 826.420 s
Output throughput: 454.471 token/s
Accuracy: 0.829
Latency: 536.163 s
Output throughput: 684.102 token/s
Accuracy: 0.835
Latency: 378.931 s
Output throughput: 944.312 token/s
Compilation
--enable-torch-compile
Accuracy: 0.826
Latency: 798.009 s
Output throughput: 480.432 token/s
Accuracy: 0.826
Latency: 528.144 s
Output throughput: 720.576 token/s
Accuracy: 0.820
Latency: 367.730 s
Output throughput: 969.680 token/s
  1. Support model compilation on NPU and PassManager for current and future fuses in Python via torch.fx.replace_pattern. Fuses can be easily developed by external contributors.
  2. Improve performance via fuse AddRmsNorm and AscendQuantV2 kernels to AddRmsNormQuant kernel:
  3. Encrease performance for compiled model via NPU kerneal and torch guards avoiding.
  4. TorchAir compilation backend support
    Original comment: [feat] npu support enable_torch_compile #12371

TorchAir (Torch Ascend Intermediate Representation) is an extension library that provides graph mode capabilities for torch_npu. It enables users to perform graph-mode inference on NPU using PyTorch and torch_npu. TorchAir externally offers a torch.compile backend for NPU, which interfaces with torch._dynamo. Through the following features, performance optimization and capability enhancement of the torch fx graph can be achieved.

torchair1

TorchAir Main Features:

  1. Basic Features:
  • Enable NPU kernels that depend on host-value tiling operators (e.g., FIA) to support npugraph
  • Graph input copy optimization
  • Memory reuse across multi-graphs
  1. FX Pass:
  • In-place optimization
  • Redundant operator elimination
  • NPU fused operator passes
  1. Advanced Features:
  • Static shape kernel compilation
  • Multi-stream within single graphs
  • Compilation caching

How to enable compilation and fuses for NPUGraph decode:

--enable-torch-compile

Optional customization:

--compilation-config {"compiler": "npugraph"}

How to enable TorchAir for decode:

--enable-npu-torchair-compile --torch-compile-max-bs <max-bs-value> --disable-overlap-schedule

Optional customization:

--compilation-config {"compiler": "npugraph_ex"} --disable-overlap-schedule

CANN version: 8.2
Torch NPU version: torch-npu 2.6.0.post3

NpuAddRmsNormQuantFuse pass

NpuAddRmsNormQuantFuse pass fuses AddRmsNorm and AscentQuant to AddRmsNormQuant.

Before fuse: NPU kernels AddRmsNorm and AscentQuant usage which take 29 microseconds:
before

After fuse: NPU kernel AddRmsNormQuant usage which takes 19 microseconds:
after

SplitQkvRmsnormRopeFuse pass

Before fuse: NPU kernels RmsNorm, RmsNorm and RopeWithSinCosCache usage which take 62 microseconds:
before1

After fuse: NPU kernel split_kqv_rmsnorm_rope usage which takes 25 microseconds:
after1

Modifications

  1. Model compilation support by torch.compile
    Use --enable-torch-compile to enable compilation and optional --torch-compile-max-bs argument to limit max batch size for compilation.

  2. NpuGraphCompilerBackend compilation backend for NPU Graph capturing. Implemented in: python/sglang/srt/model_executor/compilation/npu_graph_compiler_backend.py, usage:

self.compiled_callable = torch.compile(
    model, fullgraph=True, dynamic=False, backend=NpuGraphCompilerBackend()
)
  1. PassManager passes manager and passes python/sglang/srt/model_executor/compilation/passes/w8a8_int8 and python/sglang/srt/compilation/npu/passes/fp16.py to optimize model during compilation. Usage:
from sglang.srt.compilation.npu.pass_manager import PassManager
from sglang.srt.compilation.npu.passes.fp16 import SplitQkvRmsnormRopeFuse
from sglang.srt.compilation.npu.passes.w8a8_int8 import (
    DivFuse,
    EraseCopy,
    NpuAddRmsNormQuantFuse,
    NpuAddRmsNormDynamicQuantFuse,
)

def apply_passes(graph_module: torch.fx.GraphModule):
    passManager = PassManager(graph_module)
    passManager.add(
        SplitQkvRmsnormRopeFuse,
        q_size=self.q_size,
        kv_size=self.kv_size,
        head_dim=self.head_dim,
        q_shape=self.q_shape,
        k_shape=self.k_shape,
        variance_epsilon=self.rms_norm_eps,
    )
    passManager.add(NpuAddRmsNormQuantFuse)
    passManager.add(NpuAddRmsNormDynamicQuantFuse)
    passManager.add(DivFuse)
    passManager.add(EraseCopy)
    passManager.apply()
    graph_module.recompile()
  1. RotaryEmbedding layer use NPU kernel in forward instead native implementation
  2. torch.compile guards are ignored to improve forward performance
  3. Ascend page attention is used to enable compilation without custom ops: python/sglang/srt/layers/attention/ascend_backend.py
  4. TorchAir
    7.1. Rewrite the capture function;
    7.2. Encapsulate the kvcache input (input needs all kvcache);
    7.3. Pad the block table to the max length;
    7.4. TorchAir input preparation;

The calling process is as follows.
torchair2

Accuracy Tests

Collected on gsm8k dataset for static quantized Qwen3-32B:

Version Accuracy
Reference 85.7%
Compilation 85.6%
TorchAir 85.1%

TorchAir

python3 few_shot_gsm8k.py --data-path "/path/to/model/test.jsonl.txt” --parallel 32 --num-questions 200

Accuracy: 0.865
Invalid: 0.000
Latency: 43.077 s
Output throughput: 795.877 token/s

Collected on MMMU dataset for Qwen3-VL-30B-A3B-Instruct:

Version Overall accuracy
Reference 0.592
Compilation 0.597

Benchmarking and Profiling (910A3)

Reference

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 64
Successful requests:                     64
Benchmark duration (s):                  36.80
Total input tokens:                      65536
Total input text tokens:                 65536
Total input vision tokens:               0
Total generated tokens:                  65536
Total generated tokens (retokenized):    65510
Request throughput (req/s):              1.74
Input token throughput (tok/s):          1780.68
Output token throughput (tok/s):         1780.68
Peak output token throughput (tok/s):    2176.00
Peak concurrent requests:                64
Total token throughput (tok/s):          3561.37
Concurrency:                             63.95
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   36777.89
Median E2E Latency (ms):                 36777.27
---------------Time to First Token----------------
Mean TTFT (ms):                          2851.34
Median TTFT (ms):                        3166.90
P99 TTFT (ms):                           4767.54
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          33.16
Median TPOT (ms):                        32.86
P99 TPOT (ms):                           35.11
---------------Inter-Token Latency----------------
Mean ITL (ms):                           33.16
Median ITL (ms):                         31.34
P95 ITL (ms):                            33.23
P99 ITL (ms):                            33.61
Max ITL (ms):                            4542.82
==================================================

Compilation

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 64
Successful requests:                     64
Benchmark duration (s):                  35.64
Total input tokens:                      65536
Total input text tokens:                 65536
Total input vision tokens:               0
Total generated tokens:                  65536
Total generated tokens (retokenized):    65527
Request throughput (req/s):              1.80
Input token throughput (tok/s):          1838.97
Output token throughput (tok/s):         1838.97
Peak output token throughput (tok/s):    2240.00
Peak concurrent requests:                64
Total token throughput (tok/s):          3677.94
Concurrency:                             63.95
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   35610.14
Median E2E Latency (ms):                 35610.78
---------------Time to First Token----------------
Mean TTFT (ms):                          2691.20
Median TTFT (ms):                        3144.41
P99 TTFT (ms):                           4235.79
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          32.18
Median TPOT (ms):                        31.74
P99 TPOT (ms):                           33.99
---------------Inter-Token Latency----------------
Mean ITL (ms):                           32.18
Median ITL (ms):                         30.66
P95 ITL (ms):                            32.61
P99 ITL (ms):                            33.06
Max ITL (ms):                            4036.15
==================================================

Future roadmaps

In the torch_npu 7.2.0 version, the reduce-overhead mode of the torchair backend will support torch.compile(model, dynamic=True). This mode will be set as the default in get_compile_backend(), enabling support for methods wrapped by the @torch.compile() decorator.
In the torch_npu 7.3.0 version, the capture and replay of NPUGraph currently integrated in the torchair backend will be changed to optional execution. The torchair backend will only perform optimizations such as fx pass optimization and static kernel compilation, while the capture and replay of NPUGraph will be implemented independently. This design is closer to the implementation of CudaGraphRunner, decoupling fx graph optimization from graph offloading.

Checklist

@eshoguli eshoguli changed the title [WIP] NPU Graph Compilation & PassManager NPU Graph Compilation support and PassManager with AddRmsNorm & Quantize fuse Oct 30, 2025
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch 9 times, most recently from 508e483 to d77e709 Compare October 30, 2025 22:40
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from c958827 to b974460 Compare October 31, 2025 15:38
help="Enable debug mode for torch compile",
)
parser.add_argument(
"--enable-npu-torchair-compile",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse the --enable-torch-compile and set specific configs to args like --torch-compile-config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for comment. TorchAIR for NPU can be enabled by --enable-npu-torchair-compile and --compilation-config option:

--compilation-config {"compiler": "npugraph_ex"} --disable-overlap-schedule

Separate command line argument --enable-npu-torchair-compile has been discussed with @yuan-luo offline. We decided to have separate argument for each option.

"--compilation-config",
type=str,
default=None,
help="Compilation config.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to --torch-compile-config? Can we make the description more clear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. Description was extended.

Note, please: --compilation-config represents JSON serialized instance of CompilationConfig class, which has been created before: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/compilation/compilation_config.py. I can rename, if you really need, but it will be different from class name. Should I rename CompilationConfig class too?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that's fine.

@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch 2 times, most recently from a73e3a8 to 1cdcb0d Compare January 14, 2026 14:52
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from 1cdcb0d to ac81122 Compare January 14, 2026 17:08
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch 2 times, most recently from 29996aa to 9dab727 Compare January 15, 2026 08:54
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from 9dab727 to a61b8d6 Compare January 15, 2026 11:43
@ping1jing2
Copy link
Collaborator

/rerun-failed-ci

1 similar comment
@ping1jing2
Copy link
Collaborator

/rerun-failed-ci

@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from 77cb7fc to d0ed6b3 Compare January 17, 2026 12:36
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from d0ed6b3 to 01d892d Compare January 18, 2026 20:34
@ping1jing2
Copy link
Collaborator

/rerun-failed-ci

hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if _is_npu and context.cache is not None:
_ = prepare_weight_cache(hidden_states, context.cache)
_ = torch.ops.sglang.prepare_weight_cache(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are against the use of torch.ops.sglang. Please use the new API

def register_custom_op(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks for comment

def __init__(
self,
capture_sizes: List[int],
capture_sizes: List[int] = [],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using [] as the default argument is a very bad behavior. It will be inplace changed!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, thanks for comment

# FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1

compilation_config: Optional[CompilationConfig] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this under enable_torch_compile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for comment. compilation_config option is still here, but enable_torch_compile validation existence was added.

@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from 32bcb6b to 5c16ce2 Compare January 20, 2026 14:27
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from 5c16ce2 to 467b543 Compare January 20, 2026 15:17
@eshoguli eshoguli requested a review from sglang-bot January 20, 2026 15:25
@ping1jing2
Copy link
Collaborator

/rerun-failed-ci

@ping1jing2
Copy link
Collaborator

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation npu piecewise-cuda-graph quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.