Skip to content

[RFC] SGLang unified kernel fusion and torch compile optimisations #10118

@DevashishLal-CB

Description

@DevashishLal-CB

Motivation

To achieve the best possible performance across diverse hardware, workloads, and scale. Heavily optimized hand tuned kernels are often required. This requirement makes it harder for model developers (that just want to focus on the model architecture) to easily integrate there models into inference engines like sglang. Currently models need to be reimplemented using the set of optimized kernels provided by sgl-kernel and other kernel libraries. Not only does this duplicate a lot of code related to invocation of the kernels across different models it just adds more for the models authors to get the best performance.

With torch compile model authors don't need to care about the specifics of optimized kernels and can let the compiler do it for them.

Solution

For pure pytorch models torch compiles works great, all you need to do is decorate a function with @torch.compile and it will trace the tensor operations to generate a computation graph which is then compiled to optimized triton kernel (depending on the inductor backend). torch compile with just native pytorch code get you pretty far in terms of performance but seek out the last bit of performance custom hardware kernels are still required. torch compile can easily be extended with custom fusion passes completely decoupling model architecture from kernel optimisations and these passes can be reused across models

What does a custom fusion pass look like

Torch compile provides multiple way to perform surgery on the fx graph, we can drop down at the node level and perform matching and replacement manually but inductor ships with a nice pattern matcher (torch._inductor.pattern_matcher) which can make implementation quite clean and simple, as an example a fused swiglu pass would like the following

import torch._inductor.pattern_matcher as pm

patterns = pm.PatternMatcherPass()

def fused_swiglu_pattern(x, w):
    mm = torch.ops.aten.mm.default(x, w)
    result = torch.ops.aten.empty.memory_format(
        [mm.shape[0], mm.shape[1] // 2],
        dtype=mm.dtype,
        device=mm.device,
        pin_memory=False,
    )
    at = auto_functionalized_v2(
        torch.ops.sgl_kernel.silu_and_mul.default,
        input=mm,
        _out_base_index=0,
        _all_bases=[result],
    )
    return at[1]


def fused_swiglu_replacement(x, w):
    return torch.ops.sglang.fused_swiglu.default(x, w)


example_inputs = [
    torch.empty(16, 16).half().cuda(),
    torch.empty(16, 16).half().cuda(),
]

pm.register_replacement(
    fused_swiglu_pattern,
    fused_swiglu_replacement,
    example_inputs,
    pm.fwd_only,
    patterns,
)

Patterns can be found by looking at the dynamo graph, and as long the replacement is registered as a torch operator with torch compile support, registration for the above operator looks like the following

from sglang.srt.utils import direct_register_custom_op

def fused_swiglu(
    x: torch.Tensor,
    w: torch.Tensor,
    x_scale: Optional[torch.Tensor] = None,
    w_scale: Optional[torch.Tensor] = None,
    o_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
	# actual implementation
	...
	pass


def fused_swiglu_fake(
    x: torch.Tensor,
    w: torch.Tensor,
    x_scale: Optional[torch.Tensor] = None,
    w_scale: Optional[torch.Tensor] = None,
    o_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    # fake implementation is to just inform the compiler of the outputs
    # if this is an inplace operator this function would be blank
    M, N, K = x.shape[0], x.shape[1], w.shape[1] // 2
    return torch.empty((M, K), device=x.device, dtype=x.dtype)


direct_register_custom_op(
    op_name="fused_swiglu",
    op_func=fused_swiglu,
    mutates_args=[],
    fake_impl=fused_swiglu_fake,
)

Fake implementations are used by the compiler to check graph if a particular fusion pass can fit in the computation graph, using a fake implementation avoids the execution of actual operators during the compilation which would just slow down the compilation.

Finally you need to tell torch compile about the pattern matcher you just created

def sglang_fusion_passes(graph):
    num_matches = patterns.apply(graph)
    return num_matches

# actual registration into the backend
torch._inductor.config.post_grad_custom_post_pass = sglang_fusion_passes

With this any function/module passed to or decorated with torch.compile would run our fusion passes

Example of fusion passes

Here are just a few examples of possible fusions, lot more should be possible especially with different quant schemes

  • SiLU and Mul (and other actications)
  • SiLU and Mul + Quant (FP8)
  • UpProject + SiLU and Mul (and other actications)
  • UpProject + SiLU and Mul (and other actications) + Quant (FP8)
  • FusedMOE + SiLU and Mul
  • Attention + Quant (FP8)
  • RMSNorm + Quant (FP8)
  • Allreduce + RMSNorm
  • Allreduce + RMSNorm + Quant (FP8)

Implementation (Proposed)

To get started we just need PassManager where custom passes can be registered, but I would like to lay down the foundation for features that may be adopted down the road, just on top of my head and looking at vllm's implementation

  • Easily configurable, user should be able to tweak sglang specifc or general torch compilation config as well as enable/disable specific passes
  • Usable with or without cuda graphs, currently torch.compile is only supported with cuda graph
  • Inductor cache management for faster starts, inductor can do a lot of this for us need to look further do we need manual management
  • Cuda graph integration, Piecewise graphs, I understand sglang manages cudagraphs for decode on it's own (cuda_graph_runner.py) torch compile could do this too, I need too look more into this to fully understand which would be the best way
  • Observability and monitoring, we should easily be able to gets stats on which passes ran and how many replacements were done
  • Unit testing, writing unit tests for custom fusion passes and testing them on actual model should be easy

Compilation Config

responsible for resolving the current config from server_args and env vars for torch compile as well as individual passes, this config object's hash would be used for cache invalidation if need manual management. any required validation and model specific tweaks to the config to be done here

Compilation Manager

Entry point for the compilation, responsible for registering the pass manager with inductor post grad passes and collecting compilation stats and metrics

Pass Manager

Responsible for registering individual passes custom passes based on compilation config

SGLang Inductor Pass

base class for defining passes, ideally passes should be composable in nature as logically similar passes could be registered with multiple variation (i.e with or without quant or multiple types of quant), sole purpose of this class is to provide utilities to enable the composable nature and avoid duplication of code

Operator registration

As required existing kernels and triton ops would need to be register as torch ops with torch compile support to avoid graph breaks and facilitate pattern replacement, e.g registration for the triton decode attention

def decode_attention_fwd_impl(
    q: torch.Tensor,
    k_buffer: torch.Tensor,
    v_buffer: torch.Tensor,
    o: torch.Tensor,
    kv_indptr: torch.Tensor,
    kv_indices: torch.Tensor,
    attn_logits: torch.Tensor,
    attn_lse: torch.Tensor,
    num_kv_splits: torch.Tensor,
    max_kv_splits: int,
    sm_scale: float,
    logit_cap: float = 0.0,
    sinks: Optional[torch.Tensor] = None,
    xai_temperature_len: int = -1,
) -> None:
    assert max_kv_splits == attn_logits.shape[2]
    assert q.shape[0] <= kv_indptr.shape[0] - 1
    assert q.shape[0] <= attn_logits.shape[0]

    kv_group_num = q.shape[1] // v_buffer.shape[1]

    if kv_group_num == 1:
        # MHA
        decode_attention_fwd_normal(
            q,
            k_buffer,
            v_buffer,
            o,
            kv_indptr,
            kv_indices,
            attn_logits,
            attn_lse,
            num_kv_splits,
            max_kv_splits,
            sm_scale,
            logit_cap=logit_cap,
            sinks=sinks,
            xai_temperature_len=xai_temperature_len,
        )
    else:
        # GQA/MQA/MLA
        decode_attention_fwd_grouped(
            q,
            k_buffer,
            v_buffer,
            o,
            kv_indptr,
            kv_indices,
            attn_logits,
            attn_lse,
            num_kv_splits,
            max_kv_splits,
            sm_scale,
            logit_cap=logit_cap,
            sinks=sinks,
            xai_temperature_len=xai_temperature_len,
        )


def decode_attention_fwd_fake(
    q,
    k_buffer,
    v_buffer,
    o,
    kv_indptr,
    kv_indices,
    attn_logits,
    attn_lse,
    num_kv_splits,
    max_kv_splits,
    sm_scale,
    logit_cap=0.0,
    sinks=None,
    xai_temperature_len=-1,
):
    pass


def decode_attention_fwd(
    q,
    k_buffer,
    v_buffer,
    o,
    kv_indptr,
    kv_indices,
    attn_logits,
    attn_lse,
    num_kv_splits,
    max_kv_splits,
    sm_scale,
    logit_cap=0.0,
    sinks=None,
    xai_temperature_len=-1,
):
    torch.ops.sglang.decode_attention_fwd(
        q,
        k_buffer,
        v_buffer,
        o,
        kv_indptr,
        kv_indices,
        attn_logits,
        attn_lse,
        num_kv_splits,
        max_kv_splits,
        sm_scale,
        logit_cap,
        sinks,
        xai_temperature_len,
    )


direct_register_custom_op(
    op_name="decode_attention_fwd",
    op_func=decode_attention_fwd_impl,
    mutates_args=["o", "attn_lse"],
    fake_impl=decode_attention_fwd_fake,
)

Usage of auto functionalize v2

functionalization is required for torch compile to optimise the IR, for mutable operators auto functionalize v1 achived functionalization using extra copies and an extra defunctionalization pass was required to get rid of these copies post compilation (https://github.com/vllm-project/vllm/blob/main/vllm/compilation/fix_functionalization.py), auto functionalize v2 does this for us under the hood. for a better explanation refer the following: https://dev-discuss.pytorch.org/t/a-new-strategy-for-automatic-custom-operators-functionalization/2733

Overview of vLLM's implementation

This is the execution flow of the complete compilation process according to the snooping around I have done

  • starts with the support_torch_compile decorator which is added to model architecture for which custom passes are supported
  • this calls the custom dispatcher TorchCompileWrapperWithCustomDispatcher which inits the VllmBackend based on compilation config and calls torch.compile with the initialized backend
  • VllmBackend on init creates the graph pool (probably for the piecewise graphs), compilation manager (responsible for the actual compilation), pass manager (this where post grad passes get registered)
  • VllmBackend when executed first computes get's the config has, this is used by the compilation manager to either load existing compiled cache (skip the compilation) or compile and save a new cache, then the backend creates the piece wise graphs (each of these will be a compiled cuda graph) this is done as all graphs breaks can't be avoided
  • Each graph is handled by a PiecewiseCompileInterpreter(torch.fx.Interpreter) which internally calls the compilation manager
  • CompilerManager is just a wrapper around existing inductor backend, this is done through CompilerInterface which is implemented by InductorStandaloneAdaptor or InductorAdaptor other than executing inductor compilation, CompilerManager also manages create or loading the cache for the sub graph it's targeting

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions