diff --git a/python/sglang/srt/mem_cache/multimodal_cache.py b/python/sglang/srt/mem_cache/multimodal_cache.py index 604048700536..1795cf2a5e7e 100644 --- a/python/sglang/srt/mem_cache/multimodal_cache.py +++ b/python/sglang/srt/mem_cache/multimodal_cache.py @@ -1,4 +1,5 @@ import abc +import logging from collections import OrderedDict from typing import List, Optional @@ -67,6 +68,9 @@ def _get_tensor_size(embedding: torch.Tensor): return embedding.element_size() * embedding.numel() +logger = logging.getLogger(__name__) + + class MultiModalStaticCache(MultimodalCache): """ A server-level cache for multimodal embedding. @@ -100,11 +104,18 @@ def set( self.mm_cache.move_to_end(mm_hash) return True data_size = _get_tensor_size(embedding) + evicted_bytes = 0 while self.current_size + data_size > self.max_size: if not self.mm_cache: return False lru_hash, lru_embedding = self.mm_cache.popitem(last=False) - self.current_size -= _get_tensor_size(lru_embedding) + freed = _get_tensor_size(lru_embedding) + self.current_size -= freed + evicted_bytes += freed + if evicted_bytes > 0: + logger.debug( + f"Cache eviction: evicted {evicted_bytes} bytes, remaining size: {self.current_size}/{self.max_size} bytes" + ) self.mm_cache[mm_hash] = embedding self.current_size += data_size diff --git a/test/srt/models/test_vlm_models.py b/test/srt/models/test_vlm_models.py index c195bbd81cfb..fc4ef6d5ccc3 100644 --- a/test/srt/models/test_vlm_models.py +++ b/test/srt/models/test_vlm_models.py @@ -24,6 +24,21 @@ class TestVLMModels(MMMUMultiModelTestBase): + def _detect_eviction_in_logs(self, log_output: str) -> tuple[bool, int]: + """Detect if eviction events occurred in the log output.""" + eviction_keyword = "Cache eviction" + + eviction_detected = False + eviction_count = 0 + + for line in log_output.split("\n"): + if eviction_keyword in line: + eviction_detected = True + eviction_count += 1 + print(f"Eviction detected: {line.strip()}") + + return eviction_detected, eviction_count + def test_vlm_mmmu_benchmark(self): """Test VLM models against MMMU benchmark.""" models_to_test = MODELS @@ -34,6 +49,45 @@ def test_vlm_mmmu_benchmark(self): for model in models_to_test: self._run_vlm_mmmu_test(model, "./logs") + def test_vlm_mmmu_benchmark_with_small_cache(self): + """Test VLM models with a tiny embedding cache to exercise eviction logic.""" + models_to_test = MODELS + + if is_in_ci(): + models_to_test = [random.choice(MODELS)] + + for model in models_to_test: + custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"} + server_output = self._run_vlm_mmmu_test( + model, + "./logs_small_cache", + test_name=" with small embedding cache (evict test)", + custom_env=custom_env, + log_level="debug", + capture_output=True, + ) + print("Server output:\n", server_output) + + eviction_detected, eviction_count = self._detect_eviction_in_logs( + server_output + ) + + self.assertTrue( + eviction_detected, + ( + "Expected eviction events to be detected with small cache (5MB), " + "but none found. Cache size may be too large for the workload or " + "eviction logic may not be working." + ), + ) + + print( + f"Eviction detection summary: {eviction_count} eviction events detected" + ) + + if eviction_detected: + print("✅ Eviction logic successfully triggered and detected!") + if __name__ == "__main__": # Define and parse arguments here, before unittest.main