diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7c9ec61da990..00244f90f36b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -64,7 +64,10 @@ ) from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.patch_torch import monkey_patch_torch_reductions +from sglang.srt.patch_torch import ( + monkey_patch_torch_compile, + monkey_patch_torch_reductions, +) from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -88,6 +91,8 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 +monkey_patch_torch_compile() + class ModelRunner: """ModelRunner runs the forward passes of the models.""" diff --git a/python/sglang/srt/patch_torch.py b/python/sglang/srt/patch_torch.py index 32034b7044bc..8d90ce4c07e2 100644 --- a/python/sglang/srt/patch_torch.py +++ b/python/sglang/srt/patch_torch.py @@ -14,6 +14,7 @@ from typing import Callable, Union import torch +from packaging import version from torch.multiprocessing import reductions @@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: def _modify_tuple(t, index: int, modifier: Callable): return *t[:index], modifier(t[index]), *t[index + 1 :] + + +def monkey_patch_torch_compile(): + if version.parse(torch.__version__) < version.parse("2.8.0"): + # These things are cacheable by torch.compile. torch.compile just doesn't know it. + # This was fixed in PyTorch 2.8, but until then, we monkey patch. + import torch._higher_order_ops.auto_functionalize as af + + af.auto_functionalized_v2._cacheable = True + af.auto_functionalized._cacheable = True