Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
4ce70e6
NPU Graph Compilation support and PassManager with AddRmsNorm & Quantize
eshoguli Sep 17, 2025
b974460
pre-commit & refactoring
eshoguli Oct 31, 2025
7a7bde7
pre-commit
qyqc731 Nov 1, 2025
d8e2dc3
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 5, 2025
9bb7751
Merge branch 'main' into eshogulin/pass_manager: fix - custom_ops.py
eshoguli Nov 5, 2025
7048005
cleanup & refactoring
eshoguli Nov 10, 2025
eb240d9
Pass Manager fix
eshoguli Nov 10, 2025
29c1d89
Compilation: refactoring
eshoguli Nov 11, 2025
3e98d17
NPU Piecewise Graph
eshoguli Nov 8, 2025
3d9516a
rollback
eshoguli Nov 11, 2025
2c1b6fe
linter
eshoguli Nov 11, 2025
55016b0
refactoring
eshoguli Nov 11, 2025
fbff08d
refactoring
eshoguli Nov 11, 2025
3e5db77
Compilation: refactoring
eshoguli Nov 11, 2025
99d4497
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 12, 2025
30da7fe
model_type check
eshoguli Nov 13, 2025
36ef7e7
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 14, 2025
1808479
PiecewiseNpuGraphCompilerBackend quick fix
Nov 14, 2025
bcfc2c5
CompilationConfig reusage
Nov 17, 2025
a6a159d
--torch-compile-max-bs support
Nov 18, 2025
c08d076
TorchAir compilation support
XDaoHong Nov 14, 2025
73f2ee9
runner selection fix: model forward usage
eshoguli Nov 19, 2025
2f97641
add test for torchair
XDaoHong Nov 19, 2025
7154cf4
TorchAir compilation support: refactoring
eshoguli Nov 19, 2025
dfaee00
NPU Piecewise Graph: refactoring
eshoguli Nov 19, 2025
bec1b28
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 19, 2025
253c14d
linter fix after merge commit
eshoguli Nov 19, 2025
85d808e
NPUGraph compilation (fp16) & NPU Piecewise Graph tests
eshoguli Nov 19, 2025
11074d9
TorchAir compilation support: refactoring 2
eshoguli Nov 19, 2025
51ac4b4
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 Nov 20, 2025
e06675b
CompilationConfig comments fix + linter fix
eshoguli Nov 21, 2025
0c09c24
backend instantiation in get_compiler_backend
eshoguli Nov 21, 2025
00a0b9b
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 23, 2025
0b31746
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manage…
eshoguli Nov 25, 2025
3b5c83b
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 26, 2025
7eefeee
linter fix
eshoguli Nov 26, 2025
8c63980
dynamo patch removing
eshoguli Nov 26, 2025
2e02568
fix on main branch: compilation
eshoguli Nov 27, 2025
966bbf4
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 27, 2025
14092b3
auto merge fix
eshoguli Nov 27, 2025
f989147
tests suit update
eshoguli Nov 27, 2025
bf1251d
Add npu_add_rms_norm_dynamic_quant fuse
OrangeRedeng Nov 27, 2025
317174b
Merge branch 'eshogulin/pass_manager' of https://github.com/eshoguli/…
OrangeRedeng Nov 27, 2025
e6eb29c
NPU Graph compilation: attention architecture check
eshoguli Nov 27, 2025
caba95e
Add npu_add_rms_norm_dynamic_quant fuse: quick fix
eshoguli Nov 27, 2025
3f87879
Qwen3 MoE compilation support for NPU
eshoguli Nov 28, 2025
a2046c3
Merge branch 'main' into eshogulin/pass_manager
eshoguli Dec 2, 2025
faea888
linter quick fix
eshoguli Dec 2, 2025
85720d6
SlitQkvRmsnormRopeFuse fuse
eshoguli Nov 29, 2025
fd0e1e8
headers quick fix
eshoguli Dec 2, 2025
58966d6
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 Dec 2, 2025
a85ab1f
Merge branch 'main' into eshogulin/pass_manager
eshoguli Dec 3, 2025
4a61b7e
lint after merge + Piecewise Graph fix
eshoguli Dec 3, 2025
17f0af5
enable_torch_compile update rollback
eshoguli Dec 3, 2025
ad76e3c
Merge branch 'main' into eshogulin/pass_manager
eshoguli Dec 4, 2025
752657c
Merge fixes: moving in accordance with refactoring & cleanup
eshoguli Dec 7, 2025
3ac87be
Merge fixes: 1) updated cache prefetch support 2) ModelWeightParamete…
eshoguli Dec 7, 2025
b79fc0b
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 Dec 9, 2025
105050f
cleanup
eshoguli Dec 10, 2025
90caee2
Cleanup & ComilationConfig usage update & master merge refactoring
eshoguli Dec 10, 2025
a5e87f6
Compilation backends: model type quick fix
eshoguli Dec 11, 2025
daf81b2
TorchAir compilation backend: Ascend attention backend quick fix
eshoguli Dec 11, 2025
ea25b3f
torchair compilation test fix
eshoguli Dec 12, 2025
3365d71
Merge branch 'main' into eshogulin/pass_manager
eshoguli Dec 12, 2025
40389dd
Capturing compiled code issue: fix - dynamo patching
eshoguli Dec 12, 2025
f5424d8
comments fix
eshoguli Dec 12, 2025
55a1e06
Documentation
eshoguli Dec 12, 2025
fd28ac6
cleanup & fuse quick fix: compilation & piecewise
eshoguli Dec 15, 2025
97d654e
TorchAir support: inference fix & refactoring
eshoguli Dec 15, 2025
3ce92e8
Comment fixes + refactoring
eshoguli Dec 15, 2025
f4dfef3
Piecewise Graph Runner refactoring
eshoguli Dec 16, 2025
123e36c
PiecewiseGraph runner quick fix
eshoguli Dec 16, 2025
b3e2fe8
enable_torch_npugraph_ex_compile
XDaoHong Dec 16, 2025
ebcc846
linter fixes
eshoguli Dec 16, 2025
132581a
fix after merge torchair
eshoguli Dec 16, 2025
81a392c
Merge branch 'main' into eshogulin/pass_manager
eshoguli Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/platforms/ascend_npu_pass_development.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## How to transform model instances with PyTorch FX Toolkit in SGLang for NPU

### PassManager
`PassManager` is implemented here: [PassManager](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The hyperlinks in this document point to a personal fork (eshoguli/sglang). For official documentation, these links should be updated to point to the main sgl-project/sglang repository or use relative paths to ensure they are valid and accessible for all users after the pull request is merged.



You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. [`PiecewiseNpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py) compiler backed uses `PassManager` too via `NpuGraphCompilerBackend` inheritance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo here. "compiler backed" should be "compiler backend".

Suggested change
You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. [`PiecewiseNpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py) compiler backed uses `PassManager` too via `NpuGraphCompilerBackend` inheritance.
You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. [`PiecewiseNpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py) compiler backend uses `PassManager` too via `NpuGraphCompilerBackend` inheritance.


### Pass development
There are two approaches to develop passes for SGLang NPU PassManager:

1. Matches all possible non-overlapping sets of operators and their data dependencies with `torch.fx.replace_pattern` api.
Pass example: [NpuAddRmsNormQuantFuse](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L82).
You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#subgraph-rewriting-with-replace-pattern

2. Direct Graph Manipulation.
Pass example: [EraseCopy](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L28).
You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#direct-graph-manipulation

### Compiler backend update
After pass development you should create `PassManager` instance, add the pass and call `apply` method:
```
def apply_passes(self, graph_module: torch.fx.GraphModule):
passManager = PassManager(graph_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable name passManager does not follow the PEP 8 style guide, which recommends snake_case for variable names. It should be renamed to pass_manager for consistency.

Suggested change
passManager = PassManager(graph_module)
pass_manager = PassManager(graph_module)

passManager.add(NpuAddRmsNormQuantFuse)
passManager.apply()
graph_module.recompile()
```

You can explore [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) as example.
1 change: 1 addition & 0 deletions docs/platforms/ascend_npu_support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ Ascend NPUs

ascend_npu.md
ascend_npu_deepseek_example.md
ascend_npu_pass_development.md
ascend_npu_qwen3_examples.md
13 changes: 12 additions & 1 deletion python/sglang/srt/compilation/compilation_config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py

import json
from typing import List


# TODO(Yuwei): support better compile config support
class CompilationConfig:
def __init__(
self,
capture_sizes: List[int],
capture_sizes: List[int] = [],
compiler: str = "eager",
enable_debug_mode: bool = False,
splitting_ops: List[str] = [],
):
self.traced_files = set()
self.capture_sizes = capture_sizes
self.compiler = compiler
self.enable_debug_mode = enable_debug_mode
self.splitting_ops = splitting_ops
Comment on lines 9 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using mutable default arguments like [] is a common pitfall in Python. The same list instance will be shared across all CompilationConfig objects created without explicitly providing capture_sizes or splitting_ops. If one instance modifies its list, it will affect all other instances. To avoid this potential bug, you should use None as the default value and initialize a new list inside __init__ if the argument is None. You will also need to import Optional from typing.

Suggested change
def __init__(
self,
capture_sizes: List[int],
capture_sizes: List[int] = [],
compiler: str = "eager",
enable_debug_mode: bool = False,
splitting_ops: List[str] = [],
):
self.traced_files = set()
self.capture_sizes = capture_sizes
self.compiler = compiler
self.enable_debug_mode = enable_debug_mode
self.splitting_ops = splitting_ops
def __init__(
self,
capture_sizes: Optional[List[int]] = None,
compiler: str = "eager",
enable_debug_mode: bool = False,
splitting_ops: Optional[List[str]] = None,
):
self.traced_files = set()
self.capture_sizes = capture_sizes if capture_sizes is not None else []
self.compiler = compiler
self.enable_debug_mode = enable_debug_mode
self.splitting_ops = splitting_ops if splitting_ops is not None else []


def add_traced_file(self, file_path: str):
self.traced_files.add(file_path)
Expand All @@ -25,5 +28,13 @@ def get_traced_files(self):
def get_capture_sizes(self):
return self.capture_sizes

@classmethod
def from_cli(cls, args) -> "CompilationConfig":
args_dict = json.loads(args)
return CompilationConfig(**args_dict)

def get_enable_debug_mode(self):
return self.enable_debug_mode

def get_splitting_ops(self):
return self.splitting_ops
52 changes: 52 additions & 0 deletions python/sglang/srt/compilation/custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import List, Optional

import torch

import sglang.srt.layers.dp_attention


@torch.library.custom_op("sglang::_set_dp_buffer_len", mutates_args=())
def _set_dp_buffer_len(
global_dp_buffer_len: Optional[int],
num_tokens: Optional[int],
is_max_len: bool,
global_num_tokens: Optional[List[int]] = None,
) -> None:
global set_dp_buffer_len_original
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The global variable set_dp_buffer_len_original is declared but never used within this function or module. It appears to be leftover code and should be removed to improve clarity.

sglang.srt.layers.dp_attention.set_dp_buffer_len(
global_dp_buffer_len, num_tokens, is_max_len, global_num_tokens
)


@_set_dp_buffer_len.register_fake
def _set_dp_buffer_len_fake(
global_dp_buffer_len: Optional[int],
num_tokens: Optional[int],
is_max_len: bool,
global_num_tokens: Optional[List[int]] = None,
) -> None:
pass


@torch.library.custom_op("sglang::_set_is_extend_in_batch", mutates_args=())
def _set_is_extend_in_batch(is_extend_in_batch: bool) -> None:
sglang.srt.layers.dp_attention.set_is_extend_in_batch(is_extend_in_batch)


@_set_is_extend_in_batch.register_fake
def _set_is_extend_in_batch_fake(is_extend_in_batch: bool) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ class AscendAttnBackend(AttentionBackend):

def __init__(self, model_runner: ModelRunner):
super().__init__()
self.enable_piecewise_npu_graph_decode = (
model_runner.server_args.enable_piecewise_npu_graph_decode
)
self.forward_metadata = None
self.device = model_runner.device
self.page_size = model_runner.page_size
Expand All @@ -248,7 +251,6 @@ def __init__(self, model_runner: ModelRunner):
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.graph_mode = False
self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False")
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
)
Expand All @@ -264,6 +266,11 @@ def __init__(self, model_runner: ModelRunner):
if self.use_mla:
self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask

self.enable_torchair_compile = model_runner.server_args.enable_torchair_compile
if self.enable_torchair_compile:
max_total_tokens = model_runner.max_total_num_tokens
self.max_seqlen_pad = max_total_tokens // model_runner.server_args.page_size

def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.
Expand All @@ -283,12 +290,29 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
seq_lens_max = forward_batch.seq_lens.max()
if forward_batch.forward_mode.is_target_verify():
seq_lens_max += self.speculative_num_draft_tokens
self.forward_metadata.block_tables = (

block_tables = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :seq_lens_max
][:, :: self.page_size]
// self.page_size
)

if (
self.enable_torchair_compile
and forward_batch.forward_mode.is_decode_or_idle()
):
bs = forward_batch.input_ids.size(0)
device = forward_batch.input_ids.device
self.forward_metadata.block_tables = torch.full(
(bs, self.max_seqlen_pad), -1, dtype=torch.int32, device=device
)
self.forward_metadata.block_tables[:, : block_tables.size(1)].copy_(
block_tables
)
else:
self.forward_metadata.block_tables = block_tables

if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int = (
forward_batch.extend_seq_lens.cpu().int()
Expand Down Expand Up @@ -1145,6 +1169,17 @@ def forward_decode_graph(
else:
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int

if (
self.enable_piecewise_npu_graph_decode
and torch.compiler.is_dynamo_compiling()
):
# input args for submodule forward
forward_batch.req_to_token_pool.req_to_token.add_(
forward_batch.req_to_token_pool.req_to_token
)
forward_batch.req_pool_indices.add_(forward_batch.req_pool_indices)
forward_batch.seq_lens.add_(forward_batch.seq_lens)
Comment on lines +1172 to +1181
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of in-place add_ operations on tensors to mark them as inputs for torch.compile is an obscure hack. This code doubles the tensor values, which could lead to correctness issues if not handled carefully downstream. While this might be a necessary workaround for Dynamo, it severely impacts code readability and maintainability. Please add a detailed comment explaining why this is necessary and what it achieves. A better long-term solution would be to find a more explicit way to register tensor dependencies for graph compilation.


torch_npu._npu_paged_attention(
query=query,
key_cache=k_cache,
Expand Down Expand Up @@ -1256,7 +1291,7 @@ def forward_decode(
topk_indices,
)

if self.graph_mode and (not self.enable_torch_compile):
if self.graph_mode and (not self.enable_torchair_compile):
return self.forward_decode_graph(
q,
k,
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/hardware_backend/npu/cmo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

from sglang.srt.layers.parameter import ModelWeightParameter

cmo_stream = None


Expand All @@ -18,6 +20,12 @@ def set_cmo_stream(stream):
cmo_stream = stream


def get_weight_cache(layer):
if isinstance(layer.weight, ModelWeightParameter):
return layer.weight_data
return layer.weight


def prepare_weight_cache(handle, cache, PREFETCH_MAX_SIZE=1000000000):
"""
PREFETCH_MAX_SIZE: maximum size (bytes) for each prefetch operation.
Expand Down
39 changes: 39 additions & 0 deletions python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List

import torch

import sglang.srt.hardware_backend.npu.cmo
from sglang.srt.utils import direct_register_custom_op


@torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=())
def wait_cmo_stream() -> None:
if sglang.srt.hardware_backend.npu.cmo.get_cmo_stream():
sglang.srt.hardware_backend.npu.cmo.wait_cmo_stream()


@wait_cmo_stream.register_fake
def wait_cmo_stream_fake() -> None:
pass


def get_cmo_stream() -> bool:
return True


def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None:
sglang.srt.hardware_backend.npu.cmo.prepare_weight_cache(handle, cache)


def prepare_weight_cache_register_fake(
handle: torch.Tensor, cache: List[torch.Tensor]
) -> None:
pass


direct_register_custom_op(
op_name="prepare_weight_cache",
op_func=prepare_weight_cache,
mutates_args=["handle"],
fake_impl=prepare_weight_cache_register_fake,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch_npu


class CompilationContext:
graph_memory_pool = None
stream: torch_npu.npu.Stream = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import List

import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope
import torch


@torch.library.custom_op("sglang::split_qkv_rmsnorm_rope", mutates_args=())
def split_qkv_rmsnorm_rope(
input: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_hidden_size: int,
kv_hiddem_size: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the parameter name kv_hiddem_size. It should be kv_hidden_size. This typo is present in the function definition, its usage, and the fake implementation. Correcting it will improve code clarity and consistency.

Suggested change
kv_hiddem_size: int,
kv_hidden_size: int,

head_dim: int,
eps: float,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
) -> List[torch.Tensor]:
q, k, v = sgl_kernel_npu.norm.split_qkv_rmsnorm_rope.split_qkv_rmsnorm_rope(
input,
sin,
cos,
q_weight,
k_weight,
q_hidden_size,
kv_hiddem_size,
head_dim,
eps,
q_bias,
k_bias,
)
return [q, k, v]


@split_qkv_rmsnorm_rope.register_fake
def split_qkv_rmsnorm_rope(
input: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_hidden_size: int,
kv_hiddem_size: int,
head_dim: int,
eps: float,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
) -> List[torch.Tensor]:
# TODO: generalize shape
q = torch.empty((128, 2048), dtype=input.dtype, device=input.device)
k = torch.empty((128, 256), dtype=input.dtype, device=input.device)
v = torch.empty((128, 256), dtype=input.dtype, device=input.device)
return [q, k, v]
Loading
Loading