Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import threading
from enum import Enum, auto
from http import HTTPStatus
from itertools import chain
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -1145,13 +1146,13 @@ def prepare_for_extend(self):
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
input_ids_tensor = torch.tensor(
list(chain.from_iterable(input_ids)), dtype=torch.int64
).to(self.device, non_blocking=True)
seq_lens_tensor = torch.as_tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
prefix_lens_tensor = torch.tensor(
prefix_lens_tensor = torch.as_tensor(
prefix_lens, dtype=torch.int64, device=self.device
)
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
Expand Down
31 changes: 28 additions & 3 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def get_height(self, node: TreeNode):
return height

def write_backup(self, node: TreeNode, write_back=False):
# Nothing to copy if the node is already evicted.
if node.value is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need for this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is necessary to prevent redundant or incorrect backup logic:

  • If node.value is None, it’s already evicted. There's literally nothing on device to copy to host.
  • Continuing without this check would likely lead to:
    • Unnecessary allocation on host
    • Potential crash in self.cache_controller.write() expecting valid device indices

return 0
host_indices = self.cache_controller.write(
device_indices=node.value,
node_id=node.id,
Expand All @@ -104,7 +107,15 @@ def write_backup(self, node: TreeNode, write_back=False):
return len(host_indices)

def inc_hit_count(self, node: TreeNode):
if node.backuped or self.cache_controller.write_policy == "write_back":
# Skip if:
# - the node is already on host (backuped)
# - write-back policy is in use
# - the node has no live device data (evicted)
if (
node.backuped
or node.evicted
or self.cache_controller.write_policy == "write_back"
):
return
node.hit_count += 1
if node.hit_count >= self.write_through_threshold:
Expand Down Expand Up @@ -198,9 +209,15 @@ def _evict_backuped(self, node: TreeNode):

def _evict_regular(self, node: TreeNode):
# evict a node not initiated write to host
self.cache_controller.mem_pool_device_allocator.free(node.value)
num_evicted = len(node.value)
self._delete_leaf(node)
self.cache_controller.mem_pool_device_allocator.free(node.value)

self.evictable_size_ -= num_evicted # keep GPU counters in sync
node.value = None # mark as evicted

# Remove the node only if it has no (even evicted) children
if len(node.children) == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be determined already

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree (conditionally) — If evict() guarantees a node is a leaf before calling _delete_leaf, the check is redundant. But if eviction can happen preemptively on non-leaf nodes, this check is necessary. I am not confident in evict() ensuring leaf-ness, would prefer keeping the check.

self._delete_leaf(node)
return num_evicted

def evict_host(self, num_tokens: int):
Expand Down Expand Up @@ -288,7 +305,15 @@ def init_load_back(
len(prefix_indices) == 0 or prefix_indices.is_cuda
), "indices of device kV caches should be on GPU"
if last_node.evicted:
# No host copy: fall back to recomputation
if not last_node.backuped:
# Climb to the nearest ancestor that *does* have live data.
while last_node.evicted and last_node != self.root_node:
last_node = last_node.parent
return last_node, prefix_indices
# Host copy exists: try to pull it back
loading_values = self.load_back(last_node, mem_quota)

if loading_values is not None:
prefix_indices = (
loading_values
Expand Down
67 changes: 42 additions & 25 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, id: Optional[int] = None):
self.value = None
self.lock_ref = 0
self.last_access_time = time.monotonic()
self.child_key = None

self.hit_count = 0
# indicating the node is loading KV cache from host
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(
self.disable = disable
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue = []
self.evict_heap = []

if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
Expand All @@ -134,6 +136,7 @@ def reset(self):
self.evictable_size_ = 0
self.protected_size_ = 0
self._record_all_cleared_event()
self.evict_heap.clear()

def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree.
Expand Down Expand Up @@ -264,26 +267,32 @@ def evict(self, num_tokens: int):
if self.disable:
return

leaves = self._collect_leaves()
heapq.heapify(leaves)

num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)

if x == self.root_node:
break
if x.lock_ref > 0:
while self.evict_heap and num_evicted < num_tokens:
node = heapq.heappop(self.evict_heap)
if node.lock_ref > 0 or node == self.root_node or node.evicted:
Copy link
Collaborator

@hanming-lu hanming-lu Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to be

if node.lock_ref > 0:
  continue
assert node != self.root_node

node.evicted is undefined.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after looking at the logic more, I think you are referring to node.value is None here instead of node.evicted

Copy link
Collaborator

@hanming-lu hanming-lu Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, there's an implicit assumption that child's last_access_time is always older than parent here because the heap also contains internal nodes. the assumption doesn't seem to hold? should also check if the node is a leaf node, if not, continue.

continue

self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)

if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)

self._record_remove_event(x)
# Free memory and Update state
node_value_len = len(node.value)
self.token_to_kv_pool_allocator.free(node.value)
node.value = None # Mark node as evicted
# Update evictable size here since value is freed;
# do not update again in _delete_leaf()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to update node.value to None as it is to be deleted

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. If the node is immediately deleted and garbage-collected, setting node.value = None is redundant. However, if _delete_leaf() doesn’t fully nullify the node (or deletion is conditional), it may be safer to clear explicitly to avoid stale references.

self.evictable_size_ -= node_value_len
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep this in _delete_leaf still which is a clear semantic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would disagree. The evictable_size_ tracks memory currently freeable. Adjusting it before _delete_leaf() prevents double-decrement bugs if _delete_leaf() also triggers eviction logic or tracking.

I would suggest leaving the decrement where it is (immediately after freeing). Moving it to _delete_leaf() risks coupling unrelated concerns (memory vs. tree structure).

num_evicted += node_value_len
self._delete_leaf(node)
self._record_remove_event(node)

# If parent becomes a new leaf and is evictable, push it to heap
parent = node.parent
if (
parent is not None
and parent != self.root_node
and parent.lock_ref == 0
and len(parent.children) == 0
):
heapq.heappush(self.evict_heap, parent)

def inc_lock_ref(self, node: TreeNode):
if self.disable:
Expand All @@ -309,6 +318,8 @@ def dec_lock_ref(self, node: TreeNode):
self.evictable_size_ += len(node.value)
self.protected_size_ -= len(node.value)
delta += len(node.value)
if len(node.children) == 0:
heapq.heappush(self.evict_heap, node)
node.lock_ref -= 1
node = node.parent
return delta
Expand Down Expand Up @@ -367,18 +378,24 @@ def _split_node(self, key, child: TreeNode, split_len: int):
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len]
new_node.child_key = self.get_child_key_fn(key)
new_node.parent.children[new_node.child_key] = new_node
child.parent = new_node
child.key = child.key[split_len:]
child.child_key = self.get_child_key_fn(child.key)
new_node.children[child.child_key] = child
child.parent = new_node
child.value = child.value[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node

self._record_store_event(new_node)
self._record_store_event(child)

return new_node

def _insert_helper(self, node: TreeNode, key: List, value):
node.last_access_time = time.monotonic()
now = time.monotonic()
node.last_access_time = now

if len(key) == 0:
return 0

Expand All @@ -387,7 +404,7 @@ def _insert_helper(self, node: TreeNode, key: List, value):
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
node.last_access_time = now
prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len
key = key[prefix_len:]
Expand All @@ -405,9 +422,12 @@ def _insert_helper(self, node: TreeNode, key: List, value):
new_node.parent = node
new_node.key = key
new_node.value = value
new_node.child_key = child_key
node.children[child_key] = new_node
self.evictable_size_ += len(value)
self._record_store_event(new_node)
if new_node.lock_ref == 0:
heapq.heappush(self.evict_heap, new_node)
return total_prefix_length

def _print_helper(self, node: TreeNode, indent: int):
Expand All @@ -429,11 +449,8 @@ def _print_helper(self, node: TreeNode, indent: int):
), f"{key=}, {self.get_child_key_fn(child.key)=}"

def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.evictable_size_ -= len(node.key)
if node.child_key in node.parent.children:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this if is not needed, instead it should be an assert, after checking node.value is None properly.

del node.parent.children[node.child_key]

def _total_size_helper(self):
total_size = 0
Expand Down
121 changes: 121 additions & 0 deletions test/srt/test_radix_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import unittest

import torch


class MockReqToTokenPool:
def __init__(self):
self.released = []

def free(self, req_pool_idx):
self.released.append(req_pool_idx)


class MockKVAllocator:
def __init__(self):
self.freed = []
self.device = torch.device("cpu")

def free(self, tensor):
self.freed.append(tensor)


class RadixCacheTest(unittest.TestCase):
def setUp(self):
from sglang.srt.mem_cache.radix_cache import RadixCache

self.token_pool = MockReqToTokenPool()
self.kv_pool = MockKVAllocator()
self.cache = RadixCache(
req_to_token_pool=self.token_pool,
token_to_kv_pool_allocator=self.kv_pool,
page_size=1,
disable=False,
enable_kv_cache_events=False,
)

def test_insert_and_match(self):
key = [1, 2, 3, 4]
value = torch.tensor(key, dtype=torch.int64)
self.cache.insert(key, value)

matched, last_node = self.cache.match_prefix(key)
self.assertTrue(torch.equal(matched, value))
self.assertEqual(last_node.key, key)

def test_evict_removes_evicted_node(self):
key = [10, 20, 30]
value = torch.tensor(key, dtype=torch.int64)
self.cache.insert(key, value.clone())

# Ensure evictable size reflects insertion
self.assertEqual(self.cache.evictable_size(), len(value))

self.cache.evict(len(value))
# After eviction, evictable size should drop
self.assertEqual(self.cache.evictable_size(), 0)

# All memory should be marked freed
self.assertTrue(any(torch.equal(t, value) for t in self.kv_pool.freed))

def test_lock_ref_prevents_eviction(self):
key = [100, 101, 102]
value = torch.tensor(key, dtype=torch.int64)
self.cache.insert(key, value)

# Get the inserted node
_, node = self.cache.match_prefix(key)

self.cache.inc_lock_ref(node)
self.cache.evict(len(value))

# Node should not be evicted
self.assertIsNotNone(node.value)
self.assertEqual(self.cache.evictable_size(), 0)

self.cache.dec_lock_ref(node)
self.cache.evict(len(value))

# Now it should be evicted
self.assertIsNone(node.value)
self.assertEqual(self.cache.evictable_size(), 0)

def test_evict_heap_promotion_of_parent(self):
key1 = [1, 2]
key2 = [1, 2, 3]

val1 = torch.tensor(key1, dtype=torch.int64)
val2 = torch.tensor(key2, dtype=torch.int64)

self.cache.insert(key1, val1)
self.cache.insert(key2, val2)

_, child_node = self.cache.match_prefix(key2)
parent_node = child_node.parent

# Lock the child (which also protects the parent)
self.cache.inc_lock_ref(child_node)

expected_size = len(val1) + 1 # [1,2] + [3]
self.assertEqual(self.cache.protected_size(), expected_size)
self.assertEqual(self.cache.evictable_size(), 0)

# Unlock once to revert lock_ref back to 0
self.cache.dec_lock_ref(child_node)

self.assertEqual(self.cache.protected_size(), 0)
self.assertEqual(self.cache.evictable_size(), expected_size)

# Evict all 3 tokens ([3] and [1, 2])
self.cache.evict(expected_size)

# Both nodes are now evicted
self.assertIsNone(child_node.value)
self.assertIsNone(parent_node.value)

# No evictable bytes remain
self.assertEqual(self.cache.evictable_size(), 0)


if __name__ == "__main__":
unittest.main()
Loading