Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/srt/compilation/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
from sglang.srt.compilation.pass_manager import PostGradPassManager
from sglang.srt.utils.common import rank0_log

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -357,6 +358,7 @@ def __init__(
config: CompilationConfig,
graph_pool: Any,
):
rank0_log(f"Initializing SGLangBackend")
assert graph_pool is not None
self.graph_pool = graph_pool

Expand All @@ -375,6 +377,7 @@ def configure_post_pass(self):
self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
rank0_log(f"SGLangBackend __call__")
base_cache_dir = os.path.expanduser(
os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/")
)
Expand Down Expand Up @@ -441,7 +444,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
with open(graph_path, "w") as f:
f.write(src)

logger.debug("Computation graph saved to %s", graph_path)
rank0_log(f"Computation graph saved to {graph_path}")

self._called = True
return self.split_gm
2 changes: 2 additions & 0 deletions python/sglang/srt/compilation/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from sglang.srt.compilation.compilation_config import CompilationConfig
from sglang.srt.utils.common import rank0_log

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,6 +130,7 @@ def install_torch_compiled(
fullgraph: bool = True,
graph_pool: Any = None,
):
rank0_log(f"install_torch_compiled")
unbound_fwd = module.__class__.forward
if not callable(unbound_fwd):
raise TypeError("module.__class__.forward must be callable")
Expand Down
44 changes: 29 additions & 15 deletions python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@
from sglang.srt.model_executor.model_runner import ModelRunner


@contextmanager
def disable_ca_comm(tp_group):
"""
Context manager to temporarily disable custom allreduce communication.

This is used during Piecewise CUDA graph capture to avoid custom allreduce operations
that may not be compatible with graph capture.

TODO(yuwei): Fix this
"""
old_disabled = None
try:
if tp_group.ca_comm is not None:
old_disabled = tp_group.ca_comm.disabled
tp_group.ca_comm.disabled = True
yield
finally:
if tp_group.ca_comm is not None and old_disabled is not None:
tp_group.ca_comm.disabled = old_disabled


@contextmanager
def freeze_gc(enable_cudagraph_gc: bool):
"""
Expand Down Expand Up @@ -207,7 +228,7 @@ def __init__(self, model_runner: ModelRunner):
)

with set_compiled(True):
self.warmup_and_capture()
self.warmup_torch_compile()

# Capture
try:
Expand All @@ -219,7 +240,8 @@ def __init__(self, model_runner: ModelRunner):

self.raw_num_tokens = 0

def warmup_and_capture(self):
def warmup_torch_compile(self):
"""Warmup the model with a simple forward pass before CUDA graph capture."""
num_tokens = 2
with torch.device(self.device):
forward_batch = ForwardBatch(
Expand Down Expand Up @@ -283,7 +305,7 @@ def warmup_and_capture(self):

with set_forward_context(
forward_batch, self.attention_layers, self.quant_config
):
), disable_ca_comm(self.model_runner.tp_group):
_ = self.model_runner.model.forward(
forward_batch.input_ids,
forward_batch.positions,
Expand Down Expand Up @@ -311,10 +333,9 @@ def capture(self) -> None:
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
if self.model_runner.tp_group.ca_comm is not None:
old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
self.model_runner.tp_group.ca_comm.disabled = True
with freeze_gc(
self.model_runner.server_args.enable_cudagraph_gc
), disable_ca_comm(self.model_runner.tp_group):
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
Expand Down Expand Up @@ -342,8 +363,6 @@ def capture(self) -> None:

# Save gemlite cache after each capture
save_gemlite_cache()
if self.model_runner.tp_group.ca_comm is not None:
self.model_runner.tp_group.ca_comm.disabled = old_ca_disable

def capture_one_batch_size(self, num_tokens: int):
bs = 1
Expand Down Expand Up @@ -565,10 +584,7 @@ def replay(
forward_batch: ForwardBatch,
**kwargs,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
with enable_piecewise_cuda_graph():
if self.model_runner.tp_group.ca_comm is not None:
old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
self.model_runner.tp_group.ca_comm.disabled = True
with enable_piecewise_cuda_graph(), disable_ca_comm(self.model_runner.tp_group):
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
# Replay
Expand Down Expand Up @@ -599,8 +615,6 @@ def replay(
raise NotImplementedError(
"PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
)
if self.model_runner.tp_group.ca_comm is not None:
self.model_runner.tp_group.ca_comm.disabled = old_ca_disable

def get_spec_info(self, num_tokens: int):
spec_info = None
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/tokenizer/tiktoken_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def apply_chat_template(
add_generation_prompt,
tools=None,
reasoning_effort=None,
**kwargs, # Accept additional parameters (e.g., return_dict) for compatibility
):
ret = self.chat_template_jinja.render(
messages=messages, add_generation_prompt=add_generation_prompt
Expand Down
Loading