[feat] support in-flight weight update#10071
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @ShawnY112358, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a critical feature for continuous operation of the SGLang engine by enabling model parameter updates without requiring a service interruption. This enhancement is particularly beneficial for agent reinforcement learning training, where seamless synchronization of model parameters is essential to avoid disrupting ongoing inference services. The changes ensure that the system can update its underlying model weights dynamically, maintaining responsiveness and preventing errors that previously occurred during update-related pauses.
Highlights
- Online Weight Updates: Introduced the capability to update model weights online without interrupting live inference, addressing issues where pausing the engine for updates caused runtime errors in agent training environments.
- Scheduler Integration for Weight Swaps: Modified the scheduler to perform weight swaps between two forward passes of a running batch. If KV-cache flushing is required, the scheduler now retracts the entire running batch, flushes the cache, and then re-prefills and resumes generation.
- API and Internal Synchronization: Added an 'online' flag to weight update request inputs (from disk, distributed, and tensor) and implemented synchronization mechanisms using threading events to ensure safe online updates within the
tp_worker_overlap_thread. - Enhanced Test Coverage: Added new unit tests specifically for online weight updates across disk, distributed, and tensor update methods, including scenarios with concurrent decode requests to validate non-interruption.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces a significant feature: online model weight updates without pausing the inference engine. The implementation appears robust, correctly handling synchronization between the scheduler and the model worker using a threading event to prevent race conditions during forward passes. The new interrupt_generation method in the scheduler, which retracts running requests to the waiting queue before flushing the cache, is a key part of this and seems well-designed. The addition of comprehensive tests for this new online update functionality is commendable. I've found one issue in the test code that needs to be addressed.
f77552e to
7fa024c
Compare
It is quite strange. Which version of the parameter do you want to use? I do not think that this need is valid (correct me if I'm wrong). Also, are you sure that the token can be correctly sampled during the refit? For example, given a 50 - layer model, 1~24 layers are the old version, the others are the new version. Could this give you a correct sample token? |
In RL scenarios, our algorithm supports partial rollout, where different segments of response can be inferred using different checkpoints. We aim to update the inference engine parameters promptly. Using writer_lock would require waiting for ongoing inference requests to complete, creating idle periods. In contrast, online parameter updates can maximize inference resource utilization. |
| # Whether to update weights online | ||
| online: bool = False | ||
| # Whether to wait for all ongoing inference requests to complete before updating | ||
| # parameters, or interrupt them and update parameters directly. |
There was a problem hiding this comment.
or abort and return all the unfinished requests, then update parameters.
There was a problem hiding this comment.
or abort and return all the unfinished requests, then update parameters.
When non_blocking is set to True, the running request does not return immediately. Instead, it remains in the inference engine while the scheduler's event loop pauses decoding. Inference resumes only after the parameter synchronization is complete.
There was a problem hiding this comment.
don't we have 3 cases now?
- wait for all requests in queue and update
- cancel current requests with partial completion and update
- keep current requests, update inbetween forward, and update.
Can we still have all the options available in this struct?
There was a problem hiding this comment.
- Simply set non_blocking=False.
- Call pause_generation first, then call the update_weights API—this approach is independent of the non_blocking parameter. Of course, if you use the pause_generation API introduced in this PR, you’ll need to set abort_all=True.
- First call pause_generation with abort_all=False, then call update_weights with non_blocking=True. If you want to implement a single-call non-blocking weight synchronization, you wouldn’t need to call pause_generation at all—just call update_weights directly with non_blocking=True.
There was a problem hiding this comment.
In this sense:
- Set
non_blocking= False, the weights refit will wait till all the remaining requests finish, then begin the update. - First
pause_generationwithabort_all=True, then all the requests will be returned. Then call the update weights. - Directly call
update_weightswithnon_blocking=True, then the engine will pause decoding, wait for the parameters to be updated, and then continue decoding with the old KV cache immediately.
Am I understanding these right? If so, could we add these three conditions and usage to the comments to let users know?
BTW, what shall happen if pause_generation but abort_all=False?
| # Optional format specification for loading | ||
| flattened_bucket_meta: Optional[dict] = None | ||
| # Whether to wait for all ongoing inference requests to complete before updating | ||
| # parameters, or interrupt them and update parameters directly. |
There was a problem hiding this comment.
as above, change it
| # Whether to wait for all ongoing inference requests to complete before updating | ||
| # parameters, or interrupt them and update parameters directly. |
| def pause_generation(self): | ||
| response = requests.post( | ||
| self.base_url + "/pause_generation", | ||
| json={"abort_all": False, "retract_all": True}, | ||
| ) | ||
| ret = response.json() | ||
| return ret | ||
|
|
||
| def continue_generation(self): | ||
| response = requests.post( | ||
| self.base_url + "/continue_generation", | ||
| json={}, | ||
| ) | ||
| ret = response.json() | ||
| return ret |
There was a problem hiding this comment.
Sorry for not catching the development of these APIs for a long time. Just pondering what these two APIs, "pause" and "continue" generation mean in SGLang. What's the difference of abort_all? In my understanding, the pause and continue don't return the unfinished requests to the user. It just stops generation, then maybe we update the weights but not flush KV cache, then continue generation?
There was a problem hiding this comment.
In the old API, pause_generation would abort all running requests and return immediately, and no new requests would be processed afterward. #7419
We now want pause_generation to not abort requests, but instead only pause the decoding process of currently running requests. To achieve this, we introduced the abort_all parameter:
When abort_all=True, pause_generation behaves as before—aborting all running requests.
When abort_all=False, it pauses the inference process in the scheduler's event loop without aborting any requests. The requests remain in either the running_queue or the waiting_queue, preserving their state for potential resumption later.
There was a problem hiding this comment.
We cannot have the model updating parameters during forward passes. So in this unit test, we first pause the scheduler's event loop using a pause_generation request, execute non-blocking parameter updates, and then resume the scheduler's event loop.
Users have the flexibility to decide whether to flush the KV cache according to their specific requirements, and can resume inference by calling the continue_generation API after the cache has been flushed.(Note that if users want to flush the KV cache, users need to retract all running batches using the retract_all parameter within the pause_generation process to force all requests to re-prefill after resuming inference.otherwise, this will result in an error)
There was a problem hiding this comment.
Note that if users want to flush the KV cache, users need to retract all running batches using the retract_all parameter within the pause_generation process to force all requests to re-prefill after resuming inference.otherwise, this will result in an error
Great, could you add this in the comment of the change. This shall be important to users.
|
Will review it soon. |
|
sorry, busy with my stuff at school. Ping me if not reply in 24h. |
There was a problem hiding this comment.
Thanks for implementing this! This is a feature that's we are very excited about. However, I have the following concerns about the interaction design.
Instead of the pause - update - continue; I would like a single call of update_nonblock:
ret = self.run_update_weights(named_tensors, non_blocking=True, flush_cache=False)
compared to what's in the current PR
ret = self.pause_generation()
ret = self.run_update_weights(named_tensors, non_blocking=True)
self.assertTrue(ret["success"])
ret = self.continue_generation()
- The key of this feature for me is to enable a true non-blocking weight update that is completely invisible to the inference caller. inference caller will not be aware of underlying updates, beyond metadata being returned; and this allows a weight update caller to work completely independently from the service caller.
From what I see it seems possible to just send the _pause_engine signal directly to the scheduler through weight_update call without even changing the pause_engine API? Obviously I don't know the overlap scheduler well enough, but handling this seems possible.
Following that "invisible" paradigm, the retract_all logic will need to be implicitly handled given the flush_cache flag, and we have one less failure point for the user.
- non_blocking + flush_cache: pause generation, flush_cache, weight_update, send all current running requests back to prefill stage, continue
- non_blocking only: pause generation, weight_update, continue.
Personally I'm not sure how much people like option 2 --- You are now left with a dirty cache without knowing which version of checkpoint they came from, do we have an actual use case that would prefer this?
Let me know what you think
| # Whether to update weights online | ||
| online: bool = False | ||
| # Whether to wait for all ongoing inference requests to complete before updating | ||
| # parameters, or interrupt them and update parameters directly. |
There was a problem hiding this comment.
don't we have 3 cases now?
- wait for all requests in queue and update
- cancel current requests with partial completion and update
- keep current requests, update inbetween forward, and update.
Can we still have all the options available in this struct?
Question 1: Why not implement a single-call, non-blocking update? I initially designed it this way for the following reasons:
That said, I agree that the single-call, non-blocking update you proposed is also very reasonable and more user-friendly. I support both implementation approaches. If we go with the single-call design, after merging this PR, we would only need to enhance the pause_generation implementation on top of the community’s baseline to satisfy the requirement mentioned in point #2 above. |
Question 2: Regarding "dirty" KV cache Some research has shown that in partial rollouts, RL training can still converge properly even without flushing the KV cache. Our recent experiments have also observed similar findings. Not flushing the cache avoids the computational overhead of re-running the prefill phase for cached prompts, which can provide a noticeable speed advantage—especially when weight synchronization happens frequently. |
ding~ |
Thanks! Will reply today! |
|
I agree with this discussion. Let me set these APIs in this, would you guys agree: @JD-ETH @ShawnY112358 Let's rename
ret = self.run_update_weights(named_tensors, block_generation=False)
ret = self.pause_generation(abort_all=True)
ret = self.run_update_weights(named_tensors, block_generation=False/True)Note that, to continue rollout these unfinished requests, we shall manually send these requests back to the server. Simply calling the
ret = self.run_update_weights(named_tensors, block_generation=True, flush_cache=False)
ret = self.run_update_weights(named_tensors, block_generation=True, flush_cache=True)Would this be clearer? |
|
final minor nits |
| @dataclass | ||
| class PauseGenerationReqInput(BaseReq): | ||
| """ | ||
| abort: Abort all requests currently being processed. |
There was a problem hiding this comment.
Abort all requests currently being processed.
->
Abort and return all requests currently being processed.
| in_place: Pause the scheduler's event_loop from performing inference; | ||
| only non-inference requests (e.g., control commands) will be handled. | ||
| Note: In 'inplace' mode, flush_cache will fail if there are any requests | ||
| in the running_batch. | ||
|
|
||
| retract: Pause the scheduler's event loop from performing inference; | ||
| only non-inference requests will be handled, and all currently running | ||
| requests will be retracted back to the waiting_queue. |
There was a problem hiding this comment.
What will happen in in_place? The requests in the engine will be paused and stay in the event_loop, then continue generation after ContinueGenerationReqInput with the old kv cache?
Also, for retract, the requests in the engine will be paused and removed from the event_loop into waiting_queue, and the KV cache will be flushed and recomputed after ContinueGenerationReqInput?
If my understanding is correct, please make a more adequate description of this PR.
At the same time, it seems that in_place and retract are counterparts; should we only keep one parameter like retract=True?
There was a problem hiding this comment.
Discussed this with @JD-ETH . Currently, pause only supports these three modes. If we use the abort and retract parameters for control, the retract parameter becomes meaningless when abort=True.
There was a problem hiding this comment.
Please add these details in the comments:
What will happen in in_place? The requests in the engine will be paused and stay in the event_loop, then continue generation after ContinueGenerationReqInput with the old kv cache?
Also, for retract, the requests in the engine will be paused and removed from the event_loop into waiting_queue, and the KV cache will be flushed and recomputed after ContinueGenerationReqInput?
If my understanding is correct, please make a more adequate description of this PR.
| @dataclass | ||
| class PauseGenerationReqInput(BaseReq): | ||
| """ | ||
| abort: Abort all requests currently being processed. |
There was a problem hiding this comment.
Sorry for my late reply, but I'm just not sure we should put the abort here.
If we have an abort parameter here, the abort_request and pause_generation seem to be intersecting. This makes the API ambiguous.
Could we only leave the pause to do the pause, and let users choose whether to recompute the KV cache?
There was a problem hiding this comment.
I see that the current pause API can do the abort_all, which feels a bit strange to me. Let me discuss with the team.
| ret = self.pause_generation(mode) | ||
| ret = self.run_update_weights( | ||
| new_model_path, flush_cache=mode == "retract" | ||
| ) |
There was a problem hiding this comment.
Indeed, we cannot validate whether the KV cache is indeed recomputed. We can leave a comment here, stating that “we can not validate whether the KV cache is indeed recomputed, we just trust the retract parameter.”
There was a problem hiding this comment.
I didn't quite understand your point. Here, we just want to clarify that if the user wants to flush the cache, they must call pause_generation in retract mode.
| ] | ||
|
|
||
| for tp_size, dp_size, model_name, backend in test_suits: | ||
| modes = ["in_place", "retract"] |
There was a problem hiding this comment.
modes = ["in_place", "retract"] -> pause_generation_modes = ["in_place", "retract"].
|
|
||
| for tp_size, dp_size, model_name, backend in test_suits: | ||
| modes = ["in_place", "retract"] | ||
| for mode in modes: |
There was a problem hiding this comment.
for pause_generation_mode in pause_generation_modes:
| model_state_dict_shapes[model_name], | ||
| truncate_size, | ||
| checking_parameters, | ||
| mode=mode, |
There was a problem hiding this comment.
pause_generation_mode, please do not use pause_generation_mode=pause_generation_mode.
| response = requests.post( | ||
| self.base_url + "/generate", | ||
| json={ | ||
| "text": "The capital of France is", |
There was a problem hiding this comment.
change the text also.
| response = requests.post( | ||
| url + "/generate", | ||
| json={ | ||
| "text": "The capital of France is", |
There was a problem hiding this comment.
change the text also.
|
|
||
| def test_update_weights(self): | ||
| modes = ["in_place", "retract"] | ||
| for mode in modes: |
There was a problem hiding this comment.
also, please do not use mode here.
| origin_model_path = self.get_model_info() | ||
| print(f"[Server Mode] origin_model_path: {origin_model_path}") | ||
|
|
||
| modes = ["in_place", "retract"] |
There was a problem hiding this comment.
also, please do not use mode here.
| ] | ||
|
|
||
| for tp_size, dp_size, model_name, backend in test_suits: | ||
| modes = ["in_place", "retract"] |
There was a problem hiding this comment.
In the same time, please move the pause_generation_mode into the test_suits, make it like:
for tp_size, dp_size, model_name, backend, pause_generation_mode in test_suits:
The same for other update weights unit tests.
And, please help me change the codes as follows: (someone seems to be abusing my previous setting for random covering of the tests, basically, local tests should cover more)
if is_in_ci():
backend = random.choose("server", "engine")
pause_generation_mode = random.choose("in_place", "retract")
test_suits = [
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, backend),
]
else:
test_suits = [
(1, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Sever", random.choose("in_place", "retract")), (1, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Engine", random.choose("in_place", "retract"))
]
if torch.cuda.device_count() >= 4:
test_suits.append(
(1, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Engine", random.choose("in_place", "retract")), (2, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Sever", random.choose("in_place", "retract")),
)
if torch.cuda.device_count() >= 5:
test_suits.append(
(2, 2, DEFAULT_MODEL_NAME_FOR_TEST, random.choose("Engine", "Sever"), random.choose("in_place", "retract")),
)Do it the same as other tests, adding new parameters in the test_suites.
There was a problem hiding this comment.
The pause_generation API does not expose an interface for Engine mode
| in_place: Pause the scheduler's event_loop from performing inference; | ||
| only non-inference requests (e.g., control commands) will be handled. | ||
| Note: In 'inplace' mode, flush_cache will fail if there are any requests | ||
| in the running_batch. | ||
|
|
||
| retract: Pause the scheduler's event loop from performing inference; | ||
| only non-inference requests will be handled, and all currently running | ||
| requests will be retracted back to the waiting_queue. |
There was a problem hiding this comment.
Please add these details in the comments:
What will happen in in_place? The requests in the engine will be paused and stay in the event_loop, then continue generation after ContinueGenerationReqInput with the old kv cache?
Also, for retract, the requests in the engine will be paused and removed from the event_loop into waiting_queue, and the KV cache will be flushed and recomputed after ContinueGenerationReqInput?
If my understanding is correct, please make a more adequate description of this PR.
| 1, | ||
| 1, | ||
| DEFAULT_SMALL_MODEL_NAME_FOR_TEST, | ||
| "Server", |
There was a problem hiding this comment.
if_in_ci, also choose from ["Server", "Engine"]
|
The test suites I want to have in update weights generally: Inside the test function, if the backend is "Server", then set the Generally, we do not want to double the test file length and the test running time. For example, keep the if is_in_ci():
backend = random.choose("Engine", "Server")
test_suits = [
(
1,
1,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
backend,
),
]
else:
test_suits = [
(
1,
1,
DEFAULT_MODEL_NAME_FOR_TEST,
"Engine",
),
(
1,
1,
DEFAULT_MODEL_NAME_FOR_TEST,
"Server",
),
]
if torch.cuda.device_count() >= 4:
test_suits.append(
(
1,
2,
DEFAULT_MODEL_NAME_FOR_TEST,
random.choice(["Server", "Engine"],
),
)
if torch.cuda.device_count() >= 5:
test_suits.append(
(
2,
2,
DEFAULT_MODEL_NAME_FOR_TEST,
random.choice(["Server", "Engine"]),
),
) |
Add note about PauseGenerationRequests support in SGLang Server.
| result = (await self.update_weights_from_distributed_communicator(obj))[ | ||
| 0 | ||
| ] | ||
| return result.success, result.message |
There was a problem hiding this comment.
There are some other steps.
Please also do success, message = _Communicator.merge_results(results) and
if success and obj.weight_version is not None:
self._update_weight_version_if_provided(obj.weight_version)
message += f" Weight version updated to {obj.weight_version}."
| async with self.is_pause_cond: | ||
| if self.is_pause: | ||
| result = (await self.update_weights_from_tensor_communicator(obj))[0] | ||
| return result.success, result.message |
There was a problem hiding this comment.
Same, please do the post-processing correctly instead of directly returning the result from rank 0.
Code skeleton
if self.is_pause:
result = await self.update_weights_from_tensor_communicator(obj)
else:
async with self.model_update_lock.writer_lock:
result = ...
# post processing
Motivation
In agent RL training, we need to synchronize model parameters without pausing the SGLang engine.
Existing pause()-based approaches break the inference service: the agent continues to send requests while the engine is paused, resulting in runtime errors.
This PR enables parameter updates without interrupting live inference, keeping the agent unaware of the sync process.
Modifications
The pause_generation API now supports three distinct modes to provide fine-grained control over the scheduler’s behavior during runtime:
Building on the pause_generation mechanism, the update_weights API can now deliver new model weights directly to the scheduler without acquiring the global model_update_lock.
Accuracy Tests
Benchmarking and Profiling
Checklist