Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/sglang/srt/mem_cache/multimodal_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import logging
from collections import OrderedDict
from typing import List, Optional

Expand Down Expand Up @@ -67,6 +68,9 @@ def _get_tensor_size(embedding: torch.Tensor):
return embedding.element_size() * embedding.numel()


logger = logging.getLogger(__name__)


class MultiModalStaticCache(MultimodalCache):
"""
A server-level cache for multimodal embedding.
Expand Down Expand Up @@ -100,11 +104,18 @@ def set(
self.mm_cache.move_to_end(mm_hash)
return True
data_size = _get_tensor_size(embedding)
evicted_bytes = 0
while self.current_size + data_size > self.max_size:
if not self.mm_cache:
return False
lru_hash, lru_embedding = self.mm_cache.popitem(last=False)
self.current_size -= _get_tensor_size(lru_embedding)
freed = _get_tensor_size(lru_embedding)
self.current_size -= freed
evicted_bytes += freed
if evicted_bytes > 0:
logger.debug(
f"Cache eviction: evicted {evicted_bytes} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
)

self.mm_cache[mm_hash] = embedding
self.current_size += data_size
Expand Down
54 changes: 54 additions & 0 deletions test/srt/models/test_vlm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@


class TestVLMModels(MMMUMultiModelTestBase):
def _detect_eviction_in_logs(self, log_output: str) -> tuple[bool, int]:
"""Detect if eviction events occurred in the log output."""
eviction_keyword = "Cache eviction"

eviction_detected = False
eviction_count = 0

for line in log_output.split("\n"):
if eviction_keyword in line:
eviction_detected = True
eviction_count += 1
print(f"Eviction detected: {line.strip()}")

return eviction_detected, eviction_count

def test_vlm_mmmu_benchmark(self):
"""Test VLM models against MMMU benchmark."""
models_to_test = MODELS
Expand All @@ -34,6 +49,45 @@ def test_vlm_mmmu_benchmark(self):
for model in models_to_test:
self._run_vlm_mmmu_test(model, "./logs")

def test_vlm_mmmu_benchmark_with_small_cache(self):
"""Test VLM models with a tiny embedding cache to exercise eviction logic."""
models_to_test = MODELS

if is_in_ci():
models_to_test = [random.choice(MODELS)]

for model in models_to_test:
custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"}
server_output = self._run_vlm_mmmu_test(
model,
"./logs_small_cache",
test_name=" with small embedding cache (evict test)",
custom_env=custom_env,
log_level="debug",
capture_output=True,
)
print("Server output:\n", server_output)

eviction_detected, eviction_count = self._detect_eviction_in_logs(
server_output
)

self.assertTrue(
eviction_detected,
(
"Expected eviction events to be detected with small cache (5MB), "
"but none found. Cache size may be too large for the workload or "
"eviction logic may not be working."
),
)

print(
f"Eviction detection summary: {eviction_count} eviction events detected"
)

if eviction_detected:
print("✅ Eviction logic successfully triggered and detected!")


if __name__ == "__main__":
# Define and parse arguments here, before unittest.main
Expand Down
Loading