Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def __init__(
disable_overlap_schedule: bool = False,
disable_custom_all_reduce: bool = False,
torchao_config: Optional[str] = None,
sleep_on_idle=False,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
Expand Down Expand Up @@ -540,6 +541,7 @@ def __init__(
disable_overlap_schedule=disable_overlap_schedule,
cuda_graph_max_bs=4,
disable_custom_all_reduce=disable_custom_all_reduce,
sleep_on_idle=sleep_on_idle,
**spec_kwargs,
)

Expand Down
96 changes: 41 additions & 55 deletions test/srt/models/lora/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest
from typing import List

import torch
from utils import (
ALL_OTHER_MULTI_LORA_MODELS,
CI_MULTI_LORA_MODELS,
Expand Down Expand Up @@ -46,16 +47,44 @@
The Transformers are large language models,
They're used to make predictions on text.
""",
# "AI is a field of computer science focused on", TODO: Add it back after fixing its bug
"AI is a field of computer science focused on",
"Computer science is the study of",
"Write a short story.",
"What are the main components of a computer?",
]


class TestLoRA(CustomTestCase):
def _create_test_samples(
self, lora_adapter_paths: List[str], repeated_trials: int = 3
):
random.seed(42) # Ensure reproducibility

patterns = [
[None, lora_adapter_paths[0], lora_adapter_paths[1]],
[lora_adapter_paths[0], None, lora_adapter_paths[1]],
[lora_adapter_paths[0], lora_adapter_paths[1], None],
[None, lora_adapter_paths[1], None],
[None, None, None],
]

batches = [
[random.choice(pattern) for _ in range(3)]
for pattern in patterns
for _ in range(repeated_trials)
]

return batches

def ensure_reproducibility(self):
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True)

def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):

for model_case in model_cases:
for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32
Expand All @@ -64,57 +93,6 @@ def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCas
lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2

batches = [
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[
None,
lora_adapter_paths[0],
lora_adapter_paths[1],
],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[
lora_adapter_paths[0],
None,
lora_adapter_paths[1],
],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[lora_adapter_paths[0], lora_adapter_paths[1], None],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[None, lora_adapter_paths[1], None],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[None, None, None],
),
]

print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
)
Expand All @@ -128,23 +106,31 @@ def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCas
max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend,
disable_radix_cache=True,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native",
)
hf_runner = HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
)

batches = self._create_test_samples(lora_adapter_paths)
with srt_runner, hf_runner:
for i, (prompts, lora_paths) in enumerate(batches):
for i, lora_paths in enumerate(batches, start=1):
prompts = [
random.choice(TEST_MULTIPLE_BATCH_PROMPTS) for _ in range(3)
]
print(
f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}"
f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}"
)

self.ensure_reproducibility()
srt_outputs = srt_runner.batch_forward(
prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
)

self.ensure_reproducibility()
hf_outputs = hf_runner.forward(
prompts,
max_new_tokens=max_new_tokens,
Expand All @@ -167,7 +153,7 @@ def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCas
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
)

print(f"--- Batch {i+1} Comparison Passed --- ")
print(f"--- Batch {i} Comparison Passed --- ")

def test_ci_lora_models(self):
self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS)
Expand Down
2 changes: 1 addition & 1 deletion test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestFile:

suites = {
"per-commit": [
TestFile("models/lora/test_lora.py", 76),
TestFile("models/lora/test_lora.py", 200),
TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250),
Expand Down
Loading