diff --git a/docs/advanced_features/lora.ipynb b/docs/advanced_features/lora.ipynb index 3a5e93148252..ff2c32fd6540 100644 --- a/docs/advanced_features/lora.ipynb +++ b/docs/advanced_features/lora.ipynb @@ -35,6 +35,8 @@ "\n", "* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n", "\n", + "* `lora_eviction_policy`: LoRA adapter eviction policy when GPU memory pool is full. `lru`: Least Recently Used (default, better cache efficiency). `fifo`: First-In-First-Out.\n", + "\n", "* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we support Triton LoRA backend (`triton`) and Chunked SGMV backend (`csgmv`). In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 0bc20b416884..3673ba4d8cdb 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -213,6 +213,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: | = | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool} | None | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None | +| `--lora-eviction-policy` | LoRA adapter eviction policy when GPU memory pool is full. `lru`: Least Recently Used (better cache efficiency). `fifo`: First-In-First-Out. | lru | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | ## Kernel backend diff --git a/python/sglang/srt/lora/eviction_policy.py b/python/sglang/srt/lora/eviction_policy.py new file mode 100644 index 000000000000..7d1f5f91adfd --- /dev/null +++ b/python/sglang/srt/lora/eviction_policy.py @@ -0,0 +1,139 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Eviction policies for LoRA adapter memory management. +""" + +import logging +import time +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + + +class EvictionPolicy(ABC): + """Abstract base class for LoRA adapter eviction policies.""" + + @abstractmethod + def mark_used(self, uid: Optional[str]) -> None: + """Marks an adapter as used.""" + pass + + @abstractmethod + def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]: + """Selects an adapter to evict from candidates.""" + pass + + @abstractmethod + def remove(self, uid: Optional[str]) -> None: + """Removes an adapter from the policy's tracking.""" + pass + + +class LRUEvictionPolicy(EvictionPolicy): + """LRU eviction policy - evicts the least recently used adapter.""" + + def __init__(self): + self.access_order = OrderedDict() # key=uid, value=last_access_time + self.total_accesses = 0 + self.eviction_count = 0 + + def mark_used(self, uid: Optional[str]) -> None: + if uid is not None: + current_time = time.monotonic() + # Remove and re-add to move to end (most recent) + self.access_order.pop(uid, None) + self.access_order[uid] = current_time + self.total_accesses += 1 + logger.debug(f"LoRA {uid} marked as used at {current_time}") + + def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]: + """Select the least recently used adapter from candidates.""" + # Base model (currently None, will be replaced with special UID in future) + # always has lowest priority - evict it first if available + BASE_MODEL_UID = None # TODO: Replace with special UID constant + if BASE_MODEL_UID in candidates: + logger.debug(f"Selected base model for eviction (LRU)") + self.eviction_count += 1 + return BASE_MODEL_UID + + # Iterate through access_order (oldest first) to find LRU victim + for uid in list(self.access_order.keys()): + if uid in candidates: + logger.debug(f"Selected LoRA {uid} for eviction (LRU)") + self.eviction_count += 1 + return uid + + # Should never reach here if candidates is non-empty + assert False, f"Failed to select LRU victim from candidates: {candidates}" + + def remove(self, uid: Optional[str]) -> None: + if uid is not None: + self.access_order.pop(uid, None) + logger.debug(f"Removed LoRA {uid} from LRU tracking") + + +class FIFOEvictionPolicy(EvictionPolicy): + """FIFO eviction policy - for backward compatibility.""" + + def __init__(self): + self.insertion_order = ( + OrderedDict() + ) # key=uid, OrderedDict maintains insertion order + self.eviction_count = 0 + + def mark_used(self, uid: Optional[str]) -> None: + """For FIFO, we only track insertion order (not access time).""" + if uid is not None and uid not in self.insertion_order: + self.insertion_order[uid] = ( + True # Value unused, OrderedDict tracks insertion order + ) + + def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]: + """Select the first inserted adapter from candidates.""" + # Base model (currently None, will be replaced with special UID in future) + # always has lowest priority - evict it first if available + BASE_MODEL_UID = None # TODO: Replace with special UID constant + if BASE_MODEL_UID in candidates: + logger.debug(f"Selected base model for eviction (FIFO)") + self.eviction_count += 1 + return BASE_MODEL_UID + + # Iterate through insertion_order (oldest first) to find FIFO victim + for uid in list(self.insertion_order.keys()): + if uid in candidates: + logger.debug(f"Selected LoRA {uid} for eviction (FIFO)") + self.eviction_count += 1 + return uid + + # Should never reach here if candidates is non-empty + assert False, f"Failed to select FIFO victim from candidates: {candidates}" + + def remove(self, uid: Optional[str]) -> None: + if uid is not None: + self.insertion_order.pop(uid, None) + + +def get_eviction_policy(policy_name: str) -> EvictionPolicy: + """Factory function to create eviction policy instances.""" + policies = { + "fifo": FIFOEvictionPolicy, + "lru": LRUEvictionPolicy, + } + if policy_name not in policies: + raise ValueError(f"Unknown eviction policy: {policy_name}") + return policies[policy_name]() diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 5247f2c588b6..30d3386e28d9 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -68,6 +68,9 @@ def __init__( self.tp_size: int = tp_size self.tp_rank: int = tp_rank + # Store eviction policy from server args + self.eviction_policy = server_args.lora_eviction_policy + # LoRA backend for running sgemm kernels logger.info(f"Using {lora_backend} as backend of LoRA kernels.") backend_type = get_backend_from_name(lora_backend) @@ -131,6 +134,16 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: lora_ref.lora_id not in self.loras ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend." + if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1: + return self.create_lora_update_result( + success=False, + error_message=( + f"Already have {self.num_pinned_loras} pinned adapters, " + f"max allowed is {self.max_loras_per_batch - 1} (reserving 1 slot for dynamic use). " + f"Please unpin some adapters or increase max_loras_per_batch." + ), + ) + try: # load configs new_adapter = LoRAConfig(lora_ref.lora_path) @@ -420,6 +433,7 @@ def init_memory_pool(self): max_lora_rank=self.max_lora_rank, target_modules=self.target_modules, base_model=self.base_model, + eviction_policy=self.eviction_policy, ) def set_lora_module(self, module_name, module): diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 107f9f508d94..f6375361700e 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -4,6 +4,7 @@ import torch from sglang.srt.distributed import divide +from sglang.srt.lora.eviction_policy import get_eviction_policy from sglang.srt.lora.layers import BaseLayerWithLoRA from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig @@ -54,6 +55,7 @@ def __init__( max_lora_rank: int, target_modules: Set[str], base_model: torch.nn.Module, + eviction_policy: str, ): self.base_hf_config: AutoConfig = base_hf_config self.num_layer: int = base_hf_config.num_hidden_layers @@ -64,6 +66,9 @@ def __init__( self.max_lora_rank: int = max_lora_rank self.target_modules: Set[str] = target_modules + # Initialize eviction policy + self.eviction_policy = get_eviction_policy(eviction_policy) + # Both A_buffer and B_buffer maps lora weight names to its buffer space. # A_buffer contains num_layer number of row-major tensors with shape # (max_loras_per_batch, stacked_num * max_lora_dim, input_dim) @@ -189,31 +194,50 @@ def prepare_lora_batch( lora_refs: Dict[str, LoRARef], ): def get_available_buffer_slot(): + # 1. Prioritize empty slots for buffer_id in range(self.max_loras_per_batch): - # Prioritize empty slots if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT: return buffer_id + # 2. Memory pool is full, need to evict using policy + candidates = set() + for buffer_id in range(self.max_loras_per_batch): uid = self.buffer_id_to_uid[buffer_id] - # Evict unneeded lora - if uid not in cur_uids: - # Skip pinned LoRAs - # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future. - if uid is not None: - lora_ref = lora_refs.get(uid) - if lora_ref is not None and lora_ref.pinned: - continue - - self.uid_to_buffer_id.pop(uid) - logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.") - self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT - return buffer_id + # Skip if this adapter is needed by current batch + # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future. + if uid in cur_uids: + continue + + # Skip if this adapter is pinned (base model cannot be pinned, so can be evicted) + if uid is not None: + lora_ref = lora_refs.get(uid) + if lora_ref and lora_ref.pinned: + continue + candidates.add(uid) - raise ValueError( - "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch." + if not candidates: + raise ValueError( + "No available buffer slots found. Please ensure the number of active (pinned) loras is less than max_loras_per_batch." + ) + + # Select victim using eviction policy + victim_uid = self.eviction_policy.select_victim(candidates) + + # Evict the selected victim + victim_buffer_id = self.uid_to_buffer_id[victim_uid] + self.uid_to_buffer_id.pop(victim_uid) + self.eviction_policy.remove(victim_uid) + self.buffer_id_to_uid[victim_buffer_id] = EMPTY_SLOT + logger.debug( + f"Evicting LoRA {victim_uid} from buffer slot {victim_buffer_id}." ) + return victim_buffer_id + + # Mark all adapters in current batch as used (for LRU tracking) + for uid in cur_uids: + self.eviction_policy.mark_used(uid) for uid in cur_uids: if uid not in self.uid_to_buffer_id: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e81e6c53b750..7fd1e93a6b24 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -122,6 +122,8 @@ DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"] +DEFAULT_LORA_EVICTION_POLICY = "lru" + NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"] RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"] @@ -303,6 +305,7 @@ class ServerArgs: ] = None max_loaded_loras: Optional[int] = None max_loras_per_batch: int = 8 + lora_eviction_policy: str = DEFAULT_LORA_EVICTION_POLICY lora_backend: str = "triton" max_lora_chunk_size: Optional[int] = 16 @@ -2121,6 +2124,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.max_loaded_loras, help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.", ) + parser.add_argument( + "--lora-eviction-policy", + type=str, + default=DEFAULT_LORA_EVICTION_POLICY, + choices=["lru", "fifo"], + help="LoRA adapter eviction policy when memory pool is full. 'lru': Least Recently Used (default, better cache efficiency). 'fifo': First-In-First-Out.", + ) parser.add_argument( "--lora-backend", type=str, diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 9e64457fc020..dc7efe5285eb 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -519,6 +519,7 @@ def __init__( lora_target_modules: Optional[List[str]] = None, enable_lora: Optional[bool] = None, max_loaded_loras: Optional[int] = None, + lora_eviction_policy: str = "lru", ): self.model_type = model_type self.is_generation = model_type == "generation" @@ -565,6 +566,7 @@ def __init__( lora_target_modules=lora_target_modules, enable_lora=enable_lora, max_loaded_loras=max_loaded_loras, + lora_eviction_policy=lora_eviction_policy, **spec_kwargs, ) diff --git a/test/srt/lora/test_lora_eviction_policy.py b/test/srt/lora/test_lora_eviction_policy.py new file mode 100644 index 000000000000..18ff8f46743e --- /dev/null +++ b/test/srt/lora/test_lora_eviction_policy.py @@ -0,0 +1,190 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Unit tests for LoRA eviction policies. +Tests LRU and FIFO eviction behavior. +""" + +import unittest + +from sglang.srt.lora.eviction_policy import get_eviction_policy + + +class TestLoRAEvictionPolicy(unittest.TestCase): + """Unit tests for LoRA eviction policies.""" + + def _test_eviction_policy( + self, policy_name, access_sequence, candidates, expected_victim + ): + """ + Helper to test eviction policy with given access pattern. + + Args: + policy_name: Name of eviction policy ("lru" or "fifo") + access_sequence: List of adapter IDs in access order + candidates: Set of adapter IDs that can be evicted + expected_victim: Expected adapter ID to be evicted + """ + policy = get_eviction_policy(policy_name) + + # Simulate access pattern + for adapter_id in access_sequence: + policy.mark_used(adapter_id) + + # Select victim from candidates + victim = policy.select_victim(candidates) + self.assertEqual( + victim, + expected_victim, + f"{policy_name.upper()}: Expected {expected_victim}, got {victim}", + ) + + def test_lru_basic(self): + """Test LRU selects least recently used adapter.""" + self._test_eviction_policy( + "lru", + access_sequence=["lora1", "lora2", "lora3", "lora4"], + candidates={"lora1", "lora2", "lora3", "lora4"}, + expected_victim="lora1", + ) + + def test_lru_with_reuse(self): + """Test LRU updates order on reuse.""" + self._test_eviction_policy( + "lru", + access_sequence=["lora1", "lora2", "lora3", "lora4", "lora1"], + candidates={"lora1", "lora2", "lora3", "lora4"}, + expected_victim="lora2", + ) + + def test_lru_multiple_reuse(self): + """Test LRU with multiple reuses.""" + self._test_eviction_policy( + "lru", + access_sequence=["lora1", "lora2", "lora3", "lora1", "lora2"], + candidates={"lora1", "lora2", "lora3"}, + expected_victim="lora3", + ) + + def test_lru_with_subset_candidates(self): + """Test LRU with subset of candidates.""" + self._test_eviction_policy( + "lru", + access_sequence=["lora1", "lora2", "lora3", "lora4"], + candidates={"lora2", "lora3", "lora4"}, + expected_victim="lora2", + ) + + def test_lru_base_model_priority(self): + """Test LRU prioritizes base model for eviction.""" + self._test_eviction_policy( + "lru", + access_sequence=["lora1", "lora2", "lora3"], + candidates={None, "lora1", "lora2", "lora3"}, + expected_victim=None, + ) + + def test_fifo_basic(self): + """Test FIFO selects first inserted adapter.""" + self._test_eviction_policy( + "fifo", + access_sequence=["lora1", "lora2", "lora3", "lora4"], + candidates={"lora1", "lora2", "lora3", "lora4"}, + expected_victim="lora1", + ) + + def test_fifo_ignores_reuse(self): + """Test FIFO ignores reuse.""" + self._test_eviction_policy( + "fifo", + access_sequence=[ + "lora1", + "lora2", + "lora3", + "lora4", + "lora4", + "lora3", + "lora2", + "lora1", + ], + candidates={"lora1", "lora2", "lora3", "lora4"}, + expected_victim="lora1", + ) + + def test_fifo_with_subset_candidates(self): + """Test FIFO with subset of candidates.""" + self._test_eviction_policy( + "fifo", + access_sequence=["lora1", "lora2", "lora3", "lora4"], + candidates={"lora2", "lora3", "lora4"}, + expected_victim="lora2", + ) + + def test_fifo_base_model_priority(self): + """Test FIFO prioritizes base model for eviction.""" + self._test_eviction_policy( + "fifo", + access_sequence=["lora1", "lora2", "lora3"], + candidates={None, "lora1", "lora2", "lora3"}, + expected_victim=None, + ) + + def test_policy_remove(self): + """Test that remove() correctly updates internal state.""" + lru = get_eviction_policy("lru") + lru.mark_used("lora1") + lru.mark_used("lora2") + lru.mark_used("lora3") + + # Remove lora1, so lora2 becomes LRU + lru.remove("lora1") + victim = lru.select_victim({"lora1", "lora2", "lora3"}) + self.assertEqual(victim, "lora2") + + def test_eviction_policy_factory(self): + """Test eviction policy factory function.""" + # Test valid policies + lru = get_eviction_policy("lru") + fifo = get_eviction_policy("fifo") + + self.assertIsNotNone(lru) + self.assertIsNotNone(fifo) + + # Test invalid policy + with self.assertRaises(ValueError): + get_eviction_policy("invalid_policy") + + def test_lru_vs_fifo_behavior(self): + """Test that LRU and FIFO behave differently.""" + access_sequence = ["lora1", "lora2", "lora3", "lora1"] + candidates = {"lora1", "lora2", "lora3"} + + lru = get_eviction_policy("lru") + for adapter_id in access_sequence: + lru.mark_used(adapter_id) + lru_victim = lru.select_victim(candidates) + + fifo = get_eviction_policy("fifo") + for adapter_id in access_sequence: + fifo.mark_used(adapter_id) + fifo_victim = fifo.select_victim(candidates) + + self.assertNotEqual(lru_victim, fifo_victim) + self.assertEqual(lru_victim, "lora2") + self.assertEqual(fifo_victim, "lora1") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index c6536259452f..abd01450d588 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -20,6 +20,8 @@ class TestFile: TestFile("hicache/test_hicache_mla.py", 127), TestFile("hicache/test_hicache_storage.py", 127), TestFile("lora/test_lora.py", 200), + TestFile("lora/test_lora_eviction.py", 200), + TestFile("lora/test_lora_eviction_policy.py", 200), TestFile("lora/test_lora_backend.py", 99), TestFile("lora/test_lora_eviction.py", 200), TestFile("lora/test_lora_qwen3.py", 97),