-
Notifications
You must be signed in to change notification settings - Fork 4.8k
[NPU] Piecewise Graph for decode with PassManager & fuses #15332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4ce70e6
b974460
7a7bde7
d8e2dc3
9bb7751
7048005
eb240d9
29c1d89
3e98d17
3d9516a
2c1b6fe
55016b0
fbff08d
3e5db77
99d4497
30da7fe
36ef7e7
1808479
bcfc2c5
a6a159d
c08d076
73f2ee9
2f97641
7154cf4
dfaee00
bec1b28
253c14d
85d808e
11074d9
51ac4b4
e06675b
0c09c24
00a0b9b
0b31746
3b5c83b
7eefeee
8c63980
2e02568
966bbf4
14092b3
f989147
bf1251d
317174b
e6eb29c
caba95e
3f87879
a2046c3
faea888
85720d6
fd0e1e8
58966d6
a85ab1f
4a61b7e
17f0af5
ad76e3c
752657c
3ac87be
b79fc0b
105050f
90caee2
a5e87f6
daf81b2
ea25b3f
3365d71
40389dd
f5424d8
55a1e06
fd28ac6
97d654e
3ce92e8
f4dfef3
123e36c
b3e2fe8
ebcc846
132581a
81a392c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,30 @@ | ||||||
| ## How to transform model instances with PyTorch FX Toolkit in SGLang for NPU | ||||||
|
|
||||||
| ### PassManager | ||||||
| `PassManager` is implemented here: [PassManager](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py) | ||||||
|
|
||||||
|
|
||||||
| You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. [`PiecewiseNpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py) compiler backed uses `PassManager` too via `NpuGraphCompilerBackend` inheritance. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a typo here. "compiler backed" should be "compiler backend".
Suggested change
|
||||||
|
|
||||||
| ### Pass development | ||||||
| There are two approaches to develop passes for SGLang NPU PassManager: | ||||||
|
|
||||||
| 1. Matches all possible non-overlapping sets of operators and their data dependencies with `torch.fx.replace_pattern` api. | ||||||
| Pass example: [NpuAddRmsNormQuantFuse](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L82). | ||||||
| You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#subgraph-rewriting-with-replace-pattern | ||||||
|
|
||||||
| 2. Direct Graph Manipulation. | ||||||
| Pass example: [EraseCopy](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L28). | ||||||
| You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#direct-graph-manipulation | ||||||
|
|
||||||
| ### Compiler backend update | ||||||
| After pass development you should create `PassManager` instance, add the pass and call `apply` method: | ||||||
| ``` | ||||||
| def apply_passes(self, graph_module: torch.fx.GraphModule): | ||||||
| passManager = PassManager(graph_module) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| passManager.add(NpuAddRmsNormQuantFuse) | ||||||
| passManager.apply() | ||||||
| graph_module.recompile() | ||||||
| ``` | ||||||
|
|
||||||
| You can explore [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) as example. | ||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,20 +1,23 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import List | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO(Yuwei): support better compile config support | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| class CompilationConfig: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| capture_sizes: List[int], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| capture_sizes: List[int] = [], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| compiler: str = "eager", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| enable_debug_mode: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| splitting_ops: List[str] = [], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.traced_files = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.capture_sizes = capture_sizes | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.compiler = compiler | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.enable_debug_mode = enable_debug_mode | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.splitting_ops = splitting_ops | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
9
to
+20
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using mutable default arguments like
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def add_traced_file(self, file_path: str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.traced_files.add(file_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -25,5 +28,13 @@ def get_traced_files(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_capture_sizes(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.capture_sizes | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| def from_cli(cls, args) -> "CompilationConfig": | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| args_dict = json.loads(args) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return CompilationConfig(**args_dict) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_enable_debug_mode(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.enable_debug_mode | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_splitting_ops(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.splitting_ops | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # Copyright 2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| import sglang.srt.layers.dp_attention | ||
|
|
||
|
|
||
| @torch.library.custom_op("sglang::_set_dp_buffer_len", mutates_args=()) | ||
| def _set_dp_buffer_len( | ||
| global_dp_buffer_len: Optional[int], | ||
| num_tokens: Optional[int], | ||
| is_max_len: bool, | ||
| global_num_tokens: Optional[List[int]] = None, | ||
| ) -> None: | ||
| global set_dp_buffer_len_original | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| sglang.srt.layers.dp_attention.set_dp_buffer_len( | ||
| global_dp_buffer_len, num_tokens, is_max_len, global_num_tokens | ||
| ) | ||
|
|
||
|
|
||
| @_set_dp_buffer_len.register_fake | ||
| def _set_dp_buffer_len_fake( | ||
| global_dp_buffer_len: Optional[int], | ||
| num_tokens: Optional[int], | ||
| is_max_len: bool, | ||
| global_num_tokens: Optional[List[int]] = None, | ||
| ) -> None: | ||
| pass | ||
|
|
||
|
|
||
| @torch.library.custom_op("sglang::_set_is_extend_in_batch", mutates_args=()) | ||
| def _set_is_extend_in_batch(is_extend_in_batch: bool) -> None: | ||
| sglang.srt.layers.dp_attention.set_is_extend_in_batch(is_extend_in_batch) | ||
|
|
||
|
|
||
| @_set_is_extend_in_batch.register_fake | ||
| def _set_is_extend_in_batch_fake(is_extend_in_batch: bool) -> None: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -231,6 +231,9 @@ class AscendAttnBackend(AttentionBackend): | |
|
|
||
| def __init__(self, model_runner: ModelRunner): | ||
| super().__init__() | ||
| self.enable_piecewise_npu_graph_decode = ( | ||
| model_runner.server_args.enable_piecewise_npu_graph_decode | ||
| ) | ||
| self.forward_metadata = None | ||
| self.device = model_runner.device | ||
| self.page_size = model_runner.page_size | ||
|
|
@@ -248,7 +251,6 @@ def __init__(self, model_runner: ModelRunner): | |
| self.req_to_token = model_runner.req_to_token_pool.req_to_token | ||
| self.graph_mode = False | ||
| self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False") | ||
| self.enable_torch_compile = model_runner.server_args.enable_torch_compile | ||
| self.speculative_num_draft_tokens = ( | ||
| model_runner.server_args.speculative_num_draft_tokens | ||
| ) | ||
|
|
@@ -264,6 +266,11 @@ def __init__(self, model_runner: ModelRunner): | |
| if self.use_mla: | ||
| self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask | ||
|
|
||
| self.enable_torchair_compile = model_runner.server_args.enable_torchair_compile | ||
| if self.enable_torchair_compile: | ||
| max_total_tokens = model_runner.max_total_num_tokens | ||
| self.max_seqlen_pad = max_total_tokens // model_runner.server_args.page_size | ||
|
|
||
| def get_verify_buffers_to_fill_after_draft(self): | ||
| """ | ||
| Return buffers for verify attention kernels that needs to be filled after draft. | ||
|
|
@@ -283,12 +290,29 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): | |
| seq_lens_max = forward_batch.seq_lens.max() | ||
| if forward_batch.forward_mode.is_target_verify(): | ||
| seq_lens_max += self.speculative_num_draft_tokens | ||
| self.forward_metadata.block_tables = ( | ||
|
|
||
| block_tables = ( | ||
| forward_batch.req_to_token_pool.req_to_token[ | ||
| forward_batch.req_pool_indices, :seq_lens_max | ||
| ][:, :: self.page_size] | ||
| // self.page_size | ||
| ) | ||
|
|
||
| if ( | ||
| self.enable_torchair_compile | ||
| and forward_batch.forward_mode.is_decode_or_idle() | ||
| ): | ||
| bs = forward_batch.input_ids.size(0) | ||
| device = forward_batch.input_ids.device | ||
| self.forward_metadata.block_tables = torch.full( | ||
| (bs, self.max_seqlen_pad), -1, dtype=torch.int32, device=device | ||
| ) | ||
| self.forward_metadata.block_tables[:, : block_tables.size(1)].copy_( | ||
| block_tables | ||
| ) | ||
| else: | ||
| self.forward_metadata.block_tables = block_tables | ||
|
|
||
| if forward_batch.extend_seq_lens is not None: | ||
| self.forward_metadata.extend_seq_lens_cpu_int = ( | ||
| forward_batch.extend_seq_lens.cpu().int() | ||
|
|
@@ -1145,6 +1169,17 @@ def forward_decode_graph( | |
| else: | ||
| actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int | ||
|
|
||
| if ( | ||
| self.enable_piecewise_npu_graph_decode | ||
| and torch.compiler.is_dynamo_compiling() | ||
| ): | ||
| # input args for submodule forward | ||
| forward_batch.req_to_token_pool.req_to_token.add_( | ||
| forward_batch.req_to_token_pool.req_to_token | ||
| ) | ||
| forward_batch.req_pool_indices.add_(forward_batch.req_pool_indices) | ||
| forward_batch.seq_lens.add_(forward_batch.seq_lens) | ||
|
Comment on lines
+1172
to
+1181
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of in-place |
||
|
|
||
| torch_npu._npu_paged_attention( | ||
| query=query, | ||
| key_cache=k_cache, | ||
|
|
@@ -1256,7 +1291,7 @@ def forward_decode( | |
| topk_indices, | ||
| ) | ||
|
|
||
| if self.graph_mode and (not self.enable_torch_compile): | ||
| if self.graph_mode and (not self.enable_torchair_compile): | ||
| return self.forward_decode_graph( | ||
| q, | ||
| k, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
| import sglang.srt.hardware_backend.npu.cmo | ||
| from sglang.srt.utils import direct_register_custom_op | ||
|
|
||
|
|
||
| @torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=()) | ||
| def wait_cmo_stream() -> None: | ||
| if sglang.srt.hardware_backend.npu.cmo.get_cmo_stream(): | ||
| sglang.srt.hardware_backend.npu.cmo.wait_cmo_stream() | ||
|
|
||
|
|
||
| @wait_cmo_stream.register_fake | ||
| def wait_cmo_stream_fake() -> None: | ||
| pass | ||
|
|
||
|
|
||
| def get_cmo_stream() -> bool: | ||
| return True | ||
|
|
||
|
|
||
| def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: | ||
| sglang.srt.hardware_backend.npu.cmo.prepare_weight_cache(handle, cache) | ||
|
|
||
|
|
||
| def prepare_weight_cache_register_fake( | ||
| handle: torch.Tensor, cache: List[torch.Tensor] | ||
| ) -> None: | ||
| pass | ||
|
|
||
|
|
||
| direct_register_custom_op( | ||
| op_name="prepare_weight_cache", | ||
| op_func=prepare_weight_cache, | ||
| mutates_args=["handle"], | ||
| fake_impl=prepare_weight_cache_register_fake, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # Copyright 2023-2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import torch_npu | ||
|
|
||
|
|
||
| class CompilationContext: | ||
| graph_memory_pool = None | ||
| stream: torch_npu.npu.Stream = None |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| from typing import List | ||
|
|
||
| import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope | ||
| import torch | ||
|
|
||
|
|
||
| @torch.library.custom_op("sglang::split_qkv_rmsnorm_rope", mutates_args=()) | ||
| def split_qkv_rmsnorm_rope( | ||
| input: torch.Tensor, | ||
| sin: torch.Tensor, | ||
| cos: torch.Tensor, | ||
| q_weight: torch.Tensor, | ||
| k_weight: torch.Tensor, | ||
| q_hidden_size: int, | ||
| kv_hiddem_size: int, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| head_dim: int, | ||
| eps: float, | ||
| q_bias: torch.Tensor, | ||
| k_bias: torch.Tensor, | ||
| ) -> List[torch.Tensor]: | ||
| q, k, v = sgl_kernel_npu.norm.split_qkv_rmsnorm_rope.split_qkv_rmsnorm_rope( | ||
| input, | ||
| sin, | ||
| cos, | ||
| q_weight, | ||
| k_weight, | ||
| q_hidden_size, | ||
| kv_hiddem_size, | ||
| head_dim, | ||
| eps, | ||
| q_bias, | ||
| k_bias, | ||
| ) | ||
| return [q, k, v] | ||
|
|
||
|
|
||
| @split_qkv_rmsnorm_rope.register_fake | ||
| def split_qkv_rmsnorm_rope( | ||
| input: torch.Tensor, | ||
| sin: torch.Tensor, | ||
| cos: torch.Tensor, | ||
| q_weight: torch.Tensor, | ||
| k_weight: torch.Tensor, | ||
| q_hidden_size: int, | ||
| kv_hiddem_size: int, | ||
| head_dim: int, | ||
| eps: float, | ||
| q_bias: torch.Tensor, | ||
| k_bias: torch.Tensor, | ||
| ) -> List[torch.Tensor]: | ||
| # TODO: generalize shape | ||
| q = torch.empty((128, 2048), dtype=input.dtype, device=input.device) | ||
| k = torch.empty((128, 256), dtype=input.dtype, device=input.device) | ||
| v = torch.empty((128, 256), dtype=input.dtype, device=input.device) | ||
| return [q, k, v] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hyperlinks in this document point to a personal fork (
eshoguli/sglang). For official documentation, these links should be updated to point to the mainsgl-project/sglangrepository or use relative paths to ensure they are valid and accessible for all users after the pull request is merged.