Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import torch

from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch


Expand Down Expand Up @@ -97,8 +96,8 @@ def run_gate_up_lora(

def init_cuda_graph_batch_info(
self,
cuda_graph_batch_info: LoRABatchInfo,
max_bs_in_cuda_graph: int,
num_tokens_per_bs: int,
):
"""Initialize the batch info for CUDA Graph mode.

Expand All @@ -108,6 +107,7 @@ def init_cuda_graph_batch_info(
Args:
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
num_tokens_per_bs: number of tokens per sequence (1 for decoding, >1 for target_verify)
"""
pass

Expand All @@ -117,7 +117,7 @@ def prepare_lora_batch(
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
use_cuda_graph: bool,
):
"""Prepare the lora weights and batch info for current forward batch.

Expand All @@ -129,7 +129,6 @@ def prepare_lora_batch(
weight_indices: list of indices of lora weights to be applied for current batch
lora_ranks: list of lora ranks corresponding to weight_indices
scalings: list of scaling factors corresponding to weight_indices
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
use_cuda_graph: whether to use CUDA Graph for this batch
"""
pass
55 changes: 43 additions & 12 deletions python/sglang/srt/lora/backend/chunked_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import torch

from sglang.srt.lora.backend.base_backend import BaseLoRABackend
Expand Down Expand Up @@ -52,7 +50,7 @@ def run_lora_b_sgemm(
output_offset: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
# For simple lora B, we use slice offsets [0, output_dim]
output_dim = weights.shape[-2]
Expand All @@ -75,7 +73,7 @@ def run_qkv_lora(
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
*args,
**kwargs
**kwargs,
) -> torch.Tensor:

# x: (s, input_dim)
Expand Down Expand Up @@ -107,7 +105,7 @@ def run_gate_up_lora(
output_offset: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs
**kwargs,
) -> torch.Tensor:

# x: (s, input_dim)
Expand Down Expand Up @@ -160,13 +158,36 @@ def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
chunk_size = 16
return min(self.max_chunk_size, chunk_size)

def init_cuda_graph_batch_info(
self,
max_bs_in_cuda_graph: int,
num_tokens_per_bs: int,
):
max_num_segments = (
(num_tokens_per_bs + MIN_CHUNK_SIZE - 1) // MIN_CHUNK_SIZE
) * max_bs_in_cuda_graph
max_num_tokens = max_bs_in_cuda_graph * num_tokens_per_bs
with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo(
bs=max_bs_in_cuda_graph,
use_cuda_graph=True,
seg_lens=torch.zeros(max_num_segments, dtype=torch.int32),
seg_indptr=torch.zeros(max_num_segments + 1, dtype=torch.int32),
weight_indices=torch.zeros(max_num_segments, dtype=torch.int32),
permutation=torch.zeros(max_num_tokens, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
num_segments=None, # Set per batch
max_len=None, # Not used in CSGMV backend
)

def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
use_cuda_graph: bool,
):
chunk_size = self._determine_chunk_size(forward_batch)

Expand All @@ -188,7 +209,7 @@ def prepare_lora_batch(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)

if batch_info is None:
if not use_cuda_graph:
batch_info = LoRABatchInfo(
bs=forward_batch.batch_size,
num_segments=num_segments,
Expand All @@ -213,6 +234,7 @@ def prepare_lora_batch(
seg_lens=None,
)
else:
batch_info = self.cuda_graph_batch_info
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = num_segments
batch_info.max_len = chunk_size
Expand Down Expand Up @@ -262,14 +284,23 @@ def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
with torch.device("cpu"):
seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)

seg_lens_cpu = (
torch.tensor(
if forward_batch.forward_mode.is_decode():
seg_lens_cpu = torch.ones(forward_batch.batch_size, dtype=torch.int32)
elif forward_batch.forward_mode.is_target_verify():
seg_lens_cpu = torch.full(
size=(forward_batch.batch_size,),
fill_value=forward_batch.spec_info.draft_token_num,
dtype=torch.int32,
)
elif forward_batch.forward_mode.is_extend():
seg_lens_cpu = torch.tensor(
forward_batch.extend_seq_lens_cpu,
dtype=torch.int32,
)
if forward_batch.forward_mode.is_extend()
else torch.ones(forward_batch.batch_size, dtype=torch.int32)
)
else:
raise ValueError(
f"Unsupported forward mode: {forward_batch.forward_mode}"
)

row_weight_indices = torch.repeat_interleave(
seq_weight_indices, seg_lens_cpu
Expand Down
46 changes: 31 additions & 15 deletions python/sglang/srt/lora/backend/triton_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import torch

from sglang.srt.lora.backend.base_backend import BaseLoRABackend
Expand Down Expand Up @@ -97,24 +95,41 @@ def run_gate_up_lora(
return lora_output

def init_cuda_graph_batch_info(
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
self,
max_bs_in_cuda_graph: int,
num_tokens_per_bs: int,
):
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
torch.cumsum(
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
dim=0,
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
)
with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo(
bs=max_bs_in_cuda_graph,
use_cuda_graph=True,
num_segments=None,
seg_lens=torch.full(
(max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32
),
seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=num_tokens_per_bs,
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
permutation=None,
)

# Initialize seg_indptr for CUDA graph as they remain constant
# across batches.
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
)

def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
use_cuda_graph: bool,
):
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
Expand All @@ -129,10 +144,11 @@ def prepare_lora_batch(

bs = forward_batch.batch_size

if batch_info is not None:
if use_cuda_graph:
assert (
batch_info.use_cuda_graph
), "batch_info.use_cuda_graph must be True when batch_info is provided"
self.cuda_graph_batch_info is not None
), "CUDA Graph batch info is not initialized."
batch_info = self.cuda_graph_batch_info
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = forward_batch.batch_size
else:
Expand Down
23 changes: 5 additions & 18 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
LoRABatchInfo,
LoRAType,
get_layer_id,
get_normalized_target_modules,
Expand Down Expand Up @@ -95,25 +94,13 @@ def __init__(
lora_paths=lora_paths,
)

def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
def init_cuda_graph_batch_info(
self, max_bs_in_cuda_graph: int, num_tokens_per_bs: int
):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo(
bs=max_bs_in_cuda_graph,
use_cuda_graph=True,
num_segments=None,
seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=1,
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
)

self.lora_backend.init_cuda_graph_batch_info(
cuda_graph_batch_info=self.cuda_graph_batch_info,
max_bs_in_cuda_graph=max_bs_in_cuda_graph,
num_tokens_per_bs=num_tokens_per_bs,
)

def create_lora_update_result(
Expand Down Expand Up @@ -297,7 +284,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
use_cuda_graph=use_cuda_graph,
)

def update_lora_info(self):
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ class LoRABatchInfo:
# Number of segments. For triton backend, it is equal to batch size.
num_segments: int

# Maximum segment length of current batch
max_len: int

# Indice pointers of each segment in shape (num_segments + 1, )
seg_indptr: torch.Tensor

Expand All @@ -34,6 +31,9 @@ class LoRABatchInfo:
# scaling of each lora adapter, in shape (lora_num,)
scalings: torch.Tensor

# Maximum segment length of current batch
max_len: Optional[int]

# Lengths of each segments in shape (num_segments,)
seg_lens: Optional[torch.Tensor]

Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ def __init__(self, model_runner: ModelRunner):
set_torch_compile_config()

if self.model_runner.server_args.enable_lora:
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
self.model_runner.lora_manager.init_cuda_graph_batch_info(
max_bs_in_cuda_graph=self.max_bs,
num_tokens_per_bs=self.num_tokens_per_bs,
)

# Graph inputs
with torch.device(self.device):
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3874,6 +3874,13 @@ def check_lora_server_args(self):
)

if self.enable_lora:
# Validate compatibility with speculative decoding
if self.speculative_algorithm not in ["NGRAM", None]:
raise ValueError(
"Currently LoRA is only compatible with NGRAM speculative decoding."
)

# Parse lora_paths
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = []
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ def __init__(
speculative_num_steps: Optional[int] = None,
speculative_eagle_topk: Optional[int] = None,
speculative_num_draft_tokens: Optional[int] = None,
speculative_ngram_min_match_window_size: Optional[int] = None,
speculative_ngram_max_match_window_size: Optional[int] = None,
disable_overlap_schedule: bool = False,
disable_custom_all_reduce: bool = False,
torchao_config: Optional[str] = None,
Expand All @@ -539,6 +541,7 @@ def __init__(
max_loaded_loras: Optional[int] = None,
json_model_override_args: Optional[dict[str, Any]] = None,
lora_eviction_policy: str = "lru",
enable_deterministic_inference: bool = False,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
Expand All @@ -554,6 +557,14 @@ def __init__(
spec_kwargs["speculative_num_steps"] = speculative_num_steps
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
elif speculative_algorithm == "NGRAM":
spec_kwargs["speculative_algorithm"] = speculative_algorithm
spec_kwargs["speculative_ngram_min_match_window_size"] = (
speculative_ngram_min_match_window_size
)
spec_kwargs["speculative_ngram_max_match_window_size"] = (
speculative_ngram_max_match_window_size
)

self.engine = Engine(
model_path=model_path,
Expand Down Expand Up @@ -594,6 +605,7 @@ def __init__(
else "{}"
),
lora_eviction_policy=lora_eviction_policy,
enable_deterministic_inference=enable_deterministic_inference,
**spec_kwargs,
)

Expand Down
Loading
Loading