From 9c9d83a2aef5c8f9f8fae5be7a09f5fb61b92d71 Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Tue, 9 Dec 2025 21:45:12 +0800 Subject: [PATCH 1/5] Init interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 晟海 --- .../mem_cache/sparsity/algorithms/__init__.py | 15 + .../sparsity/algorithms/base_algorithm.py | 175 +++++++++++ .../sparsity/algorithms/deepseek_nsa.py | 74 +++++ .../algorithms/page_wise_algorithm.py | 276 ++++++++++++++++++ .../srt/sparsity/test_knorm_page_algorithm.py | 168 +++++++++++ 5 files changed, 708 insertions(+) create mode 100644 python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py create mode 100644 python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py create mode 100644 python/sglang/srt/mem_cache/sparsity/algorithms/deepseek_nsa.py create mode 100644 python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py create mode 100644 test/srt/sparsity/test_knorm_page_algorithm.py diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py new file mode 100644 index 000000000000..412f8adbd6db --- /dev/null +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py @@ -0,0 +1,15 @@ +from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import ( + BaseSparseAlgorithm, + SparseMode, +) +from sglang.srt.mem_cache.sparsity.algorithms.deepseek_nsa import DeepSeekNSAAlgorithm +from sglang.srt.mem_cache.sparsity.algorithms.page_wise_algorithm import ( + KnormPageAlgorithm, +) + +__all__ = [ + "BaseSparseAlgorithm", + "SparseMode", + "KnormPageAlgorithm", + "DeepSeekNSAAlgorithm", +] diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py b/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py new file mode 100644 index 000000000000..fe4c1e109132 --- /dev/null +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py @@ -0,0 +1,175 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import TYPE_CHECKING, Any, Optional + +import torch + +if TYPE_CHECKING: + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +class SparseMode(Enum): + """Sparse attention granularity mode.""" + + PAGE_WISE = "page_wise" + TOKEN_WISE = "token_wise" + DEEPSEEK_TOKEN_WISE = "deepseek_token_wise" + + +class BaseSparseAlgorithm(ABC): + """ + Abstract base class for sparse attention algorithms. + + This class provides a unified interface for implementing various retrievable KVCache + compression algorithms, supporting both page-wise and token-wise sparsity. + + References: + - ChunkKV: https://arxiv.org/abs/2502.00299 + - Quest: https://arxiv.org/pdf/2406.10774 + - PQCache: https://arxiv.org/abs/2407.12820 + - SnapKV: https://arxiv.org/pdf/2404.14469 + - Look-ahead QCache: https://arxiv.org/pdf/2505.20334 + - and more... + """ + + def __init__(self, config, device: torch.device, **kwargs): + self.config = config + self.device = device + self.req_to_token_pool = None + self.states = None + + def initialize_representation_pool( + self, + start_layer: int, + end_layer: int, + token_to_kv_pool, + req_to_token_pool, + states, + ): + """ + Initialize algorithm-specific representation pool and set context. + + Called once during SparseCoordinator initialization. Algorithms allocate + their own representation tensors and store references to context. + + Algorithm-specific implementations: + - ChunkKV: Allocate chunk scores [num_chunks, 1] for tracking semantic chunk importance + - Quest: Allocate page representations [num_pages, repr_dim] via key pooling + - PQCache: Allocate centroids [n_subvec, n_centroids, subvec_dim] and token codes [num_tokens, n_subvec] + - SnapKV: Allocate voting scores [num_tokens] and selected positions mask for retention strategy + - Look-ahead QCache: Allocate importance scores [num_tokens], eviction mask, and optional pseudo query cache [cache_size, hidden_dim] + """ + self.req_to_token_pool = req_to_token_pool + self.states = states + + @abstractmethod + def get_sparse_mode(self) -> SparseMode: + """ + Return the sparsity granularity mode. + + Returns: + SparseMode.PAGE_WISE: Selection operates on page/chunk level + SparseMode.TOKEN_WISE: Selection operates on individual token level + + Algorithm-specific modes: + - ChunkKV: PAGE_WISE (selects important semantic chunks while preserving linguistic structures) + - Quest: PAGE_WISE (selects important pages/blocks) + - PQCache: TOKEN_WISE (selects important tokens via centroid similarity) + - SnapKV: TOKEN_WISE (retention-based: keeps voted important prefix tokens + observation window) + - Look-ahead QCache: TOKEN_WISE (eviction-based: removes tokens with low pseudo query importance) + """ + pass + + def construct_representations( + self, + layer_id: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + k_buffer: torch.Tensor, + forward_batch: "ForwardBatch", + ): + """ + Construct initial representations during prefill phase. + + Called at every layer during forward pass. Algorithm internally decides + whether to perform construction based on self.states.repr_constructed. + Typically only constructs once per request during prefill/extend phase. + + Algorithm-specific implementations: + - ChunkKV: Compute chunk importance scores via aggregated key L2 norms within semantic chunks + - Quest: Compute page representations via mean pooling of keys within each page + - PQCache: Run K-means clustering to generate centroids and assign each token to nearest centroid + - SnapKV: Select observation window (recent tokens), compute attention weights, aggregate via voting to identify important prefix positions, apply 1D pooling to preserve context + - Look-ahead QCache: Generate pseudo lookahead query (e.g., mean of last k queries), compute KV importance scores, mark low-importance KVs for eviction + """ + pass + + def update_representations( + self, + layer_id: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + k_buffer: torch.Tensor, + forward_batch: "ForwardBatch", + ): + """ + Incrementally update representations during decode phase. + + Called at every layer during forward pass. Algorithm internally decides + whether to update based on: + - self.states.repr_constructed[req_id]: Whether initial construction done + - self.states.last_extracted_token[req_id]: Last processed position + - Current seq_lens: To detect new tokens/pages + + + Algorithm-specific implementations: + - ChunkKV: Incrementally compute importance scores for newly generated chunks during decode + - Quest: Incrementally compute representations for newly generated pages during decode + - PQCache: Assign new tokens to existing centroids (no centroid update during decode) + - SnapKV: Optional: periodically re-run voting with sliding observation window (typically static after prefill) + - Look-ahead QCache: Periodically regenerate pseudo queries and re-evaluate importance scores to adapt to generation dynamics + """ + pass + + @abstractmethod + def retrieve_topk( + self, + queries: torch.Tensor, + layer_id: int, + req_pool_indices: torch.Tensor, + sparse_mask: torch.Tensor, + attn_metadata: Optional[Any], + **kwargs, + ) -> tuple: + """ + Retrieve top-k important KV indices for sparse attention. + + Called before attention computation at each layer. Uses current query + and pre-computed representations to select the most important subset + of KV cache for attention computation. + + Args: + queries: [bs, num_heads, head_dim] Current query vectors + layer_id: Current layer index + req_pool_indices: [bs] Request pool indices + sparse_mask: [bs] bool, which requests need sparse attention + attn_metadata: Attention metadata (contains seq_lens, etc.) + **kwargs: Algorithm-specific arguments + + Returns: + selected_indices: [bs, max_selected] Selected page/token indices, padded with -1 + valid_lengths: [bs] Actual number of selected indices per request + + Algorithm-specific implementations: + - ChunkKV: Select top-k chunks based on pre-computed importance scores with layer-wise index reuse + - Quest: Compute query-page similarity using current query and stored page representations, select top-k pages + - PQCache: Calculate query-centroid similarity, use centroid scores to rank tokens, select top-k tokens + - SnapKV: Return union of voted important prefix positions (with clustered neighbors) and observation window tokens + - Look-ahead QCache: Return KVs not marked for eviction (eviction based on pseudo query importance evaluation) + + Note: + - For PAGE_WISE mode: Returns page indices + - For TOKEN_WISE mode: Returns token indices + - Indices are logical positions that will be mapped to physical KV cache by BackendAdaptor + """ + pass diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/deepseek_nsa.py b/python/sglang/srt/mem_cache/sparsity/algorithms/deepseek_nsa.py new file mode 100644 index 000000000000..a319c2c17f66 --- /dev/null +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/deepseek_nsa.py @@ -0,0 +1,74 @@ +import logging +from typing import TYPE_CHECKING, Any, Optional + +import nvtx +import torch + +from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import ( + BaseSparseAlgorithm, + SparseMode, +) + +if TYPE_CHECKING: + from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +logger = logging.getLogger(__name__) + + +class DeepSeekNSAAlgorithm(BaseSparseAlgorithm): + """Sparse attention algorithm for DeepSeek NSA.""" + + def __init__(self, config, device: torch.device, **kwargs): + super().__init__(config, device, **kwargs) + self.index_topk = getattr(config, "index_topk", 2048) + logger.info(f"DeepSeekNSAAlgorithm initialized: index_topk={self.index_topk}") + + def get_sparse_mode(self) -> SparseMode: + return SparseMode.DEEPSEEK_TOKEN_WISE + + @nvtx.annotate("DeepSeekNSAAlgorithm.retrieve_topk", color="green") + def retrieve_topk( + self, + queries: torch.Tensor, + layer_id: int, + req_pool_indices: torch.Tensor, + sparse_mask: torch.Tensor, + attn_metadata: Optional[Any], + **kwargs, + ) -> tuple: + indexer: Optional["Indexer"] = kwargs.get("indexer") + forward_batch: Optional["ForwardBatch"] = kwargs.get("forward_batch") + x, q_lora, positions = ( + kwargs.get("x"), + kwargs.get("q_lora"), + kwargs.get("positions"), + ) + + if any(v is None for v in [indexer, x, q_lora, positions, forward_batch]): + raise ValueError("Required: indexer, x, q_lora, positions, forward_batch") + + try: + # Using the nsa's original indexer to get the topk indices. + topk_indices = indexer( + x=x, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=layer_id, + ) + + if topk_indices is None: + return self._empty_result(queries.shape[0], queries.device) + + return topk_indices, None + except Exception as e: + logger.error(f"Layer {layer_id} NSA indexer failed: {e}", exc_info=True) + return self._empty_result(queries.shape[0], queries.device) + + def _empty_result(self, batch_size: int, device: torch.device) -> tuple: + selected_indices = torch.full( + (batch_size, self.index_topk), -1, dtype=torch.int32, device=device + ) + valid_lengths = torch.zeros(batch_size, dtype=torch.int32, device=device) + return selected_indices, valid_lengths diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py b/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py new file mode 100644 index 000000000000..f655e874d75f --- /dev/null +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py @@ -0,0 +1,276 @@ +# SPDX-License-Identifier: Apache-2.0 +import logging +from typing import Any, Optional + +import torch + +from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import ( + BaseSparseAlgorithm, + SparseMode, +) + +logger = logging.getLogger(__name__) + + +class KnormPageAlgorithm(BaseSparseAlgorithm): + """ + KnormPageAlgorithm: Page-wise sparse attention with ChunkKV-style scoring. + + This implementation combines page-wise attention with ChunkKV scoring: + - Pages (chunks) are scored based on key L2 norms (sum across tokens) + - TopK pages are selected based on pre-computed scores + - Recent pages are always included + + Based on ChunkKV (https://arxiv.org/abs/2502.00299). + + Note: This is an experimental/example implementation for demonstrating + how to integrate algorithms into the sparse framework. + Not production-ready - use for reference and testing purposes only. + """ + + def __init__(self, config, device: torch.device, **kwargs): + super().__init__(config, device, **kwargs) + self.compression_ratio = getattr(config, "compression_ratio", 0.2) + self.page_size = getattr(config, "page_size", 64) + self.num_recent_pages = getattr(config, "num_recent_pages", 4) + self.page_scores = {} + + def get_sparse_mode(self) -> SparseMode: + return SparseMode.PAGE_WISE + + def initialize_representation_pool( + self, + start_layer: int, + end_layer: int, + token_to_kv_pool, + req_to_token_pool, + states, + ): + """Initialize page score representation pool for each layer.""" + super().initialize_representation_pool( + start_layer, end_layer, token_to_kv_pool, req_to_token_pool, states + ) + self.start_layer = start_layer + self.end_layer = end_layer + total_num_tokens = token_to_kv_pool.get_key_buffer(start_layer).shape[0] + total_num_pages = (total_num_tokens + self.page_size - 1) // self.page_size + + # Create page score storage: [num_pages, 1] per layer + for layer_id in range(start_layer, end_layer): + self.page_scores[layer_id] = torch.zeros( + (total_num_pages, 1), dtype=torch.float32, device=self.device + ) + logger.info( + f"Initialized page score storage: {total_num_pages} pages, " + f"{end_layer - start_layer} layers" + ) + + def construct_representations( + self, + layer_id, + req_pool_indices, + seq_lens, + k_buffer, + forward_batch, + ): + if not forward_batch.forward_mode.is_extend(): + return + + # Check which requests need construction + num_pages = seq_lens // self.page_size + valid_mask = ( + ~self.states.repr_constructed[req_pool_indices] + & (seq_lens >= self.states.prompt_lens[req_pool_indices]) + & (num_pages > 0) + ) + + if not valid_mask.any(): + return + + # Compute page scores + self.compute_and_update_page_representations( + layer_id, + req_pool_indices[valid_mask], + seq_lens[valid_mask], + 0, + num_pages[valid_mask], + k_buffer, + ) + + # Update states + if layer_id == self.end_layer - 1: + success_indices = req_pool_indices[valid_mask] + self.states.repr_constructed[success_indices] = True + self.states.last_extracted_token[success_indices] = ( + seq_lens[valid_mask] // self.page_size * self.page_size + ) + + def update_representations( + self, + layer_id, + req_pool_indices, + seq_lens, + k_buffer, + forward_batch, + ): + if not forward_batch.forward_mode.is_decode_or_idle(): + return + + # Check if new pages were generated (compare current page to last extracted page) + start_page = ( + self.states.last_extracted_token[req_pool_indices] // self.page_size + ) + end_page = seq_lens // self.page_size + valid_mask = self.states.repr_constructed[req_pool_indices] & ( + start_page < end_page + ) # New page(s) generated + + if not valid_mask.any(): + return + + # Compute page scores for new pages + self.compute_and_update_page_representations( + layer_id, + req_pool_indices[valid_mask], + seq_lens[valid_mask], + start_page[valid_mask], + end_page[valid_mask], + k_buffer, + ) + + # Update states + if layer_id == self.end_layer - 1: + success_indices = req_pool_indices[valid_mask] + self.states.last_extracted_token[success_indices] = ( + seq_lens[valid_mask] // self.page_size * self.page_size + ) + + def retrieve_topk( + self, + queries: torch.Tensor, + layer_id: int, + req_pool_indices: torch.Tensor, + sparse_mask: torch.Tensor, + attn_metadata: Optional[Any], + **kwargs, + ) -> tuple: + bs, device = queries.shape[0], queries.device + seq_lens = attn_metadata.cache_seqlens_int32 + num_pages = (seq_lens + self.page_size - 1) // self.page_size + max_pages = max(int(num_pages.max().item()), 1) + + out_indices = torch.full((bs, max_pages), -1, dtype=torch.int32, device=device) + out_lengths = torch.zeros(bs, dtype=torch.int32, device=device) + + mask = sparse_mask & (num_pages > self.num_recent_pages) + if not mask.any(): + return out_indices, out_lengths + + # Get page scores + page_idx = torch.arange(max_pages, device=device).unsqueeze(0) + page_start_token = self.req_to_token_pool.req_to_token[ + req_pool_indices.unsqueeze(1).expand(bs, max_pages), + (page_idx * self.page_size).clamp( + 0, self.req_to_token_pool.req_to_token.shape[1] - 1 + ), + ] + phys_pages = (page_start_token // self.page_size).clamp( + 0, self.page_scores[layer_id].shape[0] - 1 + ) + scores = self.page_scores[layer_id][phys_pages].squeeze(-1) + + # TopK on history + keep recent + recent_start = (num_pages - self.num_recent_pages).clamp(min=0) + scores.masked_fill_(page_idx >= recent_start.unsqueeze(1), float("-inf")) + + k = max( + int((recent_start.float() * (1 - self.compression_ratio)).max().item()), 1 + ) + topk_idx = torch.topk(scores, k=k, dim=1, sorted=False)[1] + topk_mask = torch.arange(k, device=device).unsqueeze(0) < ( + recent_start * (1 - self.compression_ratio) + ).int().clamp(min=1).unsqueeze(1) + + recent_idx = recent_start.unsqueeze(1) + torch.arange( + self.num_recent_pages, device=device + ) + recent_mask = recent_idx < num_pages.unsqueeze(1) + + # Combine + combined = torch.cat( + [ + torch.where(topk_mask, topk_idx, -1), + torch.where(recent_mask, recent_idx, -1), + ], + dim=1, + ).sort(dim=1)[0] + + out_lengths[:] = torch.where(mask, (combined >= 0).sum(dim=1).int(), 0) + out_indices[:, : combined.shape[1]] = torch.where( + mask.unsqueeze(1), combined, -1 + ) + + if layer_id == 0: + logger.info( + f"Retrieve topk: layer_id={layer_id}, out_indices={out_indices}, out_lengths={out_lengths}" + ) + return out_indices, out_lengths + + def compute_and_update_page_representations( + self, + layer_id: int, + reqs: torch.Tensor, + seq_lens: torch.Tensor, + start_page, + end_page: torch.Tensor, + k_buffer: torch.Tensor, + ): + """Compute and store page scores based on key L2 norms.""" + if isinstance(start_page, int): + start_page = torch.full_like(end_page, start_page) + + device = k_buffer.device + req_to_token = self.req_to_token_pool.req_to_token + n = reqs.shape[0] + max_pages = int((end_page - start_page).max().item()) + + # Build ranges: [n, max_pages] or [n, max_pages, page_size] + pg_off = torch.arange(max_pages, device=device).unsqueeze(0) + pg_id = start_page.unsqueeze(1) + pg_off + pg_mask = pg_id < end_page.unsqueeze(1) + + tok_start = pg_id * self.page_size + tok_off = torch.arange(self.page_size, device=device).view(1, 1, -1) + tok_pos = tok_start.unsqueeze(2) + tok_off + tok_mask = ( + tok_pos + < (tok_start + self.page_size).clamp(max=seq_lens.unsqueeze(1)).unsqueeze(2) + ) & pg_mask.unsqueeze(2) + + # Get physical tokens and compute scores + phys_tok = req_to_token[ + reqs.view(n, 1, 1).expand(n, max_pages, self.page_size), + tok_pos.clamp(0, req_to_token.shape[1] - 1), + ].clamp(0, k_buffer.shape[0] - 1) + + tok_score = k_buffer[phys_tok].norm(dim=-1).sum(dim=-1) + pg_score = (tok_score * tok_mask).sum(dim=2) / tok_mask.sum(dim=2).clamp(min=1) + + # Store to page_scores + phys_pg = ( + req_to_token[ + reqs.unsqueeze(1).expand(n, max_pages), + tok_start.clamp(0, req_to_token.shape[1] - 1), + ] + // self.page_size + ) + idx = pg_mask.nonzero(as_tuple=False) + if idx.numel() > 0: + scores_to_store = ( + pg_score[idx[:, 0], idx[:, 1]] + .unsqueeze(-1) + .to(self.page_scores[layer_id].dtype) + ) + self.page_scores[layer_id][phys_pg[idx[:, 0], idx[:, 1]]] = scores_to_store + if layer_id == 0: + logger.info(f"Compute page scores from {start_page} to {end_page}") diff --git a/test/srt/sparsity/test_knorm_page_algorithm.py b/test/srt/sparsity/test_knorm_page_algorithm.py new file mode 100644 index 000000000000..ebef1cba11cc --- /dev/null +++ b/test/srt/sparsity/test_knorm_page_algorithm.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +import unittest + +import torch + +from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import SparseMode +from sglang.srt.mem_cache.sparsity.algorithms.page_wise_algorithm import ( + KnormPageAlgorithm, +) +from sglang.srt.model_executor.forward_batch_info import ForwardMode + + +class MockConfig: + def __init__(self, compression_ratio=0.2, page_size=64): + self.compression_ratio = compression_ratio + self.page_size = page_size + + +class MockTokenToKVPool: + def __init__(self, num_tokens=1024, num_layers=2, num_heads=8, head_dim=64): + self._k_buffer = { + i: torch.randn( + num_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" + ) + for i in range(num_layers) + } + + def get_key_buffer(self, layer_id): + return self._k_buffer[layer_id] + + +class MockReqToTokenPool: + def __init__(self, max_reqs=32, max_tokens=2048, num_physical_tokens=1024): + self.req_to_token = ( + torch.arange(max_reqs * max_tokens, device="cuda").reshape( + max_reqs, max_tokens + ) + % num_physical_tokens + ) + + +class MockStates: + def __init__(self, max_reqs=32): + self.prompt_lens = torch.zeros(max_reqs, dtype=torch.int32, device="cuda") + self.repr_constructed = torch.zeros(max_reqs, dtype=torch.bool, device="cuda") + self.last_extracted_token = torch.zeros( + max_reqs, dtype=torch.int32, device="cuda" + ) + + +class MockForwardBatch: + def __init__(self, mode=ForwardMode.EXTEND): + self.forward_mode = mode + + +class MockAttnMetadata: + def __init__(self, cache_seqlens): + self.cache_seqlens_int32 = cache_seqlens + + +class TestKnormPageAlgorithm(unittest.TestCase): + def setUp(self): + self.device = torch.device("cuda") + self.config = MockConfig(compression_ratio=0.2, page_size=64) + self.algorithm = KnormPageAlgorithm(self.config, self.device) + self.token_to_kv_pool = MockTokenToKVPool(num_tokens=1024, num_layers=2) + self.req_to_token_pool = MockReqToTokenPool( + max_reqs=8, max_tokens=512, num_physical_tokens=1024 + ) + self.states = MockStates(max_reqs=8) + + def test_get_sparse_mode(self): + self.assertEqual(self.algorithm.get_sparse_mode(), SparseMode.PAGE_WISE) + + def test_initialize_representation_pool(self): + self.algorithm.initialize_representation_pool( + start_layer=0, + end_layer=2, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + states=self.states, + ) + + self.assertEqual(len(self.algorithm.page_scores), 2) + self.assertIsNotNone(self.algorithm.req_to_token_pool) + self.assertIsNotNone(self.algorithm.states) + + def test_construct_representations(self): + self.algorithm.initialize_representation_pool( + 0, 2, self.token_to_kv_pool, self.req_to_token_pool, self.states + ) + + req_pool_indices = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([128, 192], dtype=torch.int32, device="cuda") + k_buffer = self.token_to_kv_pool.get_key_buffer(0) + forward_batch = MockForwardBatch(mode=ForwardMode.EXTEND) + + self.states.prompt_lens[req_pool_indices] = seq_lens + + self.algorithm.construct_representations( + layer_id=1, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + k_buffer=k_buffer, + forward_batch=forward_batch, + ) + + self.assertTrue(self.states.repr_constructed[0]) + self.assertTrue(self.states.repr_constructed[1]) + self.assertEqual(self.states.last_extracted_token[0].item(), 128) + self.assertEqual(self.states.last_extracted_token[1].item(), 192) + + def test_update_representations(self): + self.algorithm.initialize_representation_pool( + 0, 2, self.token_to_kv_pool, self.req_to_token_pool, self.states + ) + + req_pool_indices = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([192, 256], dtype=torch.int32, device="cuda") + k_buffer = self.token_to_kv_pool.get_key_buffer(0) + forward_batch = MockForwardBatch(mode=ForwardMode.DECODE) + + self.states.repr_constructed[req_pool_indices] = True + self.states.last_extracted_token[req_pool_indices] = torch.tensor( + [128, 128], dtype=torch.int32, device="cuda" + ) + + self.algorithm.update_representations( + layer_id=1, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + k_buffer=k_buffer, + forward_batch=forward_batch, + ) + + self.assertEqual(self.states.last_extracted_token[0].item(), 192) + self.assertEqual(self.states.last_extracted_token[1].item(), 256) + + def test_retrieve_topk(self): + self.algorithm.initialize_representation_pool( + 0, 2, self.token_to_kv_pool, self.req_to_token_pool, self.states + ) + + req_pool_indices = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + queries = torch.randn(2, 8, 64, dtype=torch.float16, device="cuda") + cache_seqlens = torch.tensor([512, 640], dtype=torch.int32, device="cuda") + sparse_mask = torch.ones(2, dtype=torch.bool, device="cuda") + attn_metadata = MockAttnMetadata(cache_seqlens=cache_seqlens) + + self.algorithm.page_scores[0] = torch.randn( + 16, 1, dtype=torch.float32, device="cuda" + ) + + out_indices, out_lengths = self.algorithm.retrieve_topk( + queries=queries, + layer_id=0, + req_pool_indices=req_pool_indices, + sparse_mask=sparse_mask, + attn_metadata=attn_metadata, + ) + + self.assertEqual(out_indices.shape[0], 2) + self.assertEqual(out_lengths.shape[0], 2) + self.assertTrue((out_lengths > 0).all()) + + +if __name__ == "__main__": + unittest.main() From 4d1a182aa76c77bb46a5cca5fc9fff034fd651f1 Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Thu, 11 Dec 2025 16:44:04 +0800 Subject: [PATCH 2/5] Refactor page wise algo --- .../mem_cache/sparsity/algorithms/__init__.py | 2 + .../algorithms/page_wise_algorithm.py | 179 +++++++++++++----- 2 files changed, 132 insertions(+), 49 deletions(-) diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py index 412f8adbd6db..f4e55371679c 100644 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py @@ -4,12 +4,14 @@ ) from sglang.srt.mem_cache.sparsity.algorithms.deepseek_nsa import DeepSeekNSAAlgorithm from sglang.srt.mem_cache.sparsity.algorithms.page_wise_algorithm import ( + BasePageWiseAlgorithm, KnormPageAlgorithm, ) __all__ = [ "BaseSparseAlgorithm", "SparseMode", + "BasePageWiseAlgorithm", "KnormPageAlgorithm", "DeepSeekNSAAlgorithm", ] diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py b/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py index f655e874d75f..18a05c10c201 100644 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import logging +from abc import abstractmethod from typing import Any, Optional import torch @@ -12,20 +13,21 @@ logger = logging.getLogger(__name__) -class KnormPageAlgorithm(BaseSparseAlgorithm): +class BasePageWiseAlgorithm(BaseSparseAlgorithm): """ - KnormPageAlgorithm: Page-wise sparse attention with ChunkKV-style scoring. + Base class for page-wise sparse attention algorithms. - This implementation combines page-wise attention with ChunkKV scoring: - - Pages (chunks) are scored based on key L2 norms (sum across tokens) - - TopK pages are selected based on pre-computed scores - - Recent pages are always included + Provides common infrastructure for algorithms that operate at page/chunk granularity: + - Generic construct/update flow with state tracking + - TopK retrieval with recent page retention (can be overridden) - Based on ChunkKV (https://arxiv.org/abs/2502.00299). + Subclasses need to implement: + - _initialize_representation_pools(): Initialize algorithm-specific representation pools + - _compute_page_representations(): Compute page scores/representations + - _retrieve_page_scores(): Retrieve page scores for TopK selection - Note: This is an experimental/example implementation for demonstrating - how to integrate algorithms into the sparse framework. - Not production-ready - use for reference and testing purposes only. + Subclasses can optionally override: + - retrieve_topk(): For query-dependent retrieval logic """ def __init__(self, config, device: torch.device, **kwargs): @@ -33,7 +35,6 @@ def __init__(self, config, device: torch.device, **kwargs): self.compression_ratio = getattr(config, "compression_ratio", 0.2) self.page_size = getattr(config, "page_size", 64) self.num_recent_pages = getattr(config, "num_recent_pages", 4) - self.page_scores = {} def get_sparse_mode(self) -> SparseMode: return SparseMode.PAGE_WISE @@ -46,24 +47,16 @@ def initialize_representation_pool( req_to_token_pool, states, ): - """Initialize page score representation pool for each layer.""" super().initialize_representation_pool( start_layer, end_layer, token_to_kv_pool, req_to_token_pool, states ) self.start_layer = start_layer self.end_layer = end_layer + total_num_tokens = token_to_kv_pool.get_key_buffer(start_layer).shape[0] total_num_pages = (total_num_tokens + self.page_size - 1) // self.page_size - # Create page score storage: [num_pages, 1] per layer - for layer_id in range(start_layer, end_layer): - self.page_scores[layer_id] = torch.zeros( - (total_num_pages, 1), dtype=torch.float32, device=self.device - ) - logger.info( - f"Initialized page score storage: {total_num_pages} pages, " - f"{end_layer - start_layer} layers" - ) + self._initialize_representation_pools(start_layer, end_layer, total_num_pages) def construct_representations( self, @@ -72,11 +65,10 @@ def construct_representations( seq_lens, k_buffer, forward_batch, - ): + ) -> torch.Tensor: if not forward_batch.forward_mode.is_extend(): return - # Check which requests need construction num_pages = seq_lens // self.page_size valid_mask = ( ~self.states.repr_constructed[req_pool_indices] @@ -87,8 +79,7 @@ def construct_representations( if not valid_mask.any(): return - # Compute page scores - self.compute_and_update_page_representations( + self._compute_page_representations( layer_id, req_pool_indices[valid_mask], seq_lens[valid_mask], @@ -97,7 +88,6 @@ def construct_representations( k_buffer, ) - # Update states if layer_id == self.end_layer - 1: success_indices = req_pool_indices[valid_mask] self.states.repr_constructed[success_indices] = True @@ -112,24 +102,22 @@ def update_representations( seq_lens, k_buffer, forward_batch, - ): + ) -> torch.Tensor: if not forward_batch.forward_mode.is_decode_or_idle(): return - # Check if new pages were generated (compare current page to last extracted page) start_page = ( self.states.last_extracted_token[req_pool_indices] // self.page_size ) end_page = seq_lens // self.page_size valid_mask = self.states.repr_constructed[req_pool_indices] & ( start_page < end_page - ) # New page(s) generated + ) if not valid_mask.any(): return - # Compute page scores for new pages - self.compute_and_update_page_representations( + self._compute_page_representations( layer_id, req_pool_indices[valid_mask], seq_lens[valid_mask], @@ -138,7 +126,6 @@ def update_representations( k_buffer, ) - # Update states if layer_id == self.end_layer - 1: success_indices = req_pool_indices[valid_mask] self.states.last_extracted_token[success_indices] = ( @@ -154,6 +141,10 @@ def retrieve_topk( attn_metadata: Optional[Any], **kwargs, ) -> tuple: + """ + Default TopK retrieval: score-based selection + recent pages. + Subclasses can override for query-dependent retrieval. + """ bs, device = queries.shape[0], queries.device seq_lens = attn_metadata.cache_seqlens_int32 num_pages = (seq_lens + self.page_size - 1) // self.page_size @@ -166,7 +157,6 @@ def retrieve_topk( if not mask.any(): return out_indices, out_lengths - # Get page scores page_idx = torch.arange(max_pages, device=device).unsqueeze(0) page_start_token = self.req_to_token_pool.req_to_token[ req_pool_indices.unsqueeze(1).expand(bs, max_pages), @@ -174,12 +164,12 @@ def retrieve_topk( 0, self.req_to_token_pool.req_to_token.shape[1] - 1 ), ] - phys_pages = (page_start_token // self.page_size).clamp( - 0, self.page_scores[layer_id].shape[0] - 1 + phys_pages = page_start_token // self.page_size + + scores = self._retrieve_page_scores( + layer_id, phys_pages, req_pool_indices, queries ) - scores = self.page_scores[layer_id][phys_pages].squeeze(-1) - # TopK on history + keep recent recent_start = (num_pages - self.num_recent_pages).clamp(min=0) scores.masked_fill_(page_idx >= recent_start.unsqueeze(1), float("-inf")) @@ -196,7 +186,6 @@ def retrieve_topk( ) recent_mask = recent_idx < num_pages.unsqueeze(1) - # Combine combined = torch.cat( [ torch.where(topk_mask, topk_idx, -1), @@ -210,13 +199,99 @@ def retrieve_topk( mask.unsqueeze(1), combined, -1 ) - if layer_id == 0: - logger.info( - f"Retrieve topk: layer_id={layer_id}, out_indices={out_indices}, out_lengths={out_lengths}" - ) return out_indices, out_lengths - def compute_and_update_page_representations( + @abstractmethod + def _initialize_representation_pools( + self, start_layer: int, end_layer: int, total_num_pages: int + ): + """ + Initialize algorithm-specific representation pools for all layers. + + Subclasses define their own representation format based on algorithm needs. + Examples: + - Knorm: self.page_scores[layer_id] = torch.zeros((total_num_pages, 1)) + - Quest: self.page_reprs[layer_id] = torch.zeros((total_num_pages, head_dim)) + """ + pass + + @abstractmethod + def _compute_page_representations( + self, + layer_id: int, + reqs: torch.Tensor, + seq_lens: torch.Tensor, + start_page, + end_page: torch.Tensor, + k_buffer: torch.Tensor, + ): + """ + Compute and store page representations for given page range. + + Args: + layer_id: Current layer index + reqs: [n] Request pool indices + seq_lens: [n] Current sequence lengths + start_page: Starting page index (int or [n] tensor) + end_page: [n] Ending page indices (exclusive) + k_buffer: Key buffer for the layer + """ + pass + + @abstractmethod + def _retrieve_page_scores( + self, + layer_id: int, + phys_pages: torch.Tensor, + req_pool_indices: torch.Tensor, + queries: torch.Tensor, + ) -> torch.Tensor: + """ + Retrieve page scores for TopK selection. + + Args: + layer_id: Current layer index + phys_pages: [bs, max_pages] Physical page indices + req_pool_indices: [bs] Request pool indices + queries: [bs, num_heads, head_dim] Query vectors + + Returns: + scores: [bs, max_pages] Page scores for ranking + """ + pass + + +class KnormPageAlgorithm(BasePageWiseAlgorithm): + """ + L2-norm based page-wise sparse attention (ChunkKV-style). + + Pages are scored based on key L2 norms aggregated across tokens. + TopK pages are selected based on pre-computed scores with recent pages always included. + + Based on ChunkKV (https://arxiv.org/abs/2502.00299). + + Note: This is an experimental/example implementation for demonstrating + how to integrate algorithms into the sparse framework. + Not production-ready - use for reference and testing purposes only. + """ + + def __init__(self, config, device: torch.device, **kwargs): + super().__init__(config, device, **kwargs) + self.page_scores = {} + + def _initialize_representation_pools( + self, start_layer: int, end_layer: int, total_num_pages: int + ): + for layer_id in range(start_layer, end_layer): + self.page_scores[layer_id] = torch.zeros( + (total_num_pages, 1), dtype=torch.float32, device=self.device + ) + logger.info( + f"Initialized page representation pools: {total_num_pages} pages, " + f"{end_layer - start_layer} layers" + ) + + def _compute_page_representations( self, layer_id: int, reqs: torch.Tensor, @@ -225,7 +300,6 @@ def compute_and_update_page_representations( end_page: torch.Tensor, k_buffer: torch.Tensor, ): - """Compute and store page scores based on key L2 norms.""" if isinstance(start_page, int): start_page = torch.full_like(end_page, start_page) @@ -234,7 +308,6 @@ def compute_and_update_page_representations( n = reqs.shape[0] max_pages = int((end_page - start_page).max().item()) - # Build ranges: [n, max_pages] or [n, max_pages, page_size] pg_off = torch.arange(max_pages, device=device).unsqueeze(0) pg_id = start_page.unsqueeze(1) + pg_off pg_mask = pg_id < end_page.unsqueeze(1) @@ -247,7 +320,6 @@ def compute_and_update_page_representations( < (tok_start + self.page_size).clamp(max=seq_lens.unsqueeze(1)).unsqueeze(2) ) & pg_mask.unsqueeze(2) - # Get physical tokens and compute scores phys_tok = req_to_token[ reqs.view(n, 1, 1).expand(n, max_pages, self.page_size), tok_pos.clamp(0, req_to_token.shape[1] - 1), @@ -256,7 +328,6 @@ def compute_and_update_page_representations( tok_score = k_buffer[phys_tok].norm(dim=-1).sum(dim=-1) pg_score = (tok_score * tok_mask).sum(dim=2) / tok_mask.sum(dim=2).clamp(min=1) - # Store to page_scores phys_pg = ( req_to_token[ reqs.unsqueeze(1).expand(n, max_pages), @@ -272,5 +343,15 @@ def compute_and_update_page_representations( .to(self.page_scores[layer_id].dtype) ) self.page_scores[layer_id][phys_pg[idx[:, 0], idx[:, 1]]] = scores_to_store - if layer_id == 0: - logger.info(f"Compute page scores from {start_page} to {end_page}") + + def _retrieve_page_scores( + self, + layer_id: int, + phys_pages: torch.Tensor, + req_pool_indices: torch.Tensor, + queries: torch.Tensor, + ) -> torch.Tensor: + phys_pages_clamped = phys_pages.clamp( + 0, self.page_scores[layer_id].shape[0] - 1 + ) + return self.page_scores[layer_id][phys_pages_clamped].squeeze(-1) From 205c2bcac2f2d88b7d586d155316756139a53c55 Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Sun, 14 Dec 2025 00:20:49 +0800 Subject: [PATCH 3/5] Refactor structure --- .../mem_cache/sparsity/algorithms/__init__.py | 10 +- .../sparsity/algorithms/base_algorithm.py | 269 +++++++++++-- .../sparsity/algorithms/deepseek_nsa.py | 82 ++-- .../sparsity/algorithms/knorm_algorithm.py | 106 ++++++ .../algorithms/page_wise_algorithm.py | 357 ------------------ .../srt/sparsity/test_knorm_page_algorithm.py | 34 +- 6 files changed, 401 insertions(+), 457 deletions(-) create mode 100644 python/sglang/srt/mem_cache/sparsity/algorithms/knorm_algorithm.py delete mode 100644 python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py index f4e55371679c..01efcdfba985 100644 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py @@ -1,17 +1,13 @@ from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import ( BaseSparseAlgorithm, - SparseMode, + BaseSparseAlgorithmImpl, ) from sglang.srt.mem_cache.sparsity.algorithms.deepseek_nsa import DeepSeekNSAAlgorithm -from sglang.srt.mem_cache.sparsity.algorithms.page_wise_algorithm import ( - BasePageWiseAlgorithm, - KnormPageAlgorithm, -) +from sglang.srt.mem_cache.sparsity.algorithms.knorm_algorithm import KnormPageAlgorithm __all__ = [ "BaseSparseAlgorithm", - "SparseMode", - "BasePageWiseAlgorithm", + "BaseSparseAlgorithmImpl", "KnormPageAlgorithm", "DeepSeekNSAAlgorithm", ] diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py b/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py index fe4c1e109132..e83cef4247e4 100644 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from enum import Enum from typing import TYPE_CHECKING, Any, Optional import torch @@ -8,20 +7,12 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch -class SparseMode(Enum): - """Sparse attention granularity mode.""" - - PAGE_WISE = "page_wise" - TOKEN_WISE = "token_wise" - DEEPSEEK_TOKEN_WISE = "deepseek_token_wise" - - class BaseSparseAlgorithm(ABC): """ Abstract base class for sparse attention algorithms. This class provides a unified interface for implementing various retrievable KVCache - compression algorithms, supporting both page-wise and token-wise sparsity. + compression algorithms. References: - ChunkKV: https://arxiv.org/abs/2502.00299 @@ -59,25 +50,6 @@ def initialize_representation_pool( - SnapKV: Allocate voting scores [num_tokens] and selected positions mask for retention strategy - Look-ahead QCache: Allocate importance scores [num_tokens], eviction mask, and optional pseudo query cache [cache_size, hidden_dim] """ - self.req_to_token_pool = req_to_token_pool - self.states = states - - @abstractmethod - def get_sparse_mode(self) -> SparseMode: - """ - Return the sparsity granularity mode. - - Returns: - SparseMode.PAGE_WISE: Selection operates on page/chunk level - SparseMode.TOKEN_WISE: Selection operates on individual token level - - Algorithm-specific modes: - - ChunkKV: PAGE_WISE (selects important semantic chunks while preserving linguistic structures) - - Quest: PAGE_WISE (selects important pages/blocks) - - PQCache: TOKEN_WISE (selects important tokens via centroid similarity) - - SnapKV: TOKEN_WISE (retention-based: keeps voted important prefix tokens + observation window) - - Look-ahead QCache: TOKEN_WISE (eviction-based: removes tokens with low pseudo query importance) - """ pass def construct_representations( @@ -92,7 +64,7 @@ def construct_representations( Construct initial representations during prefill phase. Called at every layer during forward pass. Algorithm internally decides - whether to perform construction based on self.states.repr_constructed. + whether to perform construction. Typically only constructs once per request during prefill/extend phase. Algorithm-specific implementations: @@ -118,10 +90,9 @@ def update_representations( Called at every layer during forward pass. Algorithm internally decides whether to update based on: - self.states.repr_constructed[req_id]: Whether initial construction done - - self.states.last_extracted_token[req_id]: Last processed position + - self.states.last_constructed_page[req_id]: Last constructed page index - Current seq_lens: To detect new tokens/pages - Algorithm-specific implementations: - ChunkKV: Incrementally compute importance scores for newly generated chunks during decode - Quest: Incrementally compute representations for newly generated pages during decode @@ -160,16 +131,240 @@ def retrieve_topk( selected_indices: [bs, max_selected] Selected page/token indices, padded with -1 valid_lengths: [bs] Actual number of selected indices per request + Note: + - Indices are logical positions that will be mapped to physical KV cache by BackendAdaptor + Algorithm-specific implementations: - ChunkKV: Select top-k chunks based on pre-computed importance scores with layer-wise index reuse - Quest: Compute query-page similarity using current query and stored page representations, select top-k pages - PQCache: Calculate query-centroid similarity, use centroid scores to rank tokens, select top-k tokens - SnapKV: Return union of voted important prefix positions (with clustered neighbors) and observation window tokens - Look-ahead QCache: Return KVs not marked for eviction (eviction based on pseudo query importance evaluation) - - Note: - - For PAGE_WISE mode: Returns page indices - - For TOKEN_WISE mode: Returns token indices - - Indices are logical positions that will be mapped to physical KV cache by BackendAdaptor """ pass + + +class BaseSparseAlgorithmImpl(BaseSparseAlgorithm): + """ + Implementation base class for sparse attention algorithms. + + Provides common infrastructure for algorithms that operate at page/chunk granularity + (token-wise is simply page_size=1): + - Generic construct/update flow with state tracking + - TopK retrieval with recent page retention (can be overridden) + + Subclasses need to implement: + - _initialize_representation_pools(): Initialize algorithm-specific representation pools + - _compute_page_representations(): Compute page scores/representations + - _retrieve_page_scores(): Retrieve page scores for TopK selection + + Subclasses can also override any method for specialized behavior + """ + + def __init__(self, config, device: torch.device, **kwargs): + super().__init__(config, device, **kwargs) + self.compression_ratio = getattr(config, "compression_ratio", 0.2) + self.page_size = getattr(config, "page_size", 64) + self.num_recent_pages = getattr(config, "num_recent_pages", 4) + + def initialize_representation_pool( + self, + start_layer: int, + end_layer: int, + token_to_kv_pool, + req_to_token_pool, + states, + ): + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool = token_to_kv_pool + self.start_layer = start_layer + self.end_layer = end_layer + self.states = states + + total_num_tokens = token_to_kv_pool.get_key_buffer(start_layer).shape[0] + total_num_pages = (total_num_tokens + self.page_size - 1) // self.page_size + + # Initialize algorithm-specific representation pools + self._initialize_representation_pools(start_layer, end_layer, total_num_pages) + + def construct_representations( + self, + layer_id, + req_pool_indices, + seq_lens, + k_buffer, + forward_batch, + ) -> torch.Tensor: + + if not forward_batch.forward_mode.is_extend(): + return + + num_pages = seq_lens // self.page_size + valid_mask = ( + ~self.states.repr_constructed[req_pool_indices] + & (seq_lens >= self.states.prompt_lens[req_pool_indices]) + & (num_pages > 0) + ) + + if not valid_mask.any(): + return + + # Compute page representations by subclass + self._compute_page_representations( + layer_id, + req_pool_indices[valid_mask], + seq_lens[valid_mask], + 0, + num_pages[valid_mask], + k_buffer, + ) + + # Update tracking states + if layer_id == self.end_layer - 1: + success_indices = req_pool_indices[valid_mask] + self.states.repr_constructed[success_indices] = True + self.states.last_constructed_page[success_indices] = num_pages[valid_mask] + + def update_representations( + self, + layer_id, + req_pool_indices, + seq_lens, + k_buffer, + forward_batch, + ) -> torch.Tensor: + if not forward_batch.forward_mode.is_decode_or_idle(): + return + + start_page = self.states.last_constructed_page[req_pool_indices] + end_page = seq_lens // self.page_size + valid_mask = self.states.repr_constructed[req_pool_indices] & ( + start_page < end_page + ) + + if not valid_mask.any(): + return + + # Compute page representations by subclass + self._compute_page_representations( + layer_id, + req_pool_indices[valid_mask], + seq_lens[valid_mask], + start_page[valid_mask], + end_page[valid_mask], + k_buffer, + ) + + # Update tracking states + if layer_id == self.end_layer - 1: + success_indices = req_pool_indices[valid_mask] + self.states.last_constructed_page[success_indices] = end_page[valid_mask] + + def retrieve_topk( + self, + queries: torch.Tensor, + layer_id: int, + req_pool_indices: torch.Tensor, + sparse_mask: torch.Tensor, + attn_metadata: Optional[Any], + **kwargs, + ) -> tuple: + """ + Default TopK retrieval: score-based selection + recent pages. + Subclasses can override for query-dependent retrieval. + + TODO: + 1. Using triton kernel to speed up this function + 2. Support CUDA Graph + """ + bs, device = queries.shape[0], queries.device + seq_lens = attn_metadata.cache_seqlens_int32 + num_pages = (seq_lens + self.page_size - 1) // self.page_size + max_pages = max(int(num_pages.max().item()), 1) + + out_indices = torch.full((bs, max_pages), -1, dtype=torch.int32, device=device) + out_lengths = torch.zeros(bs, dtype=torch.int32, device=device) + + mask = sparse_mask & (num_pages > self.num_recent_pages) + if not mask.any(): + return out_indices, out_lengths + + # Map logical page indices to physical page indices + # page_idx -> logical page indices [0, 1, 2, ...] + # phys_pages: [bs, max_pages] -> physical page indices in KV cache + page_idx = torch.arange(max_pages, device=device).unsqueeze(0) + page_start_token = self.req_to_token_pool.req_to_token[ + req_pool_indices.unsqueeze(1).expand(bs, max_pages), + (page_idx * self.page_size).clamp( + 0, self.req_to_token_pool.req_to_token.shape[1] - 1 + ), + ] + phys_pages = page_start_token // self.page_size + + # Get pre-computed page scores from subclass + scores = self._retrieve_page_scores( + layer_id, phys_pages, req_pool_indices, queries + ) + + # Mask out recent pages from TopK selection (they will be added separately) + # Layout: [history pages ... | recent pages (always kept)] + recent_start = (num_pages - self.num_recent_pages).clamp(min=0) + scores.masked_fill_(page_idx >= recent_start.unsqueeze(1), float("-inf")) + + # Select TopK from history pages based on compression_ratio + k = max( + int((recent_start.float() * (1 - self.compression_ratio)).max().item()), 1 + ) + topk_idx = torch.topk(scores, k=k, dim=1, sorted=False)[1] + topk_mask = torch.arange(k, device=device).unsqueeze(0) < ( + recent_start * (1 - self.compression_ratio) + ).int().clamp(min=1).unsqueeze(1) + + recent_idx = recent_start.unsqueeze(1) + torch.arange( + self.num_recent_pages, device=device + ) + recent_mask = recent_idx < num_pages.unsqueeze(1) + + # Combine TopK history pages + recent pages, sort for sequential access + combined = torch.cat( + [ + torch.where(topk_mask, topk_idx, -1), + torch.where(recent_mask, recent_idx, -1), + ], + dim=1, + ).sort(dim=1)[0] + + out_lengths[:] = torch.where(mask, (combined >= 0).sum(dim=1).int(), 0) + out_indices[:, : combined.shape[1]] = torch.where( + mask.unsqueeze(1), combined, -1 + ) + + return out_indices, out_lengths + + def _initialize_representation_pools( + self, start_layer: int, end_layer: int, total_num_pages: int + ): + """Initialize algorithm-specific representation pools for all layers.""" + raise NotImplementedError + + def _compute_page_representations( + self, + layer_id: int, + reqs: torch.Tensor, + seq_lens: torch.Tensor, + start_page, + end_page: torch.Tensor, + k_buffer: torch.Tensor, + ): + """Compute and store page representations for given page range.""" + raise NotImplementedError + + def _retrieve_page_scores( + self, + layer_id: int, + phys_pages: torch.Tensor, + req_pool_indices: torch.Tensor, + queries: torch.Tensor, + ) -> torch.Tensor: + """Retrieve page scores for TopK selection.""" + raise NotImplementedError diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/deepseek_nsa.py b/python/sglang/srt/mem_cache/sparsity/algorithms/deepseek_nsa.py index a319c2c17f66..6d64f10eed62 100644 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/deepseek_nsa.py +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/deepseek_nsa.py @@ -1,33 +1,23 @@ -import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import Any, Optional -import nvtx import torch from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import ( - BaseSparseAlgorithm, - SparseMode, + BaseSparseAlgorithmImpl, ) -if TYPE_CHECKING: - from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer - from sglang.srt.model_executor.forward_batch_info import ForwardBatch -logger = logging.getLogger(__name__) +class DeepSeekNSAAlgorithm(BaseSparseAlgorithmImpl): + """ + Sparse attention algorithm for DeepSeek NSA. - -class DeepSeekNSAAlgorithm(BaseSparseAlgorithm): - """Sparse attention algorithm for DeepSeek NSA.""" + This algorithm uses NSA's native indexer for TopK retrieval. + Overrides all parent methods as NSA has its own specialized flow. + """ def __init__(self, config, device: torch.device, **kwargs): super().__init__(config, device, **kwargs) - self.index_topk = getattr(config, "index_topk", 2048) - logger.info(f"DeepSeekNSAAlgorithm initialized: index_topk={self.index_topk}") - - def get_sparse_mode(self) -> SparseMode: - return SparseMode.DEEPSEEK_TOKEN_WISE - @nvtx.annotate("DeepSeekNSAAlgorithm.retrieve_topk", color="green") def retrieve_topk( self, queries: torch.Tensor, @@ -37,38 +27,54 @@ def retrieve_topk( attn_metadata: Optional[Any], **kwargs, ) -> tuple: - indexer: Optional["Indexer"] = kwargs.get("indexer") - forward_batch: Optional["ForwardBatch"] = kwargs.get("forward_batch") - x, q_lora, positions = ( + indexer, forward_batch, x, q_lora, positions = ( + kwargs.get("indexer"), + kwargs.get("forward_batch"), kwargs.get("x"), kwargs.get("q_lora"), kwargs.get("positions"), ) if any(v is None for v in [indexer, x, q_lora, positions, forward_batch]): - raise ValueError("Required: indexer, x, q_lora, positions, forward_batch") + raise ValueError("Required: indexer, forward_batch, x, q_lora, positions") - try: - # Using the nsa's original indexer to get the topk indices. - topk_indices = indexer( + return ( + indexer( x=x, q_lora=q_lora, positions=positions, forward_batch=forward_batch, layer_id=layer_id, - ) + ), + None, + ) - if topk_indices is None: - return self._empty_result(queries.shape[0], queries.device) + def initialize_representation_pool( + self, + start_layer: int, + end_layer: int, + token_to_kv_pool, + req_to_token_pool, + states, + ): + pass - return topk_indices, None - except Exception as e: - logger.error(f"Layer {layer_id} NSA indexer failed: {e}", exc_info=True) - return self._empty_result(queries.shape[0], queries.device) + def construct_representations( + self, + layer_id, + req_pool_indices, + seq_lens, + k_buffer, + forward_batch, + ): + pass - def _empty_result(self, batch_size: int, device: torch.device) -> tuple: - selected_indices = torch.full( - (batch_size, self.index_topk), -1, dtype=torch.int32, device=device - ) - valid_lengths = torch.zeros(batch_size, dtype=torch.int32, device=device) - return selected_indices, valid_lengths + def update_representations( + self, + layer_id, + req_pool_indices, + seq_lens, + k_buffer, + forward_batch, + ): + pass diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/knorm_algorithm.py b/python/sglang/srt/mem_cache/sparsity/algorithms/knorm_algorithm.py new file mode 100644 index 000000000000..0ac3508d03a5 --- /dev/null +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/knorm_algorithm.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +import logging + +import torch + +from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import ( + BaseSparseAlgorithmImpl, +) + +logger = logging.getLogger(__name__) + + +class KnormPageAlgorithm(BaseSparseAlgorithmImpl): + """ + L2-norm based page-wise sparse attention. + + Pages are scored based on key L2 norms aggregated across tokens. + TopK pages are selected based on pre-computed scores with recent pages always included. + + Based on ChunkKV (https://arxiv.org/abs/2502.00299). + + Note: This is an experimental/example implementation for demonstrating + how to integrate algorithms into the sparse framework. + Not production-ready - use for reference and testing purposes only. + """ + + def __init__(self, config, device: torch.device, **kwargs): + super().__init__(config, device, **kwargs) + self.page_scores = {} + + def _initialize_representation_pools( + self, start_layer: int, end_layer: int, total_num_pages: int + ): + for layer_id in range(start_layer, end_layer): + self.page_scores[layer_id] = torch.zeros( + (total_num_pages, 1), dtype=torch.float32, device=self.device + ) + logger.info( + f"Initialized page representation pools: {total_num_pages} pages, " + f"{end_layer - start_layer} layers" + ) + + def _compute_page_representations( + self, + layer_id: int, + reqs: torch.Tensor, + seq_lens: torch.Tensor, + start_page, + end_page: torch.Tensor, + k_buffer: torch.Tensor, + ): + if isinstance(start_page, int): + start_page = torch.full_like(end_page, start_page) + + device = k_buffer.device + req_to_token = self.req_to_token_pool.req_to_token + n = reqs.shape[0] + max_pages = int((end_page - start_page).max().item()) + + pg_off = torch.arange(max_pages, device=device).unsqueeze(0) + pg_id = start_page.unsqueeze(1) + pg_off + pg_mask = pg_id < end_page.unsqueeze(1) + + tok_start = pg_id * self.page_size + tok_off = torch.arange(self.page_size, device=device).view(1, 1, -1) + tok_pos = tok_start.unsqueeze(2) + tok_off + tok_mask = ( + tok_pos + < (tok_start + self.page_size).clamp(max=seq_lens.unsqueeze(1)).unsqueeze(2) + ) & pg_mask.unsqueeze(2) + + phys_tok = req_to_token[ + reqs.view(n, 1, 1).expand(n, max_pages, self.page_size), + tok_pos.clamp(0, req_to_token.shape[1] - 1), + ].clamp(0, k_buffer.shape[0] - 1) + + tok_score = k_buffer[phys_tok].norm(dim=-1).sum(dim=-1) + pg_score = (tok_score * tok_mask).sum(dim=2) / tok_mask.sum(dim=2).clamp(min=1) + + phys_pg = ( + req_to_token[ + reqs.unsqueeze(1).expand(n, max_pages), + tok_start.clamp(0, req_to_token.shape[1] - 1), + ] + // self.page_size + ) + idx = pg_mask.nonzero(as_tuple=False) + if idx.numel() > 0: + scores_to_store = ( + pg_score[idx[:, 0], idx[:, 1]] + .unsqueeze(-1) + .to(self.page_scores[layer_id].dtype) + ) + self.page_scores[layer_id][phys_pg[idx[:, 0], idx[:, 1]]] = scores_to_store + + def _retrieve_page_scores( + self, + layer_id: int, + phys_pages: torch.Tensor, + req_pool_indices: torch.Tensor, + queries: torch.Tensor, + ) -> torch.Tensor: + phys_pages_clamped = phys_pages.clamp( + 0, self.page_scores[layer_id].shape[0] - 1 + ) + return self.page_scores[layer_id][phys_pages_clamped].squeeze(-1) diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py b/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py deleted file mode 100644 index 18a05c10c201..000000000000 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/page_wise_algorithm.py +++ /dev/null @@ -1,357 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import logging -from abc import abstractmethod -from typing import Any, Optional - -import torch - -from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import ( - BaseSparseAlgorithm, - SparseMode, -) - -logger = logging.getLogger(__name__) - - -class BasePageWiseAlgorithm(BaseSparseAlgorithm): - """ - Base class for page-wise sparse attention algorithms. - - Provides common infrastructure for algorithms that operate at page/chunk granularity: - - Generic construct/update flow with state tracking - - TopK retrieval with recent page retention (can be overridden) - - Subclasses need to implement: - - _initialize_representation_pools(): Initialize algorithm-specific representation pools - - _compute_page_representations(): Compute page scores/representations - - _retrieve_page_scores(): Retrieve page scores for TopK selection - - Subclasses can optionally override: - - retrieve_topk(): For query-dependent retrieval logic - """ - - def __init__(self, config, device: torch.device, **kwargs): - super().__init__(config, device, **kwargs) - self.compression_ratio = getattr(config, "compression_ratio", 0.2) - self.page_size = getattr(config, "page_size", 64) - self.num_recent_pages = getattr(config, "num_recent_pages", 4) - - def get_sparse_mode(self) -> SparseMode: - return SparseMode.PAGE_WISE - - def initialize_representation_pool( - self, - start_layer: int, - end_layer: int, - token_to_kv_pool, - req_to_token_pool, - states, - ): - super().initialize_representation_pool( - start_layer, end_layer, token_to_kv_pool, req_to_token_pool, states - ) - self.start_layer = start_layer - self.end_layer = end_layer - - total_num_tokens = token_to_kv_pool.get_key_buffer(start_layer).shape[0] - total_num_pages = (total_num_tokens + self.page_size - 1) // self.page_size - - self._initialize_representation_pools(start_layer, end_layer, total_num_pages) - - def construct_representations( - self, - layer_id, - req_pool_indices, - seq_lens, - k_buffer, - forward_batch, - ) -> torch.Tensor: - if not forward_batch.forward_mode.is_extend(): - return - - num_pages = seq_lens // self.page_size - valid_mask = ( - ~self.states.repr_constructed[req_pool_indices] - & (seq_lens >= self.states.prompt_lens[req_pool_indices]) - & (num_pages > 0) - ) - - if not valid_mask.any(): - return - - self._compute_page_representations( - layer_id, - req_pool_indices[valid_mask], - seq_lens[valid_mask], - 0, - num_pages[valid_mask], - k_buffer, - ) - - if layer_id == self.end_layer - 1: - success_indices = req_pool_indices[valid_mask] - self.states.repr_constructed[success_indices] = True - self.states.last_extracted_token[success_indices] = ( - seq_lens[valid_mask] // self.page_size * self.page_size - ) - - def update_representations( - self, - layer_id, - req_pool_indices, - seq_lens, - k_buffer, - forward_batch, - ) -> torch.Tensor: - if not forward_batch.forward_mode.is_decode_or_idle(): - return - - start_page = ( - self.states.last_extracted_token[req_pool_indices] // self.page_size - ) - end_page = seq_lens // self.page_size - valid_mask = self.states.repr_constructed[req_pool_indices] & ( - start_page < end_page - ) - - if not valid_mask.any(): - return - - self._compute_page_representations( - layer_id, - req_pool_indices[valid_mask], - seq_lens[valid_mask], - start_page[valid_mask], - end_page[valid_mask], - k_buffer, - ) - - if layer_id == self.end_layer - 1: - success_indices = req_pool_indices[valid_mask] - self.states.last_extracted_token[success_indices] = ( - seq_lens[valid_mask] // self.page_size * self.page_size - ) - - def retrieve_topk( - self, - queries: torch.Tensor, - layer_id: int, - req_pool_indices: torch.Tensor, - sparse_mask: torch.Tensor, - attn_metadata: Optional[Any], - **kwargs, - ) -> tuple: - """ - Default TopK retrieval: score-based selection + recent pages. - Subclasses can override for query-dependent retrieval. - """ - bs, device = queries.shape[0], queries.device - seq_lens = attn_metadata.cache_seqlens_int32 - num_pages = (seq_lens + self.page_size - 1) // self.page_size - max_pages = max(int(num_pages.max().item()), 1) - - out_indices = torch.full((bs, max_pages), -1, dtype=torch.int32, device=device) - out_lengths = torch.zeros(bs, dtype=torch.int32, device=device) - - mask = sparse_mask & (num_pages > self.num_recent_pages) - if not mask.any(): - return out_indices, out_lengths - - page_idx = torch.arange(max_pages, device=device).unsqueeze(0) - page_start_token = self.req_to_token_pool.req_to_token[ - req_pool_indices.unsqueeze(1).expand(bs, max_pages), - (page_idx * self.page_size).clamp( - 0, self.req_to_token_pool.req_to_token.shape[1] - 1 - ), - ] - phys_pages = page_start_token // self.page_size - - scores = self._retrieve_page_scores( - layer_id, phys_pages, req_pool_indices, queries - ) - - recent_start = (num_pages - self.num_recent_pages).clamp(min=0) - scores.masked_fill_(page_idx >= recent_start.unsqueeze(1), float("-inf")) - - k = max( - int((recent_start.float() * (1 - self.compression_ratio)).max().item()), 1 - ) - topk_idx = torch.topk(scores, k=k, dim=1, sorted=False)[1] - topk_mask = torch.arange(k, device=device).unsqueeze(0) < ( - recent_start * (1 - self.compression_ratio) - ).int().clamp(min=1).unsqueeze(1) - - recent_idx = recent_start.unsqueeze(1) + torch.arange( - self.num_recent_pages, device=device - ) - recent_mask = recent_idx < num_pages.unsqueeze(1) - - combined = torch.cat( - [ - torch.where(topk_mask, topk_idx, -1), - torch.where(recent_mask, recent_idx, -1), - ], - dim=1, - ).sort(dim=1)[0] - - out_lengths[:] = torch.where(mask, (combined >= 0).sum(dim=1).int(), 0) - out_indices[:, : combined.shape[1]] = torch.where( - mask.unsqueeze(1), combined, -1 - ) - - return out_indices, out_lengths - - @abstractmethod - def _initialize_representation_pools( - self, start_layer: int, end_layer: int, total_num_pages: int - ): - """ - Initialize algorithm-specific representation pools for all layers. - - Subclasses define their own representation format based on algorithm needs. - Examples: - - Knorm: self.page_scores[layer_id] = torch.zeros((total_num_pages, 1)) - - Quest: self.page_reprs[layer_id] = torch.zeros((total_num_pages, head_dim)) - """ - pass - - @abstractmethod - def _compute_page_representations( - self, - layer_id: int, - reqs: torch.Tensor, - seq_lens: torch.Tensor, - start_page, - end_page: torch.Tensor, - k_buffer: torch.Tensor, - ): - """ - Compute and store page representations for given page range. - - Args: - layer_id: Current layer index - reqs: [n] Request pool indices - seq_lens: [n] Current sequence lengths - start_page: Starting page index (int or [n] tensor) - end_page: [n] Ending page indices (exclusive) - k_buffer: Key buffer for the layer - """ - pass - - @abstractmethod - def _retrieve_page_scores( - self, - layer_id: int, - phys_pages: torch.Tensor, - req_pool_indices: torch.Tensor, - queries: torch.Tensor, - ) -> torch.Tensor: - """ - Retrieve page scores for TopK selection. - - Args: - layer_id: Current layer index - phys_pages: [bs, max_pages] Physical page indices - req_pool_indices: [bs] Request pool indices - queries: [bs, num_heads, head_dim] Query vectors - - Returns: - scores: [bs, max_pages] Page scores for ranking - """ - pass - - -class KnormPageAlgorithm(BasePageWiseAlgorithm): - """ - L2-norm based page-wise sparse attention (ChunkKV-style). - - Pages are scored based on key L2 norms aggregated across tokens. - TopK pages are selected based on pre-computed scores with recent pages always included. - - Based on ChunkKV (https://arxiv.org/abs/2502.00299). - - Note: This is an experimental/example implementation for demonstrating - how to integrate algorithms into the sparse framework. - Not production-ready - use for reference and testing purposes only. - """ - - def __init__(self, config, device: torch.device, **kwargs): - super().__init__(config, device, **kwargs) - self.page_scores = {} - - def _initialize_representation_pools( - self, start_layer: int, end_layer: int, total_num_pages: int - ): - for layer_id in range(start_layer, end_layer): - self.page_scores[layer_id] = torch.zeros( - (total_num_pages, 1), dtype=torch.float32, device=self.device - ) - logger.info( - f"Initialized page representation pools: {total_num_pages} pages, " - f"{end_layer - start_layer} layers" - ) - - def _compute_page_representations( - self, - layer_id: int, - reqs: torch.Tensor, - seq_lens: torch.Tensor, - start_page, - end_page: torch.Tensor, - k_buffer: torch.Tensor, - ): - if isinstance(start_page, int): - start_page = torch.full_like(end_page, start_page) - - device = k_buffer.device - req_to_token = self.req_to_token_pool.req_to_token - n = reqs.shape[0] - max_pages = int((end_page - start_page).max().item()) - - pg_off = torch.arange(max_pages, device=device).unsqueeze(0) - pg_id = start_page.unsqueeze(1) + pg_off - pg_mask = pg_id < end_page.unsqueeze(1) - - tok_start = pg_id * self.page_size - tok_off = torch.arange(self.page_size, device=device).view(1, 1, -1) - tok_pos = tok_start.unsqueeze(2) + tok_off - tok_mask = ( - tok_pos - < (tok_start + self.page_size).clamp(max=seq_lens.unsqueeze(1)).unsqueeze(2) - ) & pg_mask.unsqueeze(2) - - phys_tok = req_to_token[ - reqs.view(n, 1, 1).expand(n, max_pages, self.page_size), - tok_pos.clamp(0, req_to_token.shape[1] - 1), - ].clamp(0, k_buffer.shape[0] - 1) - - tok_score = k_buffer[phys_tok].norm(dim=-1).sum(dim=-1) - pg_score = (tok_score * tok_mask).sum(dim=2) / tok_mask.sum(dim=2).clamp(min=1) - - phys_pg = ( - req_to_token[ - reqs.unsqueeze(1).expand(n, max_pages), - tok_start.clamp(0, req_to_token.shape[1] - 1), - ] - // self.page_size - ) - idx = pg_mask.nonzero(as_tuple=False) - if idx.numel() > 0: - scores_to_store = ( - pg_score[idx[:, 0], idx[:, 1]] - .unsqueeze(-1) - .to(self.page_scores[layer_id].dtype) - ) - self.page_scores[layer_id][phys_pg[idx[:, 0], idx[:, 1]]] = scores_to_store - - def _retrieve_page_scores( - self, - layer_id: int, - phys_pages: torch.Tensor, - req_pool_indices: torch.Tensor, - queries: torch.Tensor, - ) -> torch.Tensor: - phys_pages_clamped = phys_pages.clamp( - 0, self.page_scores[layer_id].shape[0] - 1 - ) - return self.page_scores[layer_id][phys_pages_clamped].squeeze(-1) diff --git a/test/srt/sparsity/test_knorm_page_algorithm.py b/test/srt/sparsity/test_knorm_page_algorithm.py index ebef1cba11cc..c577393235ea 100644 --- a/test/srt/sparsity/test_knorm_page_algorithm.py +++ b/test/srt/sparsity/test_knorm_page_algorithm.py @@ -3,10 +3,7 @@ import torch -from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import SparseMode -from sglang.srt.mem_cache.sparsity.algorithms.page_wise_algorithm import ( - KnormPageAlgorithm, -) +from sglang.srt.mem_cache.sparsity.algorithms.knorm_algorithm import KnormPageAlgorithm from sglang.srt.model_executor.forward_batch_info import ForwardMode @@ -41,10 +38,10 @@ def __init__(self, max_reqs=32, max_tokens=2048, num_physical_tokens=1024): class MockStates: def __init__(self, max_reqs=32): - self.prompt_lens = torch.zeros(max_reqs, dtype=torch.int32, device="cuda") + self.prompt_lens = torch.zeros(max_reqs, dtype=torch.int64, device="cuda") self.repr_constructed = torch.zeros(max_reqs, dtype=torch.bool, device="cuda") - self.last_extracted_token = torch.zeros( - max_reqs, dtype=torch.int32, device="cuda" + self.last_constructed_page = torch.zeros( + max_reqs, dtype=torch.int64, device="cuda" ) @@ -69,9 +66,6 @@ def setUp(self): ) self.states = MockStates(max_reqs=8) - def test_get_sparse_mode(self): - self.assertEqual(self.algorithm.get_sparse_mode(), SparseMode.PAGE_WISE) - def test_initialize_representation_pool(self): self.algorithm.initialize_representation_pool( start_layer=0, @@ -91,7 +85,7 @@ def test_construct_representations(self): ) req_pool_indices = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - seq_lens = torch.tensor([128, 192], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([128, 192], dtype=torch.int64, device="cuda") k_buffer = self.token_to_kv_pool.get_key_buffer(0) forward_batch = MockForwardBatch(mode=ForwardMode.EXTEND) @@ -107,8 +101,10 @@ def test_construct_representations(self): self.assertTrue(self.states.repr_constructed[0]) self.assertTrue(self.states.repr_constructed[1]) - self.assertEqual(self.states.last_extracted_token[0].item(), 128) - self.assertEqual(self.states.last_extracted_token[1].item(), 192) + # last_constructed_page stores page count, not token position + # 128 / 64 = 2 pages, 192 / 64 = 3 pages + self.assertEqual(self.states.last_constructed_page[0].item(), 2) + self.assertEqual(self.states.last_constructed_page[1].item(), 3) def test_update_representations(self): self.algorithm.initialize_representation_pool( @@ -116,13 +112,14 @@ def test_update_representations(self): ) req_pool_indices = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - seq_lens = torch.tensor([192, 256], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([192, 256], dtype=torch.int64, device="cuda") k_buffer = self.token_to_kv_pool.get_key_buffer(0) forward_batch = MockForwardBatch(mode=ForwardMode.DECODE) self.states.repr_constructed[req_pool_indices] = True - self.states.last_extracted_token[req_pool_indices] = torch.tensor( - [128, 128], dtype=torch.int32, device="cuda" + # Start from page 2 (was 128 tokens) + self.states.last_constructed_page[req_pool_indices] = torch.tensor( + [2, 2], dtype=torch.int64, device="cuda" ) self.algorithm.update_representations( @@ -133,8 +130,9 @@ def test_update_representations(self): forward_batch=forward_batch, ) - self.assertEqual(self.states.last_extracted_token[0].item(), 192) - self.assertEqual(self.states.last_extracted_token[1].item(), 256) + # 192 / 64 = 3 pages, 256 / 64 = 4 pages + self.assertEqual(self.states.last_constructed_page[0].item(), 3) + self.assertEqual(self.states.last_constructed_page[1].item(), 4) def test_retrieve_topk(self): self.algorithm.initialize_representation_pool( From 0d52ca1ccf191dc12e34af9a8be25108c43f3b06 Mon Sep 17 00:00:00 2001 From: MagicYang1573 <1328657938@qq.com> Date: Wed, 17 Dec 2025 18:59:14 +0800 Subject: [PATCH 4/5] Add Implementation of quest_algorithm.py --- .../mem_cache/sparsity/algorithms/__init__.py | 2 + .../sparsity/algorithms/base_algorithm.py | 137 ++++++++------- .../sparsity/algorithms/knorm_algorithm.py | 106 ----------- .../sparsity/algorithms/quest_algorithm.py | 166 ++++++++++++++++++ .../srt/sparsity/test_knorm_page_algorithm.py | 166 ------------------ 5 files changed, 243 insertions(+), 334 deletions(-) delete mode 100644 python/sglang/srt/mem_cache/sparsity/algorithms/knorm_algorithm.py create mode 100644 python/sglang/srt/mem_cache/sparsity/algorithms/quest_algorithm.py delete mode 100644 test/srt/sparsity/test_knorm_page_algorithm.py diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py index 01efcdfba985..577d4a3c734c 100644 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py @@ -4,10 +4,12 @@ ) from sglang.srt.mem_cache.sparsity.algorithms.deepseek_nsa import DeepSeekNSAAlgorithm from sglang.srt.mem_cache.sparsity.algorithms.knorm_algorithm import KnormPageAlgorithm +from sglang.srt.mem_cache.sparsity.algorithms.quest_algorithm import QuestAlgorithm __all__ = [ "BaseSparseAlgorithm", "BaseSparseAlgorithmImpl", "KnormPageAlgorithm", "DeepSeekNSAAlgorithm", + "QuestAlgorithm", ] diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py b/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py index e83cef4247e4..69423a132771 100644 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING import torch @@ -12,7 +12,7 @@ class BaseSparseAlgorithm(ABC): Abstract base class for sparse attention algorithms. This class provides a unified interface for implementing various retrievable KVCache - compression algorithms. + compression algorithms. Token-wise sparsity is treated as page-wise with page_size=1. References: - ChunkKV: https://arxiv.org/abs/2502.00299 @@ -109,7 +109,6 @@ def retrieve_topk( layer_id: int, req_pool_indices: torch.Tensor, sparse_mask: torch.Tensor, - attn_metadata: Optional[Any], **kwargs, ) -> tuple: """ @@ -163,7 +162,7 @@ class BaseSparseAlgorithmImpl(BaseSparseAlgorithm): def __init__(self, config, device: torch.device, **kwargs): super().__init__(config, device, **kwargs) - self.compression_ratio = getattr(config, "compression_ratio", 0.2) + self.compression_ratio = getattr(config, "compression_ratio", 0.3) self.page_size = getattr(config, "page_size", 64) self.num_recent_pages = getattr(config, "num_recent_pages", 4) @@ -266,7 +265,6 @@ def retrieve_topk( layer_id: int, req_pool_indices: torch.Tensor, sparse_mask: torch.Tensor, - attn_metadata: Optional[Any], **kwargs, ) -> tuple: """ @@ -278,66 +276,81 @@ def retrieve_topk( 2. Support CUDA Graph """ bs, device = queries.shape[0], queries.device - seq_lens = attn_metadata.cache_seqlens_int32 - num_pages = (seq_lens + self.page_size - 1) // self.page_size - max_pages = max(int(num_pages.max().item()), 1) - out_indices = torch.full((bs, max_pages), -1, dtype=torch.int32, device=device) + seq_lens_source = kwargs.get("forward_batch", None) + if seq_lens_source is None or not hasattr(seq_lens_source, "seq_lens"): + raise ValueError( + "forward_batch with seq_lens is required for TopK retrieval" + ) + seq_lens = seq_lens_source.seq_lens.to(device) + + req_to_token = self.req_to_token_pool.req_to_token + max_req_tokens = req_to_token.shape[1] + + per_request_indices = [] + per_request_lengths = [] + + for i in range(bs): + if not sparse_mask[i]: + per_request_indices.append( + torch.empty(0, device=device, dtype=torch.int32) + ) + per_request_lengths.append(0) + continue + + num_pages = int((seq_lens[i].item() + self.page_size - 1) // self.page_size) + if num_pages <= self.num_recent_pages: + per_request_indices.append( + torch.empty(0, device=device, dtype=torch.int32) + ) + per_request_lengths.append(0) + continue + + page_idx = torch.arange(num_pages, device=device) + page_start_token = req_to_token[ + req_pool_indices[i], + (page_idx * self.page_size).clamp(0, max_req_tokens - 1), + ] + phys_pages = (page_start_token // self.page_size).unsqueeze(0) + + scores = self._retrieve_page_scores( + layer_id, + phys_pages, + req_pool_indices[i : i + 1], + queries[i : i + 1], + ) + + recent_start = max(num_pages - self.num_recent_pages, 0) + scores = scores.clone() + scores[:, recent_start:] = float("-inf") + + history_pages = max(recent_start, 1) + k = max(int(history_pages * (1 - self.compression_ratio)), 1) + k = min(k, history_pages) + topk_idx = torch.topk(scores, k=k, dim=1, sorted=False)[1].squeeze(0) + + recent_idx = torch.arange( + recent_start, recent_start + self.num_recent_pages, device=device + ) + recent_idx = recent_idx[recent_idx < num_pages] + + combined = ( + torch.cat([topk_idx, recent_idx], dim=0).sort()[0].to(torch.int32) + ) + + per_request_indices.append(combined) + per_request_lengths.append(int(combined.numel())) + + max_len = max(max(per_request_lengths, default=0), 1) + out_indices = torch.full((bs, max_len), -1, dtype=torch.int32, device=device) out_lengths = torch.zeros(bs, dtype=torch.int32, device=device) - mask = sparse_mask & (num_pages > self.num_recent_pages) - if not mask.any(): - return out_indices, out_lengths - - # Map logical page indices to physical page indices - # page_idx -> logical page indices [0, 1, 2, ...] - # phys_pages: [bs, max_pages] -> physical page indices in KV cache - page_idx = torch.arange(max_pages, device=device).unsqueeze(0) - page_start_token = self.req_to_token_pool.req_to_token[ - req_pool_indices.unsqueeze(1).expand(bs, max_pages), - (page_idx * self.page_size).clamp( - 0, self.req_to_token_pool.req_to_token.shape[1] - 1 - ), - ] - phys_pages = page_start_token // self.page_size - - # Get pre-computed page scores from subclass - scores = self._retrieve_page_scores( - layer_id, phys_pages, req_pool_indices, queries - ) - - # Mask out recent pages from TopK selection (they will be added separately) - # Layout: [history pages ... | recent pages (always kept)] - recent_start = (num_pages - self.num_recent_pages).clamp(min=0) - scores.masked_fill_(page_idx >= recent_start.unsqueeze(1), float("-inf")) - - # Select TopK from history pages based on compression_ratio - k = max( - int((recent_start.float() * (1 - self.compression_ratio)).max().item()), 1 - ) - topk_idx = torch.topk(scores, k=k, dim=1, sorted=False)[1] - topk_mask = torch.arange(k, device=device).unsqueeze(0) < ( - recent_start * (1 - self.compression_ratio) - ).int().clamp(min=1).unsqueeze(1) - - recent_idx = recent_start.unsqueeze(1) + torch.arange( - self.num_recent_pages, device=device - ) - recent_mask = recent_idx < num_pages.unsqueeze(1) - - # Combine TopK history pages + recent pages, sort for sequential access - combined = torch.cat( - [ - torch.where(topk_mask, topk_idx, -1), - torch.where(recent_mask, recent_idx, -1), - ], - dim=1, - ).sort(dim=1)[0] - - out_lengths[:] = torch.where(mask, (combined >= 0).sum(dim=1).int(), 0) - out_indices[:, : combined.shape[1]] = torch.where( - mask.unsqueeze(1), combined, -1 - ) + for i, selected in enumerate(per_request_indices): + length = per_request_lengths[i] + if length == 0: + continue + out_indices[i, :length] = selected + out_lengths[i] = length return out_indices, out_lengths diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/knorm_algorithm.py b/python/sglang/srt/mem_cache/sparsity/algorithms/knorm_algorithm.py deleted file mode 100644 index 0ac3508d03a5..000000000000 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/knorm_algorithm.py +++ /dev/null @@ -1,106 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import logging - -import torch - -from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import ( - BaseSparseAlgorithmImpl, -) - -logger = logging.getLogger(__name__) - - -class KnormPageAlgorithm(BaseSparseAlgorithmImpl): - """ - L2-norm based page-wise sparse attention. - - Pages are scored based on key L2 norms aggregated across tokens. - TopK pages are selected based on pre-computed scores with recent pages always included. - - Based on ChunkKV (https://arxiv.org/abs/2502.00299). - - Note: This is an experimental/example implementation for demonstrating - how to integrate algorithms into the sparse framework. - Not production-ready - use for reference and testing purposes only. - """ - - def __init__(self, config, device: torch.device, **kwargs): - super().__init__(config, device, **kwargs) - self.page_scores = {} - - def _initialize_representation_pools( - self, start_layer: int, end_layer: int, total_num_pages: int - ): - for layer_id in range(start_layer, end_layer): - self.page_scores[layer_id] = torch.zeros( - (total_num_pages, 1), dtype=torch.float32, device=self.device - ) - logger.info( - f"Initialized page representation pools: {total_num_pages} pages, " - f"{end_layer - start_layer} layers" - ) - - def _compute_page_representations( - self, - layer_id: int, - reqs: torch.Tensor, - seq_lens: torch.Tensor, - start_page, - end_page: torch.Tensor, - k_buffer: torch.Tensor, - ): - if isinstance(start_page, int): - start_page = torch.full_like(end_page, start_page) - - device = k_buffer.device - req_to_token = self.req_to_token_pool.req_to_token - n = reqs.shape[0] - max_pages = int((end_page - start_page).max().item()) - - pg_off = torch.arange(max_pages, device=device).unsqueeze(0) - pg_id = start_page.unsqueeze(1) + pg_off - pg_mask = pg_id < end_page.unsqueeze(1) - - tok_start = pg_id * self.page_size - tok_off = torch.arange(self.page_size, device=device).view(1, 1, -1) - tok_pos = tok_start.unsqueeze(2) + tok_off - tok_mask = ( - tok_pos - < (tok_start + self.page_size).clamp(max=seq_lens.unsqueeze(1)).unsqueeze(2) - ) & pg_mask.unsqueeze(2) - - phys_tok = req_to_token[ - reqs.view(n, 1, 1).expand(n, max_pages, self.page_size), - tok_pos.clamp(0, req_to_token.shape[1] - 1), - ].clamp(0, k_buffer.shape[0] - 1) - - tok_score = k_buffer[phys_tok].norm(dim=-1).sum(dim=-1) - pg_score = (tok_score * tok_mask).sum(dim=2) / tok_mask.sum(dim=2).clamp(min=1) - - phys_pg = ( - req_to_token[ - reqs.unsqueeze(1).expand(n, max_pages), - tok_start.clamp(0, req_to_token.shape[1] - 1), - ] - // self.page_size - ) - idx = pg_mask.nonzero(as_tuple=False) - if idx.numel() > 0: - scores_to_store = ( - pg_score[idx[:, 0], idx[:, 1]] - .unsqueeze(-1) - .to(self.page_scores[layer_id].dtype) - ) - self.page_scores[layer_id][phys_pg[idx[:, 0], idx[:, 1]]] = scores_to_store - - def _retrieve_page_scores( - self, - layer_id: int, - phys_pages: torch.Tensor, - req_pool_indices: torch.Tensor, - queries: torch.Tensor, - ) -> torch.Tensor: - phys_pages_clamped = phys_pages.clamp( - 0, self.page_scores[layer_id].shape[0] - 1 - ) - return self.page_scores[layer_id][phys_pages_clamped].squeeze(-1) diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/quest_algorithm.py b/python/sglang/srt/mem_cache/sparsity/algorithms/quest_algorithm.py new file mode 100644 index 000000000000..94678546af31 --- /dev/null +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/quest_algorithm.py @@ -0,0 +1,166 @@ +""" +Quest sparse attention algorithm. + +This implementation follows the Quest paper's bounding-box estimation for +query-aware page selection. For each KV page, it maintains per-dimension +min/max of keys and uses them to upper-bound attention scores without +materializing full dot products. +""" + +import logging + +import torch + +from sglang.srt.mem_cache.sparsity.algorithms.base_algorithm import ( + BaseSparseAlgorithmImpl, +) + +logger = logging.getLogger(__name__) + + +class QuestAlgorithm(BaseSparseAlgorithmImpl): + """Quest page-wise sparse attention using bounding-box criticality.""" + + def __init__(self, config, device: torch.device, **kwargs): + super().__init__(config, device, **kwargs) + self.page_k_min = {} + self.page_k_max = {} + self.page_valid = {} + + def _initialize_representation_pools( + self, start_layer: int, end_layer: int, total_num_pages: int + ): + key_buf = self.token_to_kv_pool.get_key_buffer(start_layer) + head_num, head_dim = key_buf.shape[1], key_buf.shape[2] + + for layer_id in range(start_layer, end_layer): + self.page_k_min[layer_id] = torch.zeros( + (total_num_pages, head_num, head_dim), + dtype=torch.float32, + device=self.device, + ) + self.page_k_max[layer_id] = torch.zeros_like(self.page_k_min[layer_id]) + self.page_valid[layer_id] = torch.zeros( + total_num_pages, dtype=torch.bool, device=self.device + ) + + logger.info( + "Initialized Quest page reps: %d pages, %d layers, head_num=%d, head_dim=%d", + total_num_pages, + end_layer - start_layer, + head_num, + head_dim, + ) + + def _compute_page_representations( + self, + layer_id: int, + reqs: torch.Tensor, + seq_lens: torch.Tensor, + start_page, + end_page: torch.Tensor, + k_buffer: torch.Tensor, + ): + if isinstance(start_page, int): + start_page = torch.full_like(end_page, start_page) + + device = k_buffer.device + req_to_token = self.req_to_token_pool.req_to_token + n = reqs.shape[0] + max_pages = int((end_page - start_page).max().item()) + if max_pages <= 0: + return + + pg_off = torch.arange(max_pages, device=device).unsqueeze(0) + pg_id = start_page.unsqueeze(1) + pg_off + pg_mask = pg_id < end_page.unsqueeze(1) + + tok_start = pg_id * self.page_size + tok_off = torch.arange(self.page_size, device=device).view(1, 1, -1) + tok_pos = tok_start.unsqueeze(2) + tok_off + tok_mask = ( + tok_pos + < (tok_start + self.page_size).clamp(max=seq_lens.unsqueeze(1)).unsqueeze(2) + ) & pg_mask.unsqueeze(2) + + phys_tok = req_to_token[ + reqs.view(n, 1, 1).expand(n, max_pages, self.page_size), + tok_pos.clamp(0, req_to_token.shape[1] - 1), + ].clamp(0, k_buffer.shape[0] - 1) + + keys = k_buffer[phys_tok].to(torch.float32) + mask = tok_mask.unsqueeze(-1).unsqueeze(-1) + + page_min = torch.where(mask, keys, torch.full_like(keys, float("inf"))).amin( + dim=2 + ) + page_max = torch.where(mask, keys, torch.full_like(keys, float("-inf"))).amax( + dim=2 + ) + + phys_pg = ( + req_to_token[ + reqs.unsqueeze(1).expand(n, max_pages), + tok_start.clamp(0, req_to_token.shape[1] - 1), + ] + // self.page_size + ) + + idx = pg_mask.nonzero(as_tuple=False) + if idx.numel() == 0: + return + + target_pages = phys_pg[idx[:, 0], idx[:, 1]].clamp( + 0, self.page_k_min[layer_id].shape[0] - 1 + ) + self.page_k_min[layer_id][target_pages] = page_min[idx[:, 0], idx[:, 1]] + self.page_k_max[layer_id][target_pages] = page_max[idx[:, 0], idx[:, 1]] + self.page_valid[layer_id][target_pages] = True + + def _retrieve_page_scores( + self, + layer_id: int, + phys_pages: torch.Tensor, + req_pool_indices: torch.Tensor, + queries: torch.Tensor, + ) -> torch.Tensor: + # Clamp pages to valid storage range + phys_pages_clamped = phys_pages.clamp(0, self.page_k_min[layer_id].shape[0] - 1) + + k_min = self.page_k_min[layer_id][phys_pages_clamped] + k_max = self.page_k_max[layer_id][phys_pages_clamped] + valid_mask = self.page_valid[layer_id][phys_pages_clamped] + # Align query shape to KV heads. + head_dim = k_min.shape[-1] + if queries.dim() == 2: + bs, hidden = queries.shape + if hidden % head_dim != 0: + raise ValueError( + f"Quest query hidden size {hidden} not divisible by head_dim {head_dim}" + ) + q_heads = hidden // head_dim + q = queries.view(bs, q_heads, head_dim) + elif queries.dim() == 3: + q = queries + else: + raise ValueError(f"Unsupported query shape for Quest: {queries.shape}") + + kv_heads = k_min.shape[-2] + q_heads = q.shape[1] + if q_heads != kv_heads: + if q_heads % kv_heads != 0: + raise ValueError( + f"Query heads {q_heads} not divisible by KV heads {kv_heads}" + ) + group = q_heads // kv_heads + # Average grouped query heads to align with KV heads (approximation for MQA/GQA). + q = q.view(q.shape[0], kv_heads, group, head_dim).mean(dim=2) + + q = q.to(k_min.dtype).unsqueeze(1) # [bs, 1, kv_heads, head_dim] + + criticality = torch.where(q >= 0, q * k_max, q * k_min).sum(dim=(2, 3)) + criticality = torch.where( + valid_mask, criticality, torch.full_like(criticality, float("-inf")) + ) + + return criticality diff --git a/test/srt/sparsity/test_knorm_page_algorithm.py b/test/srt/sparsity/test_knorm_page_algorithm.py deleted file mode 100644 index c577393235ea..000000000000 --- a/test/srt/sparsity/test_knorm_page_algorithm.py +++ /dev/null @@ -1,166 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import unittest - -import torch - -from sglang.srt.mem_cache.sparsity.algorithms.knorm_algorithm import KnormPageAlgorithm -from sglang.srt.model_executor.forward_batch_info import ForwardMode - - -class MockConfig: - def __init__(self, compression_ratio=0.2, page_size=64): - self.compression_ratio = compression_ratio - self.page_size = page_size - - -class MockTokenToKVPool: - def __init__(self, num_tokens=1024, num_layers=2, num_heads=8, head_dim=64): - self._k_buffer = { - i: torch.randn( - num_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" - ) - for i in range(num_layers) - } - - def get_key_buffer(self, layer_id): - return self._k_buffer[layer_id] - - -class MockReqToTokenPool: - def __init__(self, max_reqs=32, max_tokens=2048, num_physical_tokens=1024): - self.req_to_token = ( - torch.arange(max_reqs * max_tokens, device="cuda").reshape( - max_reqs, max_tokens - ) - % num_physical_tokens - ) - - -class MockStates: - def __init__(self, max_reqs=32): - self.prompt_lens = torch.zeros(max_reqs, dtype=torch.int64, device="cuda") - self.repr_constructed = torch.zeros(max_reqs, dtype=torch.bool, device="cuda") - self.last_constructed_page = torch.zeros( - max_reqs, dtype=torch.int64, device="cuda" - ) - - -class MockForwardBatch: - def __init__(self, mode=ForwardMode.EXTEND): - self.forward_mode = mode - - -class MockAttnMetadata: - def __init__(self, cache_seqlens): - self.cache_seqlens_int32 = cache_seqlens - - -class TestKnormPageAlgorithm(unittest.TestCase): - def setUp(self): - self.device = torch.device("cuda") - self.config = MockConfig(compression_ratio=0.2, page_size=64) - self.algorithm = KnormPageAlgorithm(self.config, self.device) - self.token_to_kv_pool = MockTokenToKVPool(num_tokens=1024, num_layers=2) - self.req_to_token_pool = MockReqToTokenPool( - max_reqs=8, max_tokens=512, num_physical_tokens=1024 - ) - self.states = MockStates(max_reqs=8) - - def test_initialize_representation_pool(self): - self.algorithm.initialize_representation_pool( - start_layer=0, - end_layer=2, - token_to_kv_pool=self.token_to_kv_pool, - req_to_token_pool=self.req_to_token_pool, - states=self.states, - ) - - self.assertEqual(len(self.algorithm.page_scores), 2) - self.assertIsNotNone(self.algorithm.req_to_token_pool) - self.assertIsNotNone(self.algorithm.states) - - def test_construct_representations(self): - self.algorithm.initialize_representation_pool( - 0, 2, self.token_to_kv_pool, self.req_to_token_pool, self.states - ) - - req_pool_indices = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - seq_lens = torch.tensor([128, 192], dtype=torch.int64, device="cuda") - k_buffer = self.token_to_kv_pool.get_key_buffer(0) - forward_batch = MockForwardBatch(mode=ForwardMode.EXTEND) - - self.states.prompt_lens[req_pool_indices] = seq_lens - - self.algorithm.construct_representations( - layer_id=1, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - k_buffer=k_buffer, - forward_batch=forward_batch, - ) - - self.assertTrue(self.states.repr_constructed[0]) - self.assertTrue(self.states.repr_constructed[1]) - # last_constructed_page stores page count, not token position - # 128 / 64 = 2 pages, 192 / 64 = 3 pages - self.assertEqual(self.states.last_constructed_page[0].item(), 2) - self.assertEqual(self.states.last_constructed_page[1].item(), 3) - - def test_update_representations(self): - self.algorithm.initialize_representation_pool( - 0, 2, self.token_to_kv_pool, self.req_to_token_pool, self.states - ) - - req_pool_indices = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - seq_lens = torch.tensor([192, 256], dtype=torch.int64, device="cuda") - k_buffer = self.token_to_kv_pool.get_key_buffer(0) - forward_batch = MockForwardBatch(mode=ForwardMode.DECODE) - - self.states.repr_constructed[req_pool_indices] = True - # Start from page 2 (was 128 tokens) - self.states.last_constructed_page[req_pool_indices] = torch.tensor( - [2, 2], dtype=torch.int64, device="cuda" - ) - - self.algorithm.update_representations( - layer_id=1, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - k_buffer=k_buffer, - forward_batch=forward_batch, - ) - - # 192 / 64 = 3 pages, 256 / 64 = 4 pages - self.assertEqual(self.states.last_constructed_page[0].item(), 3) - self.assertEqual(self.states.last_constructed_page[1].item(), 4) - - def test_retrieve_topk(self): - self.algorithm.initialize_representation_pool( - 0, 2, self.token_to_kv_pool, self.req_to_token_pool, self.states - ) - - req_pool_indices = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - queries = torch.randn(2, 8, 64, dtype=torch.float16, device="cuda") - cache_seqlens = torch.tensor([512, 640], dtype=torch.int32, device="cuda") - sparse_mask = torch.ones(2, dtype=torch.bool, device="cuda") - attn_metadata = MockAttnMetadata(cache_seqlens=cache_seqlens) - - self.algorithm.page_scores[0] = torch.randn( - 16, 1, dtype=torch.float32, device="cuda" - ) - - out_indices, out_lengths = self.algorithm.retrieve_topk( - queries=queries, - layer_id=0, - req_pool_indices=req_pool_indices, - sparse_mask=sparse_mask, - attn_metadata=attn_metadata, - ) - - self.assertEqual(out_indices.shape[0], 2) - self.assertEqual(out_lengths.shape[0], 2) - self.assertTrue((out_lengths > 0).all()) - - -if __name__ == "__main__": - unittest.main() From 86e3a48735a51bdf3023feaea0bdf96639639b31 Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Wed, 17 Dec 2025 19:03:12 +0800 Subject: [PATCH 5/5] Remove KnormPageAlgorithm --- python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py index 577d4a3c734c..7e05203af03b 100644 --- a/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py +++ b/python/sglang/srt/mem_cache/sparsity/algorithms/__init__.py @@ -3,13 +3,11 @@ BaseSparseAlgorithmImpl, ) from sglang.srt.mem_cache.sparsity.algorithms.deepseek_nsa import DeepSeekNSAAlgorithm -from sglang.srt.mem_cache.sparsity.algorithms.knorm_algorithm import KnormPageAlgorithm from sglang.srt.mem_cache.sparsity.algorithms.quest_algorithm import QuestAlgorithm __all__ = [ "BaseSparseAlgorithm", "BaseSparseAlgorithmImpl", - "KnormPageAlgorithm", "DeepSeekNSAAlgorithm", "QuestAlgorithm", ]