diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 60283080bc09..4ece53c98f77 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -17,7 +17,7 @@ MultimodalDataItem, MultimodalInputs, ) -from sglang.srt.mem_cache.multimodal_cache import MultiModalCache +from sglang.srt.mem_cache.multimodal_cache import MultiModalStaticCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once @@ -283,17 +283,12 @@ def pad_input_tokens( return ret_input_ids -embedding_cache: Optional[MultiModalCache] = None +embedding_cache: Optional[MultiModalStaticCache] = None -def init_embedding_cache(max_size: int = 0): +def init_mm_embedding_cache(max_size: int = 0): global embedding_cache - embedding_cache = MultiModalCache(max_size) - - -def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int: - hash_list = [item.hash for item in embedding_items] - return hash(tuple(hash_list)) + embedding_cache = MultiModalStaticCache(max_size) def get_embedding_chunk( @@ -380,14 +375,15 @@ def _get_chunked_prefill_embedding( embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]] items_offset = items_offset_list[i] assert items_offset is not None, items_offset - embedding_items_hash = get_embedding_hash(embedding_items_per_req) # if all items has been prefixed, we do not need to calculate embedding if all([offset_end < prefix_length[i] for _, offset_end in items_offset]): continue - embedding_per_req = embedding_cache.get(embedding_items_hash) + item_hashes = [item.hash for item in embedding_items] + embedding_items_hash = MultiModalStaticCache.combine_hashes(item_hashes) + embedding_per_req = embedding_cache.get(item_hashes) if embedding_per_req is None: embedding_per_req = data_embedding_func(embedding_items_per_req) - if not embedding_cache.put(embedding_items_hash, embedding_per_req): + if not embedding_cache.set(embedding_items_hash, embedding_per_req): print_warning_once( "Multimodal embedding cache is full. This typically occurs when a single " "embedding exceeds the cache size limit. Consider increasing the " diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2938511112a0..6b408a85ca08 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -112,7 +112,7 @@ UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, ) -from sglang.srt.managers.mm_utils import init_embedding_cache +from sglang.srt.managers.mm_utils import init_mm_embedding_cache from sglang.srt.managers.overlap_utils import FutureMap from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -846,7 +846,7 @@ def init_memory_pool_and_cache(self): ) embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100")) - init_embedding_cache(embedding_cache_size * 1024 * 1024) + init_mm_embedding_cache(embedding_cache_size * 1024 * 1024) def init_disaggregation(self): self.transfer_backend = TransferBackend( diff --git a/python/sglang/srt/mem_cache/multimodal_cache.py b/python/sglang/srt/mem_cache/multimodal_cache.py index 42c31a8e8661..604048700536 100644 --- a/python/sglang/srt/mem_cache/multimodal_cache.py +++ b/python/sglang/srt/mem_cache/multimodal_cache.py @@ -1,46 +1,111 @@ -import logging +import abc from collections import OrderedDict +from typing import List, Optional import torch -# Set up logging for cache behavior -logger = logging.getLogger(__name__) +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator -class MultiModalCache: - """MultiModalCache is used to store vlm encoder results with LRU eviction""" +class MultimodalCache(abc.ABC): + @abc.abstractmethod + def __init__( + self, + ): ... + + @staticmethod + def combine_hashes(mm_hashes: List[int]) -> Optional[int]: + """ + Get a combined hash from individual mm item hashes + """ + if not mm_hashes: + return None + return hash(tuple(mm_hashes)) + + @abc.abstractmethod + def get( + self, mm_hashes: List[int], combined_hash: Optional[int] = None + ) -> Optional[torch.Tensor]: + """ + Extract the embedding with the hash-ids of the queried items. Try combined hash first, if missed, fallback to individual hashes + The returned tensor may not be contiguous + """ + raise NotImplementedError() + + @abc.abstractmethod + def set( + self, + mm_hash: int, + embedding: torch.Tensor, + mm_embedding_allocator: BaseTokenToKVPoolAllocator, + ) -> bool: + """ + Set the embedding to the pre-allocated locations with a hash id + """ + raise NotImplementedError() + + @abc.abstractmethod + def has(self, mm_hash: int) -> bool: + raise NotImplementedError() + + @abc.abstractmethod + def free( + self, mm_hash: int, mm_embedding_allocator: BaseTokenToKVPoolAllocator + ) -> bool: + raise NotImplementedError() + + @abc.abstractmethod + def clear(self): + raise NotImplementedError() + + @abc.abstractmethod + def available_size(self): + raise NotImplementedError() + + +def _get_tensor_size(embedding: torch.Tensor): + return embedding.element_size() * embedding.numel() + + +class MultiModalStaticCache(MultimodalCache): + """ + A server-level cache for multimodal embedding. + Embeddings are computed prior, and this cache does not really pre-alloc + """ def __init__( self, max_size: int, ): + super().__init__() self.max_size = max_size self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict() self.current_size = 0 - def _allocate(self, embedding_size: int) -> bool: - """Allocate space by evicting least recently used entries""" - evictions = 0 - while self.current_size + embedding_size > self.max_size and self.mm_cache: - _, old_embedding = self.mm_cache.popitem(last=False) - evicted_size = self._get_tensor_size(old_embedding) - self.current_size -= evicted_size - evictions += evicted_size - - if evictions > 0: - logger.debug( - f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes" - ) - - if self.current_size + embedding_size > self.max_size: - return False - return True + def get( + self, mm_hashes: List[int], combined_hash: Optional[int] = None + ) -> Optional[torch.Tensor]: + combined_hash = self.combine_hashes(mm_hashes) + # MultiModalStaticCache does not fallback to individual item lookup + + embedding = self.mm_cache.get(combined_hash) + if embedding is not None: + self.mm_cache.move_to_end(combined_hash) + return embedding + + def set( + self, mm_hash: int, embedding: torch.Tensor, loc: Optional[torch.Tensor] = None + ) -> bool: + if mm_hash in self.mm_cache: + self.mm_cache.move_to_end(mm_hash) + return True + data_size = _get_tensor_size(embedding) + while self.current_size + data_size > self.max_size: + if not self.mm_cache: + return False + lru_hash, lru_embedding = self.mm_cache.popitem(last=False) + self.current_size -= _get_tensor_size(lru_embedding) - def put(self, mm_hash: int, embedding: torch.Tensor) -> bool: - data_size = self._get_tensor_size(embedding) - # Lazy free cache if not enough space - if not self._allocate(data_size): - return False self.mm_cache[mm_hash] = embedding self.current_size += data_size return True @@ -48,20 +113,21 @@ def put(self, mm_hash: int, embedding: torch.Tensor) -> bool: def has(self, mm_hash: int) -> bool: return mm_hash in self.mm_cache - def get(self, mm_hash: int) -> torch.Tensor: - """Get embedding and update LRU order""" - if mm_hash in self.mm_cache: - # Move to end (most recently used) - self.mm_cache.move_to_end(mm_hash) - return self.mm_cache[mm_hash] - return None + def free( + self, mm_hash: int, mm_embedding_allocator: BaseTokenToKVPoolAllocator + ) -> bool: + if mm_hash not in self.mm_cache: + return False + old_embedding = self.mm_cache.pop(mm_hash) + self.current_size -= _get_tensor_size(old_embedding) + return True def clear(self): self.mm_cache.clear() self.current_size = 0 - def _get_tensor_size(self, embedding: torch.Tensor): - return embedding.element_size() * embedding.numel() - def __len__(self): return len(self.mm_cache) + + def available_size(self): + return self.__len__() diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 40230c645f4e..7bb04bbc34e4 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -13,13 +13,11 @@ import traceback import urllib.request import weakref -from concurrent.futures import ThreadPoolExecutor from functools import wraps from io import BytesIO from json import dumps from typing import Any, Callable, List, Optional, Tuple, Type, Union -import numpy as np import pybase64 import requests from IPython.display import HTML, display diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index ef9a2ad51b09..4dfa6ff45523 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -14,7 +14,7 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest -from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache +from sglang.srt.managers.mm_utils import embed_mm_inputs, init_mm_embedding_cache from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, @@ -182,7 +182,7 @@ def setUpClass(cls): .eval() .to(cls.device) ) - init_embedding_cache() + init_mm_embedding_cache() async def test_vlm_embedding_output(self): """ @@ -288,7 +288,7 @@ def setUpClass(cls): .eval() .to(cls.device) ) - init_embedding_cache() + init_mm_embedding_cache() async def test_vlm_embedding_output(self): """