Skip to content

Comments

Unify memory management across (overlap, non-overlap) x (page>=1) x (spec, non-spec, spec v2) x (retract, finished)#12224

Merged
hnyls2002 merged 51 commits intomainfrom
lsyin/committed-kv-len
Nov 10, 2025
Merged

Unify memory management across (overlap, non-overlap) x (page>=1) x (spec, non-spec, spec v2) x (retract, finished)#12224
hnyls2002 merged 51 commits intomainfrom
lsyin/committed-kv-len

Conversation

@hnyls2002
Copy link
Collaborator

@hnyls2002 hnyls2002 commented Oct 27, 2025

This PR replaces the old approach to releasing KV cache, which relied on len(self.origin_input_ids) + max(len(self.output_ids) - 1, 0) to determine the KV length.

That approach is brittle with overlap scheduling and with multiple finish paths (normal completion, disaggregation-decode finish, retract, abort). With speculative decoding, we also perform over-allocation, which makes allocation/freeing logic even more error-prone.

This PR introduces two explicit notions for a request’s KV cache state

  • KV committed len: number of KV token slots that have actually been written with real tokens (i.e., valid KV).
  • KV allocated len: number of KV token slots reserved from the allocator (page-aligned), regardless of whether they have been populated.

We tightly couple the allocation steps with updates to kv_committed_len and kv_allocated_len, so these fields faithfully reflect request-level memory usage at all times.

The previous memory allocation involves

for req in reqs_to_process:
    if req.is_finished():  # overlap scheduling may produce “extra” KV allocations
        # Previously required complex and error-prone offset math
        deallocate_extra_tokens(req)
        continue

    check_finish(req)

    if req.is_finished():
        # Free KV cache using the inferred length:
        # len(origin_input_ids) + max(len(output_ids) - 1, 0)
        deallocate_inferred_kv_length(req)

We now replace the ad-hoc “extra token” arithmetic with the recorded kv_committed_len and kv_allocated_len. Consumers no longer need to infer offsets: they simply consult these fields.

To ensure correctness:

  • We only free all the memory resources in one place.
  • After we decide a request is finished (check_finish, retract, abort…), we immediately remove all its KV cache.
  • We shall never allocate any memory resources after we have found that a request is finished.

Future TODOs (cc @cctry)

  • Resolve the possible data race when we enable overlap scheduler: the just-released req_to_token_idx is reused, and the page mapping changes during the current forwarding cycle.
  • Further decouple the memory deallocation logic from the prefix cache logic.

@hnyls2002 hnyls2002 changed the title Unify memory management across (overlap, non-overlap) x (page>=1) x (spec, non-spec, spec v2) x (retract, finished) [WIP] Unify memory management across (overlap, non-overlap) x (page>=1) x (spec, non-spec, spec v2) x (retract, finished) Oct 27, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hnyls2002, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors and unifies the Key-Value (KV) cache memory management logic across various operational modes and states within the system. By introducing dedicated request-level attributes for tracking allocated, committed, and freed KV cache lengths, the changes aim to standardize memory handling, improve accuracy, and enhance the overall robustness of memory management during request extension, decoding, and completion.

Highlights

  • Request-level KV Cache Tracking: Introduced new attributes kv_committed_len, kv_allocated_len, and kv_freed_len to the Req object to enable more granular and unified memory management for Key-Value (KV) cache across different request states.
  • Memory Allocation Updates: The newly added KV cache length fields are now consistently updated during critical phases such as request extension (prepare_for_extend) and token decoding (prepare_for_decode), ensuring accurate tracking of allocated and committed memory.
  • Memory Deallocation Logic Refinement: The cache_finished_req function has been updated to leverage req.kv_allocated_len for precise memory deallocation upon request completion, and req.kv_freed_len is now used to track freed memory, including an assertion to prevent incorrect freeing.
  • Consistency and Debugging Assertions: An assertion was added in process_batch_result_decode to verify that kv_allocated_len matches kv_freed_len when handling the freeing of delayed tokens, enhancing the robustness and debuggability of the memory management system.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request unifies request-level memory management by introducing kv_committed_len, kv_allocated_len, and kv_freed_len to the Req class. These fields are consistently updated during prefill and decode stages. The logic for freeing memory for finished requests is now centralized in cache_finished_req, which uses req.kv_allocated_len as the source of truth. This is a good refactoring that improves robustness. I have one minor suggestion to remove a stale comment that became misleading after the changes.

@hnyls2002 hnyls2002 requested a review from ByronHsu as a code owner October 28, 2025 14:56
@xiezhq-hermann xiezhq-hermann self-assigned this Oct 31, 2025
@github-actions github-actions bot added speculative-decoding hicache Hierarchical Caching for SGLang labels Nov 10, 2025
@hnyls2002 hnyls2002 changed the title [WIP] Unify memory management across (overlap, non-overlap) x (page>=1) x (spec, non-spec, spec v2) x (retract, finished) Unify memory management across (overlap, non-overlap) x (page>=1) x (spec, non-spec, spec v2) x (retract, finished) Nov 10, 2025
@hnyls2002 hnyls2002 merged commit 665416f into main Nov 10, 2025
93 of 138 checks passed
@hnyls2002 hnyls2002 deleted the lsyin/committed-kv-len branch November 10, 2025 18:56
@cctry
Copy link
Collaborator

cctry commented Nov 10, 2025

Resolve the possible data race when we enable overlap scheduler: the just-released req_to_token_idx is reused, and the page mapping changes during the current forwarding cycle

There is another thing we can do:
instead of assigning out_cache_loc to req_to_token in prepare_for_decode, we can assign it in result processing.
Any concern of this approach for spec dec?

@hnyls2002
Copy link
Collaborator Author

@cctry Not sure, we can discuss offline.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants