-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Minor Optimizations in Radix Cache and Schedule Batch #6907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5ff1010
1a8fa61
a2ce650
894d1c1
af9a5ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,6 +82,9 @@ def get_height(self, node: TreeNode): | |
| return height | ||
|
|
||
| def write_backup(self, node: TreeNode, write_back=False): | ||
| # Nothing to copy if the node is already evicted. | ||
| if node.value is None: | ||
| return 0 | ||
| host_indices = self.cache_controller.write( | ||
| device_indices=node.value, | ||
| node_id=node.id, | ||
|
|
@@ -104,7 +107,15 @@ def write_backup(self, node: TreeNode, write_back=False): | |
| return len(host_indices) | ||
|
|
||
| def inc_hit_count(self, node: TreeNode): | ||
| if node.backuped or self.cache_controller.write_policy == "write_back": | ||
| # Skip if: | ||
| # - the node is already on host (backuped) | ||
| # - write-back policy is in use | ||
| # - the node has no live device data (evicted) | ||
| if ( | ||
| node.backuped | ||
| or node.evicted | ||
| or self.cache_controller.write_policy == "write_back" | ||
| ): | ||
| return | ||
| node.hit_count += 1 | ||
| if node.hit_count >= self.write_through_threshold: | ||
|
|
@@ -198,9 +209,15 @@ def _evict_backuped(self, node: TreeNode): | |
|
|
||
| def _evict_regular(self, node: TreeNode): | ||
| # evict a node not initiated write to host | ||
| self.cache_controller.mem_pool_device_allocator.free(node.value) | ||
| num_evicted = len(node.value) | ||
| self._delete_leaf(node) | ||
| self.cache_controller.mem_pool_device_allocator.free(node.value) | ||
|
|
||
| self.evictable_size_ -= num_evicted # keep GPU counters in sync | ||
| node.value = None # mark as evicted | ||
|
|
||
| # Remove the node only if it has no (even evicted) children | ||
| if len(node.children) == 0: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be determined already
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree (conditionally) — If evict() guarantees a node is a leaf before calling _delete_leaf, the check is redundant. But if eviction can happen preemptively on non-leaf nodes, this check is necessary. I am not confident in evict() ensuring leaf-ness, would prefer keeping the check. |
||
| self._delete_leaf(node) | ||
| return num_evicted | ||
|
|
||
| def evict_host(self, num_tokens: int): | ||
|
|
@@ -288,7 +305,15 @@ def init_load_back( | |
| len(prefix_indices) == 0 or prefix_indices.is_cuda | ||
| ), "indices of device kV caches should be on GPU" | ||
| if last_node.evicted: | ||
| # No host copy: fall back to recomputation | ||
| if not last_node.backuped: | ||
| # Climb to the nearest ancestor that *does* have live data. | ||
| while last_node.evicted and last_node != self.root_node: | ||
| last_node = last_node.parent | ||
| return last_node, prefix_indices | ||
| # Host copy exists: try to pull it back | ||
| loading_values = self.load_back(last_node, mem_quota) | ||
|
|
||
| if loading_values is not None: | ||
| prefix_indices = ( | ||
| loading_values | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,6 +52,7 @@ def __init__(self, id: Optional[int] = None): | |
| self.value = None | ||
| self.lock_ref = 0 | ||
| self.last_access_time = time.monotonic() | ||
| self.child_key = None | ||
|
|
||
| self.hit_count = 0 | ||
| # indicating the node is loading KV cache from host | ||
|
|
@@ -110,6 +111,7 @@ def __init__( | |
| self.disable = disable | ||
| self.enable_kv_cache_events = enable_kv_cache_events | ||
| self.kv_event_queue = [] | ||
| self.evict_heap = [] | ||
|
|
||
| if self.token_to_kv_pool_allocator: | ||
| self.device = self.token_to_kv_pool_allocator.device | ||
|
|
@@ -134,6 +136,7 @@ def reset(self): | |
| self.evictable_size_ = 0 | ||
| self.protected_size_ = 0 | ||
| self._record_all_cleared_event() | ||
| self.evict_heap.clear() | ||
|
|
||
| def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: | ||
| """Find the matching prefix from the radix tree. | ||
|
|
@@ -264,26 +267,32 @@ def evict(self, num_tokens: int): | |
| if self.disable: | ||
| return | ||
|
|
||
| leaves = self._collect_leaves() | ||
| heapq.heapify(leaves) | ||
|
|
||
| num_evicted = 0 | ||
| while num_evicted < num_tokens and len(leaves): | ||
| x = heapq.heappop(leaves) | ||
|
|
||
| if x == self.root_node: | ||
| break | ||
| if x.lock_ref > 0: | ||
| while self.evict_heap and num_evicted < num_tokens: | ||
| node = heapq.heappop(self.evict_heap) | ||
| if node.lock_ref > 0 or node == self.root_node or node.evicted: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggest to be
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. after looking at the logic more, I think you are referring to
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, there's an implicit assumption that child's last_access_time is always older than parent here because the heap also contains internal nodes. the assumption doesn't seem to hold? should also check if the node is a leaf node, if not, continue. |
||
| continue | ||
|
|
||
| self.token_to_kv_pool_allocator.free(x.value) | ||
| num_evicted += len(x.value) | ||
| self._delete_leaf(x) | ||
|
|
||
| if len(x.parent.children) == 0: | ||
| heapq.heappush(leaves, x.parent) | ||
|
|
||
| self._record_remove_event(x) | ||
| # Free memory and Update state | ||
| node_value_len = len(node.value) | ||
| self.token_to_kv_pool_allocator.free(node.value) | ||
| node.value = None # Mark node as evicted | ||
| # Update evictable size here since value is freed; | ||
| # do not update again in _delete_leaf() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to update node.value to None as it is to be deleted
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. If the node is immediately deleted and garbage-collected, setting node.value = None is redundant. However, if _delete_leaf() doesn’t fully nullify the node (or deletion is conditional), it may be safer to clear explicitly to avoid stale references. |
||
| self.evictable_size_ -= node_value_len | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can keep this in _delete_leaf still which is a clear semantic.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would disagree. The evictable_size_ tracks memory currently freeable. Adjusting it before _delete_leaf() prevents double-decrement bugs if _delete_leaf() also triggers eviction logic or tracking. I would suggest leaving the decrement where it is (immediately after freeing). Moving it to _delete_leaf() risks coupling unrelated concerns (memory vs. tree structure). |
||
| num_evicted += node_value_len | ||
| self._delete_leaf(node) | ||
| self._record_remove_event(node) | ||
|
|
||
| # If parent becomes a new leaf and is evictable, push it to heap | ||
| parent = node.parent | ||
| if ( | ||
| parent is not None | ||
| and parent != self.root_node | ||
| and parent.lock_ref == 0 | ||
| and len(parent.children) == 0 | ||
| ): | ||
| heapq.heappush(self.evict_heap, parent) | ||
|
|
||
| def inc_lock_ref(self, node: TreeNode): | ||
| if self.disable: | ||
|
|
@@ -309,6 +318,8 @@ def dec_lock_ref(self, node: TreeNode): | |
| self.evictable_size_ += len(node.value) | ||
| self.protected_size_ -= len(node.value) | ||
| delta += len(node.value) | ||
| if len(node.children) == 0: | ||
| heapq.heappush(self.evict_heap, node) | ||
| node.lock_ref -= 1 | ||
| node = node.parent | ||
| return delta | ||
|
|
@@ -367,18 +378,24 @@ def _split_node(self, key, child: TreeNode, split_len: int): | |
| new_node.lock_ref = child.lock_ref | ||
| new_node.key = child.key[:split_len] | ||
| new_node.value = child.value[:split_len] | ||
| new_node.child_key = self.get_child_key_fn(key) | ||
| new_node.parent.children[new_node.child_key] = new_node | ||
| child.parent = new_node | ||
| child.key = child.key[split_len:] | ||
| child.child_key = self.get_child_key_fn(child.key) | ||
| new_node.children[child.child_key] = child | ||
| child.parent = new_node | ||
| child.value = child.value[split_len:] | ||
| new_node.parent.children[self.get_child_key_fn(key)] = new_node | ||
|
|
||
| self._record_store_event(new_node) | ||
| self._record_store_event(child) | ||
|
|
||
| return new_node | ||
|
|
||
| def _insert_helper(self, node: TreeNode, key: List, value): | ||
| node.last_access_time = time.monotonic() | ||
| now = time.monotonic() | ||
| node.last_access_time = now | ||
|
|
||
| if len(key) == 0: | ||
| return 0 | ||
|
|
||
|
|
@@ -387,7 +404,7 @@ def _insert_helper(self, node: TreeNode, key: List, value): | |
| total_prefix_length = 0 | ||
| while len(key) > 0 and child_key in node.children.keys(): | ||
| node = node.children[child_key] | ||
| node.last_access_time = time.monotonic() | ||
| node.last_access_time = now | ||
| prefix_len = self.key_match_fn(node.key, key) | ||
| total_prefix_length += prefix_len | ||
| key = key[prefix_len:] | ||
|
|
@@ -405,9 +422,12 @@ def _insert_helper(self, node: TreeNode, key: List, value): | |
| new_node.parent = node | ||
| new_node.key = key | ||
| new_node.value = value | ||
| new_node.child_key = child_key | ||
| node.children[child_key] = new_node | ||
| self.evictable_size_ += len(value) | ||
| self._record_store_event(new_node) | ||
| if new_node.lock_ref == 0: | ||
| heapq.heappush(self.evict_heap, new_node) | ||
| return total_prefix_length | ||
|
|
||
| def _print_helper(self, node: TreeNode, indent: int): | ||
|
|
@@ -429,11 +449,8 @@ def _print_helper(self, node: TreeNode, indent: int): | |
| ), f"{key=}, {self.get_child_key_fn(child.key)=}" | ||
|
|
||
| def _delete_leaf(self, node): | ||
| for k, v in node.parent.children.items(): | ||
| if v == node: | ||
| break | ||
| del node.parent.children[k] | ||
| self.evictable_size_ -= len(node.key) | ||
| if node.child_key in node.parent.children: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this if is not needed, instead it should be an assert, after checking |
||
| del node.parent.children[node.child_key] | ||
|
|
||
| def _total_size_helper(self): | ||
| total_size = 0 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class MockReqToTokenPool: | ||
| def __init__(self): | ||
| self.released = [] | ||
|
|
||
| def free(self, req_pool_idx): | ||
| self.released.append(req_pool_idx) | ||
|
|
||
|
|
||
| class MockKVAllocator: | ||
| def __init__(self): | ||
| self.freed = [] | ||
| self.device = torch.device("cpu") | ||
|
|
||
| def free(self, tensor): | ||
| self.freed.append(tensor) | ||
|
|
||
|
|
||
| class RadixCacheTest(unittest.TestCase): | ||
| def setUp(self): | ||
| from sglang.srt.mem_cache.radix_cache import RadixCache | ||
|
|
||
| self.token_pool = MockReqToTokenPool() | ||
| self.kv_pool = MockKVAllocator() | ||
| self.cache = RadixCache( | ||
| req_to_token_pool=self.token_pool, | ||
| token_to_kv_pool_allocator=self.kv_pool, | ||
| page_size=1, | ||
| disable=False, | ||
| enable_kv_cache_events=False, | ||
| ) | ||
|
|
||
| def test_insert_and_match(self): | ||
| key = [1, 2, 3, 4] | ||
| value = torch.tensor(key, dtype=torch.int64) | ||
| self.cache.insert(key, value) | ||
|
|
||
| matched, last_node = self.cache.match_prefix(key) | ||
| self.assertTrue(torch.equal(matched, value)) | ||
| self.assertEqual(last_node.key, key) | ||
|
|
||
| def test_evict_removes_evicted_node(self): | ||
| key = [10, 20, 30] | ||
| value = torch.tensor(key, dtype=torch.int64) | ||
| self.cache.insert(key, value.clone()) | ||
|
|
||
| # Ensure evictable size reflects insertion | ||
| self.assertEqual(self.cache.evictable_size(), len(value)) | ||
|
|
||
| self.cache.evict(len(value)) | ||
| # After eviction, evictable size should drop | ||
| self.assertEqual(self.cache.evictable_size(), 0) | ||
|
|
||
| # All memory should be marked freed | ||
| self.assertTrue(any(torch.equal(t, value) for t in self.kv_pool.freed)) | ||
|
|
||
| def test_lock_ref_prevents_eviction(self): | ||
| key = [100, 101, 102] | ||
| value = torch.tensor(key, dtype=torch.int64) | ||
| self.cache.insert(key, value) | ||
|
|
||
| # Get the inserted node | ||
| _, node = self.cache.match_prefix(key) | ||
|
|
||
| self.cache.inc_lock_ref(node) | ||
| self.cache.evict(len(value)) | ||
|
|
||
| # Node should not be evicted | ||
| self.assertIsNotNone(node.value) | ||
| self.assertEqual(self.cache.evictable_size(), 0) | ||
|
|
||
| self.cache.dec_lock_ref(node) | ||
| self.cache.evict(len(value)) | ||
|
|
||
| # Now it should be evicted | ||
| self.assertIsNone(node.value) | ||
| self.assertEqual(self.cache.evictable_size(), 0) | ||
|
|
||
| def test_evict_heap_promotion_of_parent(self): | ||
| key1 = [1, 2] | ||
| key2 = [1, 2, 3] | ||
|
|
||
| val1 = torch.tensor(key1, dtype=torch.int64) | ||
| val2 = torch.tensor(key2, dtype=torch.int64) | ||
|
|
||
| self.cache.insert(key1, val1) | ||
| self.cache.insert(key2, val2) | ||
|
|
||
| _, child_node = self.cache.match_prefix(key2) | ||
| parent_node = child_node.parent | ||
|
|
||
| # Lock the child (which also protects the parent) | ||
| self.cache.inc_lock_ref(child_node) | ||
|
|
||
| expected_size = len(val1) + 1 # [1,2] + [3] | ||
| self.assertEqual(self.cache.protected_size(), expected_size) | ||
| self.assertEqual(self.cache.evictable_size(), 0) | ||
|
|
||
| # Unlock once to revert lock_ref back to 0 | ||
| self.cache.dec_lock_ref(child_node) | ||
|
|
||
| self.assertEqual(self.cache.protected_size(), 0) | ||
| self.assertEqual(self.cache.evictable_size(), expected_size) | ||
|
|
||
| # Evict all 3 tokens ([3] and [1, 2]) | ||
| self.cache.evict(expected_size) | ||
|
|
||
| # Both nodes are now evicted | ||
| self.assertIsNone(child_node.value) | ||
| self.assertIsNone(parent_node.value) | ||
|
|
||
| # No evictable bytes remain | ||
| self.assertEqual(self.cache.evictable_size(), 0) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no need for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is necessary to prevent redundant or incorrect backup logic: