Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
from sglang.srt.mem_cache.multimodal_cache import MultiModalStaticCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
Expand Down Expand Up @@ -283,17 +283,12 @@ def pad_input_tokens(
return ret_input_ids


embedding_cache: Optional[MultiModalCache] = None
embedding_cache: Optional[MultiModalStaticCache] = None


def init_embedding_cache(max_size: int = 0):
def init_mm_embedding_cache(max_size: int = 0):
global embedding_cache
embedding_cache = MultiModalCache(max_size)


def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int:
hash_list = [item.hash for item in embedding_items]
return hash(tuple(hash_list))
embedding_cache = MultiModalStaticCache(max_size)


def get_embedding_chunk(
Expand Down Expand Up @@ -380,14 +375,15 @@ def _get_chunked_prefill_embedding(
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
items_offset = items_offset_list[i]
assert items_offset is not None, items_offset
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
# if all items has been prefixed, we do not need to calculate embedding
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
continue
embedding_per_req = embedding_cache.get(embedding_items_hash)
item_hashes = [item.hash for item in embedding_items]
embedding_items_hash = MultiModalStaticCache.combine_hashes(item_hashes)
embedding_per_req = embedding_cache.get(item_hashes)
if embedding_per_req is None:
embedding_per_req = data_embedding_func(embedding_items_per_req)
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
if not embedding_cache.set(embedding_items_hash, embedding_per_req):
print_warning_once(
"Multimodal embedding cache is full. This typically occurs when a single "
"embedding exceeds the cache size limit. Consider increasing the "
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.mm_utils import init_mm_embedding_cache
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
Expand Down Expand Up @@ -846,7 +846,7 @@ def init_memory_pool_and_cache(self):
)

embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
init_embedding_cache(embedding_cache_size * 1024 * 1024)
init_mm_embedding_cache(embedding_cache_size * 1024 * 1024)

def init_disaggregation(self):
self.transfer_backend = TransferBackend(
Expand Down
140 changes: 103 additions & 37 deletions python/sglang/srt/mem_cache/multimodal_cache.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,133 @@
import logging
import abc
from collections import OrderedDict
from typing import List, Optional

import torch

# Set up logging for cache behavior
logger = logging.getLogger(__name__)
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator


class MultiModalCache:
"""MultiModalCache is used to store vlm encoder results with LRU eviction"""
class MultimodalCache(abc.ABC):
@abc.abstractmethod
def __init__(
self,
): ...

@staticmethod
def combine_hashes(mm_hashes: List[int]) -> Optional[int]:
"""
Get a combined hash from individual mm item hashes
"""
if not mm_hashes:
return None
return hash(tuple(mm_hashes))

@abc.abstractmethod
def get(
self, mm_hashes: List[int], combined_hash: Optional[int] = None
) -> Optional[torch.Tensor]:
"""
Extract the embedding with the hash-ids of the queried items. Try combined hash first, if missed, fallback to individual hashes
The returned tensor may not be contiguous
"""
raise NotImplementedError()

@abc.abstractmethod
def set(
self,
mm_hash: int,
embedding: torch.Tensor,
mm_embedding_allocator: BaseTokenToKVPoolAllocator,
) -> bool:
"""
Set the embedding to the pre-allocated locations with a hash id
"""
raise NotImplementedError()

@abc.abstractmethod
def has(self, mm_hash: int) -> bool:
raise NotImplementedError()

@abc.abstractmethod
def free(
self, mm_hash: int, mm_embedding_allocator: BaseTokenToKVPoolAllocator
) -> bool:
raise NotImplementedError()

@abc.abstractmethod
def clear(self):
raise NotImplementedError()

@abc.abstractmethod
def available_size(self):
raise NotImplementedError()


def _get_tensor_size(embedding: torch.Tensor):
return embedding.element_size() * embedding.numel()


class MultiModalStaticCache(MultimodalCache):
"""
A server-level cache for multimodal embedding.
Embeddings are computed prior, and this cache does not really pre-alloc
"""

def __init__(
self,
max_size: int,
):
super().__init__()
self.max_size = max_size
self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
self.current_size = 0

def _allocate(self, embedding_size: int) -> bool:
"""Allocate space by evicting least recently used entries"""
evictions = 0
while self.current_size + embedding_size > self.max_size and self.mm_cache:
_, old_embedding = self.mm_cache.popitem(last=False)
evicted_size = self._get_tensor_size(old_embedding)
self.current_size -= evicted_size
evictions += evicted_size

if evictions > 0:
logger.debug(
f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
)

if self.current_size + embedding_size > self.max_size:
return False
return True
def get(
self, mm_hashes: List[int], combined_hash: Optional[int] = None
) -> Optional[torch.Tensor]:
combined_hash = self.combine_hashes(mm_hashes)
# MultiModalStaticCache does not fallback to individual item lookup

embedding = self.mm_cache.get(combined_hash)
if embedding is not None:
self.mm_cache.move_to_end(combined_hash)
return embedding

def set(
self, mm_hash: int, embedding: torch.Tensor, loc: Optional[torch.Tensor] = None
) -> bool:
if mm_hash in self.mm_cache:
self.mm_cache.move_to_end(mm_hash)
return True
data_size = _get_tensor_size(embedding)
while self.current_size + data_size > self.max_size:
if not self.mm_cache:
return False
lru_hash, lru_embedding = self.mm_cache.popitem(last=False)
self.current_size -= _get_tensor_size(lru_embedding)

def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
data_size = self._get_tensor_size(embedding)
# Lazy free cache if not enough space
if not self._allocate(data_size):
return False
self.mm_cache[mm_hash] = embedding
self.current_size += data_size
return True

def has(self, mm_hash: int) -> bool:
return mm_hash in self.mm_cache

def get(self, mm_hash: int) -> torch.Tensor:
"""Get embedding and update LRU order"""
if mm_hash in self.mm_cache:
# Move to end (most recently used)
self.mm_cache.move_to_end(mm_hash)
return self.mm_cache[mm_hash]
return None
def free(
self, mm_hash: int, mm_embedding_allocator: BaseTokenToKVPoolAllocator
) -> bool:
if mm_hash not in self.mm_cache:
return False
old_embedding = self.mm_cache.pop(mm_hash)
self.current_size -= _get_tensor_size(old_embedding)
return True

def clear(self):
self.mm_cache.clear()
self.current_size = 0

def _get_tensor_size(self, embedding: torch.Tensor):
return embedding.element_size() * embedding.numel()

def __len__(self):
return len(self.mm_cache)

def available_size(self):
return self.__len__()
2 changes: 0 additions & 2 deletions python/sglang/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
import traceback
import urllib.request
import weakref
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from io import BytesIO
from json import dumps
from typing import Any, Callable, List, Optional, Tuple, Type, Union

import numpy as np
import pybase64
import requests
from IPython.display import HTML, display
Expand Down
6 changes: 3 additions & 3 deletions test/srt/test_vlm_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_mm_embedding_cache
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
Expand Down Expand Up @@ -182,7 +182,7 @@ def setUpClass(cls):
.eval()
.to(cls.device)
)
init_embedding_cache()
init_mm_embedding_cache()

async def test_vlm_embedding_output(self):
"""
Expand Down Expand Up @@ -288,7 +288,7 @@ def setUpClass(cls):
.eval()
.to(cls.device)
)
init_embedding_cache()
init_mm_embedding_cache()

async def test_vlm_embedding_output(self):
"""
Expand Down
Loading