-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Description
Motivation
To achieve the best possible performance across diverse hardware, workloads, and scale. Heavily optimized hand tuned kernels are often required. This requirement makes it harder for model developers (that just want to focus on the model architecture) to easily integrate there models into inference engines like sglang. Currently models need to be reimplemented using the set of optimized kernels provided by sgl-kernel and other kernel libraries. Not only does this duplicate a lot of code related to invocation of the kernels across different models it just adds more for the models authors to get the best performance.
With torch compile model authors don't need to care about the specifics of optimized kernels and can let the compiler do it for them.
Solution
For pure pytorch models torch compiles works great, all you need to do is decorate a function with @torch.compile and it will trace the tensor operations to generate a computation graph which is then compiled to optimized triton kernel (depending on the inductor backend). torch compile with just native pytorch code get you pretty far in terms of performance but seek out the last bit of performance custom hardware kernels are still required. torch compile can easily be extended with custom fusion passes completely decoupling model architecture from kernel optimisations and these passes can be reused across models
What does a custom fusion pass look like
Torch compile provides multiple way to perform surgery on the fx graph, we can drop down at the node level and perform matching and replacement manually but inductor ships with a nice pattern matcher (torch._inductor.pattern_matcher) which can make implementation quite clean and simple, as an example a fused swiglu pass would like the following
import torch._inductor.pattern_matcher as pm
patterns = pm.PatternMatcherPass()
def fused_swiglu_pattern(x, w):
mm = torch.ops.aten.mm.default(x, w)
result = torch.ops.aten.empty.memory_format(
[mm.shape[0], mm.shape[1] // 2],
dtype=mm.dtype,
device=mm.device,
pin_memory=False,
)
at = auto_functionalized_v2(
torch.ops.sgl_kernel.silu_and_mul.default,
input=mm,
_out_base_index=0,
_all_bases=[result],
)
return at[1]
def fused_swiglu_replacement(x, w):
return torch.ops.sglang.fused_swiglu.default(x, w)
example_inputs = [
torch.empty(16, 16).half().cuda(),
torch.empty(16, 16).half().cuda(),
]
pm.register_replacement(
fused_swiglu_pattern,
fused_swiglu_replacement,
example_inputs,
pm.fwd_only,
patterns,
)Patterns can be found by looking at the dynamo graph, and as long the replacement is registered as a torch operator with torch compile support, registration for the above operator looks like the following
from sglang.srt.utils import direct_register_custom_op
def fused_swiglu(
x: torch.Tensor,
w: torch.Tensor,
x_scale: Optional[torch.Tensor] = None,
w_scale: Optional[torch.Tensor] = None,
o_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# actual implementation
...
pass
def fused_swiglu_fake(
x: torch.Tensor,
w: torch.Tensor,
x_scale: Optional[torch.Tensor] = None,
w_scale: Optional[torch.Tensor] = None,
o_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# fake implementation is to just inform the compiler of the outputs
# if this is an inplace operator this function would be blank
M, N, K = x.shape[0], x.shape[1], w.shape[1] // 2
return torch.empty((M, K), device=x.device, dtype=x.dtype)
direct_register_custom_op(
op_name="fused_swiglu",
op_func=fused_swiglu,
mutates_args=[],
fake_impl=fused_swiglu_fake,
)Fake implementations are used by the compiler to check graph if a particular fusion pass can fit in the computation graph, using a fake implementation avoids the execution of actual operators during the compilation which would just slow down the compilation.
Finally you need to tell torch compile about the pattern matcher you just created
def sglang_fusion_passes(graph):
num_matches = patterns.apply(graph)
return num_matches
# actual registration into the backend
torch._inductor.config.post_grad_custom_post_pass = sglang_fusion_passesWith this any function/module passed to or decorated with torch.compile would run our fusion passes
Example of fusion passes
Here are just a few examples of possible fusions, lot more should be possible especially with different quant schemes
- SiLU and Mul (and other actications)
- SiLU and Mul + Quant (FP8)
- UpProject + SiLU and Mul (and other actications)
- UpProject + SiLU and Mul (and other actications) + Quant (FP8)
- FusedMOE + SiLU and Mul
- Attention + Quant (FP8)
- RMSNorm + Quant (FP8)
- Allreduce + RMSNorm
- Allreduce + RMSNorm + Quant (FP8)
Implementation (Proposed)
To get started we just need PassManager where custom passes can be registered, but I would like to lay down the foundation for features that may be adopted down the road, just on top of my head and looking at vllm's implementation
- Easily configurable, user should be able to tweak sglang specifc or general torch compilation config as well as enable/disable specific passes
- Usable with or without cuda graphs, currently torch.compile is only supported with cuda graph
- Inductor cache management for faster starts, inductor can do a lot of this for us need to look further do we need manual management
- Cuda graph integration, Piecewise graphs, I understand sglang manages cudagraphs for decode on it's own (
cuda_graph_runner.py) torch compile could do this too, I need too look more into this to fully understand which would be the best way - Observability and monitoring, we should easily be able to gets stats on which passes ran and how many replacements were done
- Unit testing, writing unit tests for custom fusion passes and testing them on actual model should be easy
Compilation Config
responsible for resolving the current config from server_args and env vars for torch compile as well as individual passes, this config object's hash would be used for cache invalidation if need manual management. any required validation and model specific tweaks to the config to be done here
Compilation Manager
Entry point for the compilation, responsible for registering the pass manager with inductor post grad passes and collecting compilation stats and metrics
Pass Manager
Responsible for registering individual passes custom passes based on compilation config
SGLang Inductor Pass
base class for defining passes, ideally passes should be composable in nature as logically similar passes could be registered with multiple variation (i.e with or without quant or multiple types of quant), sole purpose of this class is to provide utilities to enable the composable nature and avoid duplication of code
Operator registration
As required existing kernels and triton ops would need to be register as torch ops with torch compile support to avoid graph breaks and facilitate pattern replacement, e.g registration for the triton decode attention
def decode_attention_fwd_impl(
q: torch.Tensor,
k_buffer: torch.Tensor,
v_buffer: torch.Tensor,
o: torch.Tensor,
kv_indptr: torch.Tensor,
kv_indices: torch.Tensor,
attn_logits: torch.Tensor,
attn_lse: torch.Tensor,
num_kv_splits: torch.Tensor,
max_kv_splits: int,
sm_scale: float,
logit_cap: float = 0.0,
sinks: Optional[torch.Tensor] = None,
xai_temperature_len: int = -1,
) -> None:
assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1
assert q.shape[0] <= attn_logits.shape[0]
kv_group_num = q.shape[1] // v_buffer.shape[1]
if kv_group_num == 1:
# MHA
decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
kv_indptr,
kv_indices,
attn_logits,
attn_lse,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sinks=sinks,
xai_temperature_len=xai_temperature_len,
)
else:
# GQA/MQA/MLA
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
kv_indptr,
kv_indices,
attn_logits,
attn_lse,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sinks=sinks,
xai_temperature_len=xai_temperature_len,
)
def decode_attention_fwd_fake(
q,
k_buffer,
v_buffer,
o,
kv_indptr,
kv_indices,
attn_logits,
attn_lse,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap=0.0,
sinks=None,
xai_temperature_len=-1,
):
pass
def decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
kv_indptr,
kv_indices,
attn_logits,
attn_lse,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap=0.0,
sinks=None,
xai_temperature_len=-1,
):
torch.ops.sglang.decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
kv_indptr,
kv_indices,
attn_logits,
attn_lse,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap,
sinks,
xai_temperature_len,
)
direct_register_custom_op(
op_name="decode_attention_fwd",
op_func=decode_attention_fwd_impl,
mutates_args=["o", "attn_lse"],
fake_impl=decode_attention_fwd_fake,
)Usage of auto functionalize v2
functionalization is required for torch compile to optimise the IR, for mutable operators auto functionalize v1 achived functionalization using extra copies and an extra defunctionalization pass was required to get rid of these copies post compilation (https://github.com/vllm-project/vllm/blob/main/vllm/compilation/fix_functionalization.py), auto functionalize v2 does this for us under the hood. for a better explanation refer the following: https://dev-discuss.pytorch.org/t/a-new-strategy-for-automatic-custom-operators-functionalization/2733
Overview of vLLM's implementation
This is the execution flow of the complete compilation process according to the snooping around I have done
- starts with the
support_torch_compiledecorator which is added to model architecture for which custom passes are supported - this calls the custom dispatcher
TorchCompileWrapperWithCustomDispatcherwhich inits theVllmBackendbased on compilation config and calls torch.compile with the initialized backend VllmBackendon init creates thegraph pool(probably for the piecewise graphs),compilation manager(responsible for the actual compilation),pass manager(this where post grad passes get registered)VllmBackendwhen executed first computes get's the config has, this is used by the compilation manager to either load existing compiled cache (skip the compilation) or compile and save a new cache, then the backend creates the piece wise graphs (each of these will be a compiled cuda graph) this is done as all graphs breaks can't be avoided- Each graph is handled by a
PiecewiseCompileInterpreter(torch.fx.Interpreter)which internally calls the compilation manager CompilerManageris just a wrapper around existing inductor backend, this is done throughCompilerInterfacewhich is implemented byInductorStandaloneAdaptororInductorAdaptorother than executing inductor compilation, CompilerManager also manages create or loading the cache for the sub graph it's targeting
References
- https://blog.vllm.ai/2025/08/20/torch-compile.html
- https://arxiv.org/abs/2002.05202
- https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html#adding-torch-compile-support-for-an-operator
- https://dev-discuss.pytorch.org/t/a-new-strategy-for-automatic-custom-operators-functionalization/2733
- https://github.com/vllm-project/vllm/tree/main/vllm/compilation