From e519e75e0ab3842d1e77ab1a002e5934a6240f60 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Mon, 10 Nov 2025 07:09:04 +0000 Subject: [PATCH 1/3] Support spec decoding for lora. --- .../sglang/srt/lora/backend/base_backend.py | 2 + .../srt/lora/backend/chunked_backend.py | 25 ++- .../sglang/srt/lora/backend/triton_backend.py | 8 +- python/sglang/srt/lora/lora_manager.py | 11 +- python/sglang/srt/lora/utils.py | 6 +- .../srt/model_executor/cuda_graph_runner.py | 5 +- python/sglang/test/runners.py | 13 ++ test/srt/lora/test_lora.py | 127 +------------ test/srt/lora/test_lora_qwen3.py | 165 +---------------- test/srt/lora/test_lora_spec_decoding.py | 63 +++++++ test/srt/lora/utils.py | 174 ++++++++++++++++++ test/srt/run_suite.py | 3 +- 12 files changed, 299 insertions(+), 303 deletions(-) create mode 100644 test/srt/lora/test_lora_spec_decoding.py diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index 4d241f931682..d59af9ea34d0 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -99,6 +99,7 @@ def init_cuda_graph_batch_info( self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int, + num_tokens_per_bs: int, ): """Initialize the batch info for CUDA Graph mode. @@ -108,6 +109,7 @@ def init_cuda_graph_batch_info( Args: cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode + num_tokens_per_bs: number of tokens per sequence (1 for decoding, >1 for target_verify) """ pass diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index 2c460d7c1f70..1e4fb1ab9e91 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -52,7 +52,7 @@ def run_lora_b_sgemm( output_offset: torch.Tensor, base_output: torch.Tensor = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: # For simple lora B, we use slice offsets [0, output_dim] output_dim = weights.shape[-2] @@ -75,7 +75,7 @@ def run_qkv_lora( max_qkv_out_dim: int, base_output: torch.Tensor = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: # x: (s, input_dim) @@ -107,7 +107,7 @@ def run_gate_up_lora( output_offset: torch.Tensor, base_output: torch.Tensor = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: # x: (s, input_dim) @@ -262,14 +262,23 @@ def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch): with torch.device("cpu"): seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32) - seg_lens_cpu = ( - torch.tensor( + if forward_batch.forward_mode.is_decode(): + seg_lens_cpu = torch.ones(forward_batch.batch_size, dtype=torch.int32) + elif forward_batch.forward_mode.is_target_verify(): + seg_lens_cpu = torch.full( + size=(forward_batch.batch_size,), + fill_value=forward_batch.spec_info.draft_token_num, + dtype=torch.int32, + ) + elif forward_batch.forward_mode.is_extend(): + seg_lens_cpu = torch.tensor( forward_batch.extend_seq_lens_cpu, dtype=torch.int32, ) - if forward_batch.forward_mode.is_extend() - else torch.ones(forward_batch.batch_size, dtype=torch.int32) - ) + else: + raise ValueError( + f"Unsupported forward mode: {forward_batch.forward_mode}" + ) row_weight_indices = torch.repeat_interleave( seq_weight_indices, seg_lens_cpu diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index 722915efc51e..b11a46b6d854 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -97,11 +97,15 @@ def run_gate_up_lora( return lora_output def init_cuda_graph_batch_info( - self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int + self, + cuda_graph_batch_info: LoRABatchInfo, + max_bs_in_cuda_graph: int, + num_tokens_per_bs: int, ): # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant # across batches. - cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1) + cuda_graph_batch_info.max_len = num_tokens_per_bs + cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(num_tokens_per_bs) torch.cumsum( cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph], dim=0, diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 19ff874dc1da..f4e6013d69ab 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -87,7 +87,9 @@ def __init__( lora_paths=lora_paths, ) - def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): + def init_cuda_graph_batch_info( + self, max_bs_in_cuda_graph: int, num_tokens_per_bs: int + ): self.max_bs_in_cuda_graph = max_bs_in_cuda_graph with torch.device("cuda"): self.cuda_graph_batch_info = LoRABatchInfo( @@ -96,16 +98,19 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): num_segments=None, seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32), - max_len=1, + max_len=None, weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), - permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), + permutation=torch.zeros( + max_bs_in_cuda_graph * num_tokens_per_bs, dtype=torch.int32 + ), ) self.lora_backend.init_cuda_graph_batch_info( cuda_graph_batch_info=self.cuda_graph_batch_info, max_bs_in_cuda_graph=max_bs_in_cuda_graph, + num_tokens_per_bs=num_tokens_per_bs, ) def create_lora_update_result( diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 7037fc4a686c..b0ed5bfc4a99 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -19,9 +19,6 @@ class LoRABatchInfo: # Number of segments. For triton backend, it is equal to batch size. num_segments: int - # Maximum segment length of current batch - max_len: int - # Indice pointers of each segment in shape (num_segments + 1, ) seg_indptr: torch.Tensor @@ -34,6 +31,9 @@ class LoRABatchInfo: # scaling of each lora adapter, in shape (lora_num,) scalings: torch.Tensor + # Maximum segment length of current batch + max_len: Optional[int] + # Lengths of each segments in shape (num_segments,) seg_lens: Optional[torch.Tensor] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 243b9a84b286..a518b0d89ea9 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -299,7 +299,10 @@ def __init__(self, model_runner: ModelRunner): set_torch_compile_config() if self.model_runner.server_args.enable_lora: - self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) + self.model_runner.lora_manager.init_cuda_graph_batch_info( + max_bs_in_cuda_graph=self.max_bs, + num_tokens_per_bs=self.num_tokens_per_bs, + ) # Graph inputs with torch.device(self.device): diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index e469a3c035a6..a7205d4c0bc8 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -528,6 +528,8 @@ def __init__( speculative_num_steps: Optional[int] = None, speculative_eagle_topk: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None, + speculative_ngram_min_match_window_size: Optional[int] = None, + speculative_ngram_max_match_window_size: Optional[int] = None, disable_overlap_schedule: bool = False, disable_custom_all_reduce: bool = False, torchao_config: Optional[str] = None, @@ -539,6 +541,7 @@ def __init__( max_loaded_loras: Optional[int] = None, json_model_override_args: Optional[dict[str, Any]] = None, lora_eviction_policy: str = "lru", + enable_deterministic_inference: bool = False, ): self.model_type = model_type self.is_generation = model_type == "generation" @@ -555,6 +558,15 @@ def __init__( spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens + if speculative_algorithm == "NGRAM": + spec_kwargs["speculative_algorithm"] = speculative_algorithm + spec_kwargs["speculative_ngram_min_match_window_size"] = ( + speculative_ngram_min_match_window_size + ) + spec_kwargs["speculative_ngram_max_match_window_size"] = ( + speculative_ngram_max_match_window_size + ) + self.engine = Engine( model_path=model_path, tp_size=tp_size, @@ -594,6 +606,7 @@ def __init__( else "{}" ), lora_eviction_policy=lora_eviction_policy, + enable_deterministic_inference=enable_deterministic_inference, **spec_kwargs, ) diff --git a/test/srt/lora/test_lora.py b/test/srt/lora/test_lora.py index 3ab7b624d676..d918febb2a4c 100644 --- a/test/srt/lora/test_lora.py +++ b/test/srt/lora/test_lora.py @@ -14,139 +14,20 @@ import multiprocessing as mp import os -import random import unittest -from typing import List -import torch from utils import ( ALL_OTHER_MULTI_LORA_MODELS, CI_MULTI_LORA_MODELS, - TORCH_DTYPES, - LoRAModelCase, - ensure_reproducibility, + run_lora_multiple_batch_on_model_cases, ) -from sglang.test.runners import HFRunner, SRTRunner -from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci - -TEST_MULTIPLE_BATCH_PROMPTS = [ - """ - ### Instruction: - Tell me about llamas and alpacas - ### Response: - Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. - ### Question 2: - What do you know about llamas? - ### Answer: - """, - """ - ### Instruction: - Write a poem about the transformers Python library. - Mention the word "large language models" in that poem. - ### Response: - The Transformers are large language models, - They're used to make predictions on text. - """, - "AI is a field of computer science focused on", - "Computer science is the study of", - "Write a short story.", - "What are the main components of a computer?", -] +from sglang.test.test_utils import CustomTestCase, is_in_ci class TestLoRA(CustomTestCase): - def _create_test_samples( - self, lora_adapter_paths: List[str], repeated_trials: int = 3 - ): - random.seed(42) # Ensure reproducibility - - patterns = [ - [None, lora_adapter_paths[0], lora_adapter_paths[1]], - [lora_adapter_paths[0], None, lora_adapter_paths[1]], - [lora_adapter_paths[0], lora_adapter_paths[1], None], - [None, lora_adapter_paths[1], None], - [None, None, None], - ] - - batches = [ - [random.choice(pattern) for _ in range(3)] - for pattern in patterns - for _ in range(repeated_trials) - ] - - return batches - - def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]): - for model_case in model_cases: - for torch_dtype in TORCH_DTYPES: - max_new_tokens = 32 - base_path = model_case.base - lora_adapter_paths = [a.name for a in model_case.adaptors] - assert len(lora_adapter_paths) >= 2 - - print( - f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---" - ) - - # Initialize runners - srt_runner = SRTRunner( - base_path, - torch_dtype=torch_dtype, - model_type="generation", - lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], - max_loras_per_batch=len(lora_adapter_paths) + 1, - sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch. - attention_backend="torch_native", - ) - hf_runner = HFRunner( - base_path, torch_dtype=torch_dtype, model_type="generation" - ) - - batches = self._create_test_samples(lora_adapter_paths) - with srt_runner, hf_runner: - for i, lora_paths in enumerate(batches, start=1): - prompts = [ - random.choice(TEST_MULTIPLE_BATCH_PROMPTS) for _ in range(3) - ] - print( - f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}" - ) - - ensure_reproducibility() - srt_outputs = srt_runner.batch_forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - - ensure_reproducibility() - hf_outputs = hf_runner.forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - - print("SRT outputs:", [s for s in srt_outputs.output_strs]) - print("HF outputs:", [s for s in hf_outputs.output_strs]) - - for srt_out, hf_out in zip( - srt_outputs.output_strs, hf_outputs.output_strs - ): - srt_str = srt_out.strip() - hf_str = hf_out.strip() - rouge_tol = model_case.rouge_l_tolerance - rouge_score = calculate_rouge_l([srt_str], [hf_str])[0] - if rouge_score < rouge_tol: - raise AssertionError( - f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " - f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'" - ) - - print(f"--- Batch {i} Comparison Passed --- ") - def test_ci_lora_models(self): - self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS) + run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS) def test_all_lora_models(self): if is_in_ci(): @@ -158,7 +39,7 @@ def test_all_lora_models(self): continue filtered_models.append(model_case) - self._run_lora_multiple_batch_on_model_cases(filtered_models) + run_lora_multiple_batch_on_model_cases(filtered_models) if __name__ == "__main__": diff --git a/test/srt/lora/test_lora_qwen3.py b/test/srt/lora/test_lora_qwen3.py index beab18cf4a73..50904e5a83be 100644 --- a/test/srt/lora/test_lora_qwen3.py +++ b/test/srt/lora/test_lora_qwen3.py @@ -13,15 +13,11 @@ # ============================================================================== import multiprocessing as mp -import os -import random import unittest -from typing import List -from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase, ensure_reproducibility +from utils import LoRAAdaptor, LoRAModelCase, run_lora_multiple_batch_on_model_cases -from sglang.test.runners import HFRunner, SRTRunner -from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci +from sglang.test.test_utils import CustomTestCase LORA_MODELS_QWEN3 = [ LoRAModelCase( @@ -41,164 +37,9 @@ ] -TEST_MULTIPLE_BATCH_PROMPTS = [ - """ - ### Instruction: - Tell me about llamas and alpacas - ### Response: - Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. - ### Question 2: - What do you know about llamas? - ### Answer: - """, - """ - ### Instruction: - Write a poem about the transformers Python library. - Mention the word "large language models" in that poem. - ### Response: - The Transformers are large language models, - They're used to make predictions on text. - """, - "AI is a field of computer science focused on", - "Computer science is the study of", - "Write a short story.", - "What are the main components of a computer?", -] - - class TestLoRAQwen3(CustomTestCase): - def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]): - for model_case in model_cases: - for torch_dtype in TORCH_DTYPES: - max_new_tokens = 32 - base_path = model_case.base - lora_adapter_paths = [a.name for a in model_case.adaptors] - assert len(lora_adapter_paths) >= 2 - - batches = [ - ( - [ - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - ], - [ - None, - lora_adapter_paths[0], - lora_adapter_paths[1], - ], - ), - ( - [ - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - ], - [ - lora_adapter_paths[0], - None, - lora_adapter_paths[1], - ], - ), - ( - [ - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - ], - [lora_adapter_paths[0], lora_adapter_paths[1], None], - ), - ( - [ - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - ], - [None, lora_adapter_paths[1], None], - ), - ( - [ - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - random.choice(TEST_MULTIPLE_BATCH_PROMPTS), - ], - [None, None, None], - ), - ] - - print( - f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---" - ) - - # Initialize runners - ensure_reproducibility() - srt_runner = SRTRunner( - base_path, - torch_dtype=torch_dtype, - model_type="generation", - lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], - max_loras_per_batch=len(lora_adapter_paths) + 1, - sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch. - attention_backend="torch_native", - ) - - ensure_reproducibility() - hf_runner = HFRunner( - base_path, - torch_dtype=torch_dtype, - model_type="generation", - patch_model_do_sample_false=True, - ) - - with srt_runner, hf_runner: - for i, (prompts, lora_paths) in enumerate(batches): - print( - f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}" - ) - - srt_outputs = srt_runner.batch_forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - - hf_outputs = hf_runner.forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - - print("SRT outputs:", [s for s in srt_outputs.output_strs]) - print("HF outputs:", [s for s in hf_outputs.output_strs]) - - for srt_out, hf_out in zip( - srt_outputs.output_strs, hf_outputs.output_strs - ): - srt_str = srt_out.strip() - hf_str = hf_out.strip() - rouge_tol = model_case.rouge_l_tolerance - rouge_score = calculate_rouge_l([srt_str], [hf_str])[0] - if rouge_score < rouge_tol: - raise AssertionError( - f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " - f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'" - ) - - print(f"--- Batch {i+1} Comparison Passed --- ") - def test_ci_lora_models(self): - self._run_lora_multiple_batch_on_model_cases(LORA_MODELS_QWEN3) - - def test_all_lora_models(self): - if is_in_ci(): - return - qwen_filtered_models = [] - for model_case in LORA_MODELS_QWEN3: - if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base: - continue - qwen_filtered_models.append(model_case) - - self._run_lora_multiple_batch_on_model_cases(qwen_filtered_models) + run_lora_multiple_batch_on_model_cases(LORA_MODELS_QWEN3) if __name__ == "__main__": diff --git a/test/srt/lora/test_lora_spec_decoding.py b/test/srt/lora/test_lora_spec_decoding.py new file mode 100644 index 000000000000..06faed0a9b4e --- /dev/null +++ b/test/srt/lora/test_lora_spec_decoding.py @@ -0,0 +1,63 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing as mp +import unittest + +from utils import ( + CI_MULTI_LORA_MODELS, + LoRAAdaptor, + LoRAModelCase, + run_lora_multiple_batch_on_model_cases, +) + +from sglang.test.test_utils import CustomTestCase + +LORA_MODELS_QWEN3 = [ + LoRAModelCase( + base="Qwen/Qwen3-4B", + adaptors=[ + LoRAAdaptor( + name="nissenj/Qwen3-4B-lora-v2", + prefill_tolerance=3e-1, + ), + LoRAAdaptor( + name="y9760210/Qwen3-4B-lora_model", + prefill_tolerance=3e-1, + ), + ], + max_loras_per_batch=2, + ), +] + + +class TestLoRASpecDecoding(CustomTestCase): + def test_qwen(self): + run_lora_multiple_batch_on_model_cases( + LORA_MODELS_QWEN3, use_spec_decoding=True + ) + + def test_llama(self): + run_lora_multiple_batch_on_model_cases( + CI_MULTI_LORA_MODELS, use_spec_decoding=True + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/lora/utils.py b/test/srt/lora/utils.py index 95089d33c85e..ff1b190e4c9b 100644 --- a/test/srt/lora/utils.py +++ b/test/srt/lora/utils.py @@ -395,3 +395,177 @@ def ensure_reproducibility(): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.use_deterministic_algorithms(True) + + +TEST_MULTIPLE_BATCH_PROMPTS = [ + """ + ### Instruction: + Tell me about llamas and alpacas + ### Response: + Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. + ### Question 2: + What do you know about llamas? + ### Answer: + """, + """ + ### Instruction: + Write a poem about the transformers Python library. + Mention the word "large language models" in that poem. + ### Response: + The Transformers are large language models, + They're used to make predictions on text. + """, + "AI is a field of computer science focused on", + "Computer science is the study of", + "Write a short story.", + "What are the main components of a computer?", +] + + +def create_multiple_batch_test_samples( + prompts: List[str], lora_adapter_paths: List[str] +): + random.seed(42) + + return [ + ( + [ + random.choice(prompts), + random.choice(prompts), + random.choice(prompts), + ], + [ + None, + lora_adapter_paths[0], + lora_adapter_paths[1], + ], + ), + ( + [ + random.choice(prompts), + random.choice(prompts), + random.choice(prompts), + ], + [ + lora_adapter_paths[0], + None, + lora_adapter_paths[1], + ], + ), + ( + [ + random.choice(prompts), + random.choice(prompts), + random.choice(prompts), + ], + [lora_adapter_paths[0], lora_adapter_paths[1], None], + ), + ( + [ + random.choice(prompts), + random.choice(prompts), + random.choice(prompts), + ], + [None, lora_adapter_paths[1], None], + ), + ( + [ + random.choice(prompts), + random.choice(prompts), + random.choice(prompts), + ], + [None, None, None], + ), + ] + + +def run_lora_multiple_batch_on_model_cases( + model_cases: List[LoRAModelCase], + use_spec_decoding: bool = False, + attention_backend: str = "triton", + disable_cuda_graph: bool = True, + enable_deterministic_inference: bool = True, +): + for model_case in model_cases: + for torch_dtype in TORCH_DTYPES: + max_new_tokens = 32 + base_path = model_case.base + lora_adapter_paths = [a.name for a in model_case.adaptors] + assert len(lora_adapter_paths) >= 2 + + batches = create_multiple_batch_test_samples( + TEST_MULTIPLE_BATCH_PROMPTS, lora_adapter_paths + ) + + print( + f"\n========== Testing multiple batches on base '{base_path}', dtype={torch_dtype} ---" + ) + + # Initialize runners + ensure_reproducibility() + spec_args = ( + {} + if not use_spec_decoding + else { + "speculative_algorithm": "NGRAM", + "speculative_num_draft_tokens": 5, + "speculative_ngram_min_match_window_size": 2, + "speculative_ngram_max_match_window_size": 15, + } + ) + srt_runner = SRTRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], + max_loras_per_batch=len(lora_adapter_paths) + 1, + sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch. + attention_backend=attention_backend, + enable_deterministic_inference=enable_deterministic_inference, + disable_cuda_graph=disable_cuda_graph, + **spec_args, + ) + + ensure_reproducibility() + hf_runner = HFRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + patch_model_do_sample_false=True, + ) + + with srt_runner, hf_runner: + for i, (prompts, lora_paths) in enumerate(batches): + print( + f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}" + ) + + srt_outputs = srt_runner.batch_forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + + hf_outputs = hf_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + + print("SRT outputs:", [s for s in srt_outputs.output_strs]) + print("HF outputs:", [s for s in hf_outputs.output_strs]) + + for srt_out, hf_out in zip( + srt_outputs.output_strs, hf_outputs.output_strs + ): + srt_str = srt_out.strip() + hf_str = hf_out.strip() + rouge_tol = model_case.rouge_l_tolerance + rouge_score = calculate_rouge_l([srt_str], [hf_str])[0] + if rouge_score < rouge_tol: + raise AssertionError( + f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " + f"for base '{base_path}', adaptor '{lora_paths}', prompt: '{prompts}...'" + ) + + print(f"--- Batch {i+1} Comparison Passed --- ") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 9e44604171d5..415e65120353 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -28,6 +28,7 @@ class TestFile: TestFile("lora/test_lora_eviction.py", 240), TestFile("lora/test_lora_update.py", 600), TestFile("lora/test_lora_backend.py", 99), + TestFile("lora/test_lora_spec_decoding.py", 150), TestFile("lora/test_multi_lora_backend.py", 60), TestFile("models/test_compressed_tensors_models.py", 42), TestFile("models/test_cross_encoder_models.py", 100), @@ -35,7 +36,7 @@ class TestFile: TestFile("models/test_encoder_embedding_models.py", 460), TestFile("models/test_generation_models.py", 103), TestFile("models/test_nvidia_nemotron_nano_v2.py", 160), - TestFile("models/test_qwen_models.py", 82), + TestFile("models/test_qwen_models.py", 150), TestFile("models/test_reward_models.py", 132), TestFile("models/test_transformers_models.py", 320), TestFile("models/test_vlm_models.py", 741), From 9a4c23ff3cb8c0019b1c46c723eebe98a76689ee Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Fri, 14 Nov 2025 04:22:46 +0000 Subject: [PATCH 2/3] Fix --- test/srt/lora/test_lora_spec_decoding.py | 12 ++++++++++-- test/srt/lora/utils.py | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/test/srt/lora/test_lora_spec_decoding.py b/test/srt/lora/test_lora_spec_decoding.py index 06faed0a9b4e..8e8ddee2c1d2 100644 --- a/test/srt/lora/test_lora_spec_decoding.py +++ b/test/srt/lora/test_lora_spec_decoding.py @@ -45,12 +45,20 @@ class TestLoRASpecDecoding(CustomTestCase): def test_qwen(self): run_lora_multiple_batch_on_model_cases( - LORA_MODELS_QWEN3, use_spec_decoding=True + LORA_MODELS_QWEN3, + attention_backend="triton", + use_spec_decoding=True, + disable_cuda_graph=True, + enable_deterministic_inference=True, ) def test_llama(self): run_lora_multiple_batch_on_model_cases( - CI_MULTI_LORA_MODELS, use_spec_decoding=True + CI_MULTI_LORA_MODELS, + attention_backend="triton", + use_spec_decoding=True, + disable_cuda_graph=True, + enable_deterministic_inference=True, ) diff --git a/test/srt/lora/utils.py b/test/srt/lora/utils.py index ff1b190e4c9b..a88f41946ffc 100644 --- a/test/srt/lora/utils.py +++ b/test/srt/lora/utils.py @@ -482,9 +482,9 @@ def create_multiple_batch_test_samples( def run_lora_multiple_batch_on_model_cases( model_cases: List[LoRAModelCase], use_spec_decoding: bool = False, - attention_backend: str = "triton", + attention_backend: str = "torch_native", disable_cuda_graph: bool = True, - enable_deterministic_inference: bool = True, + enable_deterministic_inference: bool = False, ): for model_case in model_cases: for torch_dtype in TORCH_DTYPES: From cc3989a373cf71b2a4e1a0f141a9443610a0a5a0 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sun, 16 Nov 2025 00:46:26 +0000 Subject: [PATCH 3/3] Address comments & refactor code. --- .../sglang/srt/lora/backend/base_backend.py | 9 ++-- .../srt/lora/backend/chunked_backend.py | 30 +++++++++++-- .../sglang/srt/lora/backend/triton_backend.py | 44 ++++++++++++------- python/sglang/srt/lora/lora_manager.py | 20 +-------- python/sglang/srt/server_args.py | 7 +++ python/sglang/test/runners.py | 3 +- 6 files changed, 66 insertions(+), 47 deletions(-) diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index e63826386e30..77654c4b2d32 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -1,8 +1,7 @@ -from typing import Optional, Tuple, Union +from typing import Tuple, Union import torch -from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -97,7 +96,6 @@ def run_gate_up_lora( def init_cuda_graph_batch_info( self, - cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int, num_tokens_per_bs: int, ): @@ -119,7 +117,7 @@ def prepare_lora_batch( weight_indices: list[int], lora_ranks: list[int], scalings: list[float], - batch_info: Optional[LoRABatchInfo] = None, + use_cuda_graph: bool, ): """Prepare the lora weights and batch info for current forward batch. @@ -131,7 +129,6 @@ def prepare_lora_batch( weight_indices: list of indices of lora weights to be applied for current batch lora_ranks: list of lora ranks corresponding to weight_indices scalings: list of scaling factors corresponding to weight_indices - batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own - internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode) + use_cuda_graph: whether to use CUDA Graph for this batch """ pass diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index 1e4fb1ab9e91..f17f473cbdfd 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from sglang.srt.lora.backend.base_backend import BaseLoRABackend @@ -160,13 +158,36 @@ def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int: chunk_size = 16 return min(self.max_chunk_size, chunk_size) + def init_cuda_graph_batch_info( + self, + max_bs_in_cuda_graph: int, + num_tokens_per_bs: int, + ): + max_num_segments = ( + (num_tokens_per_bs + MIN_CHUNK_SIZE - 1) // MIN_CHUNK_SIZE + ) * max_bs_in_cuda_graph + max_num_tokens = max_bs_in_cuda_graph * num_tokens_per_bs + with torch.device("cuda"): + self.cuda_graph_batch_info = LoRABatchInfo( + bs=max_bs_in_cuda_graph, + use_cuda_graph=True, + seg_lens=torch.zeros(max_num_segments, dtype=torch.int32), + seg_indptr=torch.zeros(max_num_segments + 1, dtype=torch.int32), + weight_indices=torch.zeros(max_num_segments, dtype=torch.int32), + permutation=torch.zeros(max_num_tokens, dtype=torch.int32), + lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), + scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), + num_segments=None, # Set per batch + max_len=None, # Not used in CSGMV backend + ) + def prepare_lora_batch( self, forward_batch: ForwardBatch, weight_indices: list[int], lora_ranks: list[int], scalings: list[float], - batch_info: Optional[LoRABatchInfo] = None, + use_cuda_graph: bool, ): chunk_size = self._determine_chunk_size(forward_batch) @@ -188,7 +209,7 @@ def prepare_lora_batch( scalings, dtype=torch.float, pin_memory=True, device="cpu" ) - if batch_info is None: + if not use_cuda_graph: batch_info = LoRABatchInfo( bs=forward_batch.batch_size, num_segments=num_segments, @@ -213,6 +234,7 @@ def prepare_lora_batch( seg_lens=None, ) else: + batch_info = self.cuda_graph_batch_info batch_info.bs = forward_batch.batch_size batch_info.num_segments = num_segments batch_info.max_len = chunk_size diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index b11a46b6d854..1c2e319dd397 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from sglang.srt.lora.backend.base_backend import BaseLoRABackend @@ -98,19 +96,32 @@ def run_gate_up_lora( def init_cuda_graph_batch_info( self, - cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int, num_tokens_per_bs: int, ): - # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant - # across batches. - cuda_graph_batch_info.max_len = num_tokens_per_bs - cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(num_tokens_per_bs) - torch.cumsum( - cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph], - dim=0, - out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1], - ) + with torch.device("cuda"): + self.cuda_graph_batch_info = LoRABatchInfo( + bs=max_bs_in_cuda_graph, + use_cuda_graph=True, + num_segments=None, + seg_lens=torch.full( + (max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32 + ), + seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32), + max_len=num_tokens_per_bs, + weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), + lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), + scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), + permutation=None, + ) + + # Initialize seg_indptr for CUDA graph as they remain constant + # across batches. + torch.cumsum( + self.cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph], + dim=0, + out=self.cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1], + ) def prepare_lora_batch( self, @@ -118,7 +129,7 @@ def prepare_lora_batch( weight_indices: list[int], lora_ranks: list[int], scalings: list[float], - batch_info: Optional[LoRABatchInfo] = None, + use_cuda_graph: bool, ): # Use pinned memory to avoid synchronizations during host-to-device transfer weight_indices_tensor = torch.tensor( @@ -133,10 +144,11 @@ def prepare_lora_batch( bs = forward_batch.batch_size - if batch_info is not None: + if use_cuda_graph: assert ( - batch_info.use_cuda_graph - ), "batch_info.use_cuda_graph must be True when batch_info is provided" + self.cuda_graph_batch_info is not None + ), "CUDA Graph batch info is not initialized." + batch_info = self.cuda_graph_batch_info batch_info.bs = forward_batch.batch_size batch_info.num_segments = forward_batch.batch_size else: diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 60ca99b1bb1d..5d0d68d51fcc 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -29,7 +29,6 @@ from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.mem_pool import LoRAMemoryPool from sglang.srt.lora.utils import ( - LoRABatchInfo, LoRAType, get_layer_id, get_normalized_target_modules, @@ -99,24 +98,7 @@ def init_cuda_graph_batch_info( self, max_bs_in_cuda_graph: int, num_tokens_per_bs: int ): self.max_bs_in_cuda_graph = max_bs_in_cuda_graph - with torch.device("cuda"): - self.cuda_graph_batch_info = LoRABatchInfo( - bs=max_bs_in_cuda_graph, - use_cuda_graph=True, - num_segments=None, - seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), - seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32), - max_len=None, - weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), - lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), - scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), - permutation=torch.zeros( - max_bs_in_cuda_graph * num_tokens_per_bs, dtype=torch.int32 - ), - ) - self.lora_backend.init_cuda_graph_batch_info( - cuda_graph_batch_info=self.cuda_graph_batch_info, max_bs_in_cuda_graph=max_bs_in_cuda_graph, num_tokens_per_bs=num_tokens_per_bs, ) @@ -302,7 +284,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): weight_indices=weight_indices, lora_ranks=lora_ranks, scalings=scalings, - batch_info=self.cuda_graph_batch_info if use_cuda_graph else None, + use_cuda_graph=use_cuda_graph, ) def update_lora_info(self): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6711ae5f26f8..0623af2b2fef 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3844,6 +3844,13 @@ def check_lora_server_args(self): ) if self.enable_lora: + # Validate compatibility with speculative decoding + if self.speculative_algorithm not in ["NGRAM", None]: + raise ValueError( + "Currently LoRA is only compatible with NGRAM speculative decoding." + ) + + # Parse lora_paths if isinstance(self.lora_paths, list): lora_paths = self.lora_paths self.lora_paths = [] diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index a7205d4c0bc8..564018cdefa6 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -557,8 +557,7 @@ def __init__( spec_kwargs["speculative_num_steps"] = speculative_num_steps spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens - - if speculative_algorithm == "NGRAM": + elif speculative_algorithm == "NGRAM": spec_kwargs["speculative_algorithm"] = speculative_algorithm spec_kwargs["speculative_ngram_min_match_window_size"] = ( speculative_ngram_min_match_window_size