Skip to content

[NPU] Piecewise Graph for decode with PassManager & fuses#15332

Draft
eshoguli wants to merge 76 commits intosgl-project:mainfrom
eshoguli:eshoguli/pass_manager_piecewise
Draft

[NPU] Piecewise Graph for decode with PassManager & fuses#15332
eshoguli wants to merge 76 commits intosgl-project:mainfrom
eshoguli:eshoguli/pass_manager_piecewise

Conversation

@eshoguli
Copy link
Contributor

@eshoguli eshoguli commented Dec 17, 2025

Motivation

Piecewise for NPU Graph, based on: #11104

This Pull Request is part of architecture update to implement hardware independence layer on Qwen3 model as example. Other models will be supported later. Implemented changes make the model compilable. Implemented in the Pull Request PassManager and fusing passes (for fp16 and quantized models) based on torch.fx.replace remove hardware specific logic from Qwen3 model and allow to reuse the same Python models for different hardware with hardware specific optimizations (passes: SplitQkvRmsnormRopeFuse, NpuAddRmsNormQuantFuse, NpuAddRmsNormDynamicQuantFuse). Implemented compiler backends in this Pull Request specify usage of optimal inference for chosen hardware. Additionally, Pull Request has performance gain for fp16 and quantized models.

  1. Performance gain:

sglang.benchserving: Ascend 910A3, batch size = 32, 64, 128:

Branch Total token throughput (tok/s) Median TTFT (ms) Median ITL (ms)
Reference 3561.37 3166.90 31.34
Compilation
--enable-torch-compile
3677.94 3144.41 30.66

GSM8K Ascend 910A2:

Branch batch size 16 batch size 32 batch size 64
Reference Accuracy: 0.829
Latency: 826.420 s
Output throughput: 454.471 token/s
Accuracy: 0.829
Latency: 536.163 s
Output throughput: 684.102 token/s
Accuracy: 0.835
Latency: 378.931 s
Output throughput: 944.312 token/s
Compilation
--enable-torch-compile
Accuracy: 0.826
Latency: 798.009 s
Output throughput: 480.432 token/s
Accuracy: 0.826
Latency: 528.144 s
Output throughput: 720.576 token/s
Accuracy: 0.820
Latency: 367.730 s
Output throughput: 969.680 token/s
  1. Support model compilation on NPU and PassManager for current and future fuses in Python via torch.fx.replace_pattern. Fuses can be easily developed by external contributors.
  2. Improve performance via fuse AddRmsNorm and AscendQuantV2 kernels to AddRmsNormQuant kernel:
  3. Encrease performance for compiled model via NPU kerneal and torch guards avoiding.
  4. Piecewise graph execution approach. Issue: [Feature] support ACLGraph #8030
  5. TorchAir compilation backend support
    Original comment: [feat] npu support enable_torch_compile #12371

TorchAir (Torch Ascend Intermediate Representation) is an extension library that provides graph mode capabilities for torch_npu. It enables users to perform graph-mode inference on NPU using PyTorch and torch_npu. TorchAir externally offers a torch.compile backend for NPU, which interfaces with torch._dynamo. Through the following features, performance optimization and capability enhancement of the torch fx graph can be achieved.

torchair1

TorchAir Main Features:

  1. Basic Features:
  • Enable NPU kernels that depend on host-value tiling operators (e.g., FIA) to support npugraph
  • Graph input copy optimization
  • Memory reuse across multi-graphs
  1. FX Pass:
  • In-place optimization
  • Redundant operator elimination
  • NPU fused operator passes
  1. Advanced Features:
  • Static shape kernel compilation
  • Multi-stream within single graphs
  • Compilation caching

How to enable compilation and fuses for NPUGraph decode:

Variant 1: similar as other options:

--enable-torch-compile

Variant 2: more general with customization in the future:

--compilation-config {"compiler": "npugraph"}

How to enable piecewise graph and fuses for decode:

Variant 1: similar as other options with default compilation parameters:

--enable-piecewise-npu-graph-decode

Variant 2: more general with customization:

--compilation-config {"compiler": "piecewise", "splitting_ops": ["atb._npu_paged_attention"]}

splitting_ops key is optional

How to enable TorchAir for decode:

Variant 1: similar as other options:

--enable-torchair-compile --torch-compile-max-bs <max-bs-value> --disable-overlap-schedule

Variant 2: more general with customization in the future:

--compilation-config {"compiler": "npugraph_ex"} --disable-overlap-schedule

CANN version: 8.2
Torch NPU version: torch-npu 2.6.0.post3

NpuAddRmsNormQuantFuse pass

NpuAddRmsNormQuantFuse pass fuses AddRmsNorm and AscentQuant to AddRmsNormQuant.

Before fuse: NPU kernels AddRmsNorm and AscentQuant usage which take 29 microseconds:
before

After fuse: NPU kernel AddRmsNormQuant usage which takes 19 microseconds:
after

SplitQkvRmsnormRopeFuse pass

Before fuse: NPU kernels RmsNorm, RmsNorm and RopeWithSinCosCache usage which take 62 microseconds:
before1

After fuse: NPU kernel split_kqv_rmsnorm_rope usage which takes 25 microseconds:
after1

Modifications

  1. Model compilation support by torch.compile
    Use --enable-torch-compile to enable compilation and optional --torch-compile-max-bs argument to limit max batch size for compilation.

  2. NpuGraphCompilerBackend compilation backend for NPU Graph capturing. Implemented in: python/sglang/srt/model_executor/compilation/npu_graph_compiler_backend.py, usage:

self.compiled_callable = torch.compile(
    model, fullgraph=True, dynamic=False, backend=NpuGraphCompilerBackend()
)
  1. PiecewiseNpuGraphCompilerBackend compilation backend for Piecewise graph and partial NPU Graph capturing. Inherited from NpuGraphCompilerBackend to reuse fusing passes. Implemented in: python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py, usage:
self.compiled_callable = torch.compile(
    model, fullgraph=True, dynamic=False, backend=PiecewiseNpuGraphCompilerBackend()
)

You can use --enable-piecewise-npu-graph-decode to enable Piecewise Graph.
Optional command line arguments:

  • --compilation-config {"splitting_ops": ["atb._npu_paged_attention"]} to configure compilation backend,
  • --cuda-graph-bs to specify batch size,
  • --cuda-graph-max-bs to limit max batch size.
  1. PassManager passes manager and passes python/sglang/srt/model_executor/compilation/passes/w8a8_int8 and python/sglang/srt/compilation/npu/passes/fp16.py to optimize model during compilation. Usage:
from sglang.srt.compilation.npu.pass_manager import PassManager
from sglang.srt.compilation.npu.passes.fp16 import SplitQkvRmsnormRopeFuse
from sglang.srt.compilation.npu.passes.w8a8_int8 import (
    DivFuse,
    EraseCopy,
    NpuAddRmsNormQuantFuse,
    NpuAddRmsNormDynamicQuantFuse,
)

def apply_passes(graph_module: torch.fx.GraphModule):
    passManager = PassManager(graph_module)
    passManager.add(
        SplitQkvRmsnormRopeFuse,
        q_size=self.q_size,
        kv_size=self.kv_size,
        head_dim=self.head_dim,
        q_shape=self.q_shape,
        k_shape=self.k_shape,
        variance_epsilon=self.rms_norm_eps,
    )
    passManager.add(NpuAddRmsNormQuantFuse)
    passManager.add(NpuAddRmsNormDynamicQuantFuse)
    passManager.add(DivFuse)
    passManager.add(EraseCopy)
    passManager.apply()
    graph_module.recompile()
  1. RotaryEmbedding layer use NPU kernel in forward instead native implementation
  2. torch.compile guards are ignored to improve forward performance
  3. Ascend page attention is used to enable compilation without custom ops: python/sglang/srt/layers/attention/ascend_backend.py
  4. TorchAir
    7.1. Rewrite the capture function;
    7.2. Encapsulate the kvcache input (input needs all kvcache);
    7.3. Pad the block table to the max length;
    7.4. TorchAir input preparation;

The calling process is as follows.
torchair2

Class Diagram

classDiagram
    class PiecewiseNpuGraphRunnerDecode
    class NPUCompileModelRunner
    class NPUGraphRunner
    class CudaGraphRunner
    class NpuGraphCompiler
    class NpuGraphCompilerBackend
    class PiecewiseNpuGraphCompiler
    class PiecewiseNpuGraphCompilerBackend

    NPUGraphRunner--|>CudaGraphRunner
    NPUGraphRunner-->NpuGraphCompiler
    NpuGraphCompiler-->NpuGraphCompilerBackend
    NPUCompileModelRunner-->CudaGraphRunner
    PiecewiseNpuGraphRunnerDecode-->CudaGraphRunner
    PiecewiseNpuGraphRunnerDecode-->PiecewiseNpuGraphCompiler
    PiecewiseNpuGraphCompiler-->PiecewiseNpuGraphCompilerBackend
    PiecewiseNpuGraphCompilerBackend--|>NpuGraphCompilerBackend
Loading

Accuracy Tests

Collected on gsm8k dataset for static quantized Qwen3-32B:

Version Accuracy
Reference 85.7%
Compilation 85.6%
Piecewise Graph 85.7%
TorchAir 85.1%

TorchAir

python3 few_shot_gsm8k.py --data-path "/path/to/model/test.jsonl.txt” --parallel 32 --num-questions 200

Accuracy: 0.865
Invalid: 0.000
Latency: 43.077 s
Output throughput: 795.877 token/s

Collected on MMMU dataset for Qwen3-VL-30B-A3B-Instruct:

Version Overall accuracy
Reference 0.592
Compilation 0.597
Piecewise Graph 0.591

Benchmarking and Profiling (910A3)

Reference

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 64
Successful requests:                     64
Benchmark duration (s):                  36.80
Total input tokens:                      65536
Total input text tokens:                 65536
Total input vision tokens:               0
Total generated tokens:                  65536
Total generated tokens (retokenized):    65510
Request throughput (req/s):              1.74
Input token throughput (tok/s):          1780.68
Output token throughput (tok/s):         1780.68
Peak output token throughput (tok/s):    2176.00
Peak concurrent requests:                64
Total token throughput (tok/s):          3561.37
Concurrency:                             63.95
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   36777.89
Median E2E Latency (ms):                 36777.27
---------------Time to First Token----------------
Mean TTFT (ms):                          2851.34
Median TTFT (ms):                        3166.90
P99 TTFT (ms):                           4767.54
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          33.16
Median TPOT (ms):                        32.86
P99 TPOT (ms):                           35.11
---------------Inter-Token Latency----------------
Mean ITL (ms):                           33.16
Median ITL (ms):                         31.34
P95 ITL (ms):                            33.23
P99 ITL (ms):                            33.61
Max ITL (ms):                            4542.82
==================================================

Compilation

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 64
Successful requests:                     64
Benchmark duration (s):                  35.64
Total input tokens:                      65536
Total input text tokens:                 65536
Total input vision tokens:               0
Total generated tokens:                  65536
Total generated tokens (retokenized):    65527
Request throughput (req/s):              1.80
Input token throughput (tok/s):          1838.97
Output token throughput (tok/s):         1838.97
Peak output token throughput (tok/s):    2240.00
Peak concurrent requests:                64
Total token throughput (tok/s):          3677.94
Concurrency:                             63.95
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   35610.14
Median E2E Latency (ms):                 35610.78
---------------Time to First Token----------------
Mean TTFT (ms):                          2691.20
Median TTFT (ms):                        3144.41
P99 TTFT (ms):                           4235.79
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          32.18
Median TPOT (ms):                        31.74
P99 TPOT (ms):                           33.99
---------------Inter-Token Latency----------------
Mean ITL (ms):                           32.18
Median ITL (ms):                         30.66
P95 ITL (ms):                            32.61
P99 ITL (ms):                            33.06
Max ITL (ms):                            4036.15
==================================================

Future roadmaps

In the torch_npu 7.2.0 version, the reduce-overhead mode of the torchair backend will support torch.compile(model, dynamic=True). This mode will be set as the default in get_compile_backend(), enabling support for methods wrapped by the @torch.compile() decorator.
In the torch_npu 7.3.0 version, the capture and replay of NPUGraph currently integrated in the torchair backend will be changed to optional execution. The torchair backend will only perform optimizations such as fx pass optimization and static kernel compilation, while the capture and replay of NPUGraph will be implemented independently. This design is closer to the implementation of CudaGraphRunner, decoupling fx graph optimization from graph offloading.

Checklist

@github-actions github-actions bot added documentation Improvements or additions to documentation npu piecewise-cuda-graph labels Dec 17, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @eshoguli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances NPU (Ascend NPU) inference performance by introducing a piecewise graph compilation framework. It allows for the decomposition of the computational graph into smaller, optimizable segments, which are then processed by a new PassManager with specialized fusion and optimization passes. The changes integrate seamlessly with torch.compile, introduce custom NPU operations, and update core components to support this advanced compilation flow, providing greater flexibility and efficiency for NPU-based model execution.

Highlights

  • New NPU Graph Compilation Strategy: Introduced a piecewise graph compilation strategy for NPU decode operations, leveraging torch.compile and a custom PassManager to optimize model execution.
  • PassManager and Graph Optimization Passes: Added a PassManager for applying graph optimization passes, including DivFuse, EraseCopy, NpuAddRmsNormQuantFuse, NpuAddRmsNormDynamicQuantFuse, and SplitQkvRmsnormRopeFuse.
  • New Compilation Backends and Custom Operations: Implemented NpuGraphCompilerBackend and PiecewiseNpuGraphCompilerBackend for NPU graph compilation, along with new custom PyTorch operations for dynamic padding, CMO, and fused QKV/RMSNorm/RoPE.
  • Dynamo Context Patching: Integrated torch.compile more deeply by patching torch._dynamo.eval_frame.DisableContext to better handle batch sizes within compiled functions.
  • Updated Quantization Logic: Refined NPU quantization methods to directly store weight data and handle aclnn_input_scale_reciprocal differently when torch.compile is enabled for specific linear layers.
  • Documentation and Testing: Added new documentation for NPU pass development and introduced new test cases for NPU graph compilation and piecewise graph functionality.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant new functionality for NPU graph compilation, including a piecewise graph runner and a PassManager for optimizations. The changes are extensive and touch many parts of the compilation and runtime stack. While the overall direction is good, there are several areas that need improvement in terms of correctness, maintainability, and clarity. Specifically, there are risky monkey-patching practices, obscure hacks for graph compilation, and hardcoded values that make the code brittle. Addressing these issues will be crucial for the long-term stability and maintainability of this new NPU backend.

Comment on lines +65 to +71
torch.cuda.CUDAGraph = torch.npu.NPUGraph
torch.cuda.synchronize = torch.npu.synchronize
torch.cuda.graph = torch.npu.graph
torch.cuda.stream = torch.npu.stream
torch.cuda.Stream = torch.npu.Stream
torch.cuda.current_stream = torch.npu.current_stream
torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Monkey-patching the torch.cuda namespace to alias torch.npu is highly risky and can lead to unexpected behavior in other parts of the codebase that rely on torch.cuda for actual CUDA operations. This global side effect makes the code harder to reason about and maintain. Please consider a more localized approach, such as creating a compatibility module or using conditional imports where NPU-specific functions are needed, rather than patching the entire namespace.

Comment on lines +1172 to +1181
if (
self.enable_piecewise_npu_graph_decode
and torch.compiler.is_dynamo_compiling()
):
# input args for submodule forward
forward_batch.req_to_token_pool.req_to_token.add_(
forward_batch.req_to_token_pool.req_to_token
)
forward_batch.req_pool_indices.add_(forward_batch.req_pool_indices)
forward_batch.seq_lens.add_(forward_batch.seq_lens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of in-place add_ operations on tensors to mark them as inputs for torch.compile is an obscure hack. This code doubles the tensor values, which could lead to correctness issues if not handled carefully downstream. While this might be a necessary workaround for Dynamo, it severely impacts code readability and maintainability. Please add a detailed comment explaining why this is necessary and what it achieves. A better long-term solution would be to find a more explicit way to register tensor dependencies for graph compilation.

if not self.compilation_context.stream:
self.compilation_context.stream = torch_npu.npu.Stream()

torch.cuda.synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This file is in an NPU-specific path, but it calls torch.cuda.synchronize(). This should be torch.npu.synchronize() to ensure correct synchronization on NPU devices. Using the CUDA version here could lead to incorrect behavior or errors when running on NPUs.

Suggested change
torch.cuda.synchronize()
torch.npu.synchronize()

flatten = positions.flatten()
cos_sin = cos_sin_cache.index_select(0, flatten)

reshape = cos_sin.reshape(-1, 2, 64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 64 is hardcoded in reshape. This appears to be related to the head dimension. Hardcoding this value makes the pass brittle and will likely cause it to fail for models with different head dimensions. This should be derived from a configuration parameter, such as self.head_dim // 2, to ensure the pass is general and robust.

Suggested change
reshape = cos_sin.reshape(-1, 2, 64)
reshape = cos_sin.reshape(-1, 2, self.head_dim // 2)

Comment on lines 9 to +20
def __init__(
self,
capture_sizes: List[int],
capture_sizes: List[int] = [],
compiler: str = "eager",
enable_debug_mode: bool = False,
splitting_ops: List[str] = [],
):
self.traced_files = set()
self.capture_sizes = capture_sizes
self.compiler = compiler
self.enable_debug_mode = enable_debug_mode
self.splitting_ops = splitting_ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using mutable default arguments like [] is a common pitfall in Python. The same list instance will be shared across all CompilationConfig objects created without explicitly providing capture_sizes or splitting_ops. If one instance modifies its list, it will affect all other instances. To avoid this potential bug, you should use None as the default value and initialize a new list inside __init__ if the argument is None. You will also need to import Optional from typing.

Suggested change
def __init__(
self,
capture_sizes: List[int],
capture_sizes: List[int] = [],
compiler: str = "eager",
enable_debug_mode: bool = False,
splitting_ops: List[str] = [],
):
self.traced_files = set()
self.capture_sizes = capture_sizes
self.compiler = compiler
self.enable_debug_mode = enable_debug_mode
self.splitting_ops = splitting_ops
def __init__(
self,
capture_sizes: Optional[List[int]] = None,
compiler: str = "eager",
enable_debug_mode: bool = False,
splitting_ops: Optional[List[str]] = None,
):
self.traced_files = set()
self.capture_sizes = capture_sizes if capture_sizes is not None else []
self.compiler = compiler
self.enable_debug_mode = enable_debug_mode
self.splitting_ops = splitting_ops if splitting_ops is not None else []

q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_hidden_size: int,
kv_hiddem_size: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the parameter name kv_hiddem_size. It should be kv_hidden_size. This typo is present in the function definition, its usage, and the fake implementation. Correcting it will improve code clarity and consistency.

Suggested change
kv_hiddem_size: int,
kv_hidden_size: int,



class Submodule(torch.nn.Module):
block_tables = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The block_tables attribute is defined as a class variable and is modified in forward_with_calculation. This stateful design, where one method call (forward_with_calculation) sets a class-level state that a subsequent call (forward) depends on, is fragile and can lead to subtle bugs, especially in concurrent environments. It would be cleaner and safer to pass block_tables as an argument to the forward method or manage this state within an instance rather than at the class level.

Comment on lines +215 to +216
ops_count = 3
ops_step = ops_count + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic numbers ops_count = 3 and ops_step = 4 make this graph splitting logic hard to understand and maintain. It seems to be looking for a specific pattern of nodes. Please add comments to explain what this pattern is and why these specific numbers are used. Consider defining them as named constants with descriptive names if they represent a fixed pattern.

`PassManager` is implemented here: [PassManager](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py)


You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. [`PiecewiseNpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py) compiler backed uses `PassManager` too via `NpuGraphCompilerBackend` inheritance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo here. "compiler backed" should be "compiler backend".

Suggested change
You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. [`PiecewiseNpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py) compiler backed uses `PassManager` too via `NpuGraphCompilerBackend` inheritance.
You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. [`PiecewiseNpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py) compiler backend uses `PassManager` too via `NpuGraphCompilerBackend` inheritance.

After pass development you should create `PassManager` instance, add the pass and call `apply` method:
```
def apply_passes(self, graph_module: torch.fx.GraphModule):
passManager = PassManager(graph_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable name passManager does not follow the PEP 8 style guide, which recommends snake_case for variable names. It should be renamed to pass_manager for consistency.

Suggested change
passManager = PassManager(graph_module)
pass_manager = PassManager(graph_module)

@xueliangyang-oeuler
Copy link

Can this pr work on A2 NPU?If it can work efficiently, pls give examples for som models, e.g. qwen-32b or qwen3-next-80b. Thank you.

@ping1jing2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation npu piecewise-cuda-graph

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants