Skip to content

[feat] support in-flight weight update#10071

Merged
Kangyan-Zhou merged 22 commits intosgl-project:mainfrom
ShawnY112358:online_update
Nov 26, 2025
Merged

[feat] support in-flight weight update#10071
Kangyan-Zhou merged 22 commits intosgl-project:mainfrom
ShawnY112358:online_update

Conversation

@ShawnY112358
Copy link
Contributor

@ShawnY112358 ShawnY112358 commented Sep 5, 2025

Motivation

In agent RL training, we need to synchronize model parameters without pausing the SGLang engine.
Existing pause()-based approaches break the inference service: the agent continues to send requests while the engine is paused, resulting in runtime errors.
This PR enables parameter updates without interrupting live inference, keeping the agent unaware of the sync process.

Modifications

  1. pause_generation API with Three Operation Modes
    The pause_generation API now supports three distinct modes to provide fine-grained control over the scheduler’s behavior during runtime:
  • abort: Immediately aborts all currently processing requests and returns them.
  • in_place: Pauses the scheduler’s event loop from performing inference; only non-inference requests (e.g., control commands) are handled. ⚠️ In this mode, flush_cache will fail if there are any requests in the running_batch.
  • retract: Pauses inference and retracts all currently running requests back to the waiting_queue, allowing them to be rescheduled later. Only non-inference requests are processed while paused.
  1. Non-blocking update_weights
    Building on the pause_generation mechanism, the update_weights API can now deliver new model weights directly to the scheduler without acquiring the global model_update_lock.

Accuracy Tests

Benchmarking and Profiling

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ShawnY112358, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a critical feature for continuous operation of the SGLang engine by enabling model parameter updates without requiring a service interruption. This enhancement is particularly beneficial for agent reinforcement learning training, where seamless synchronization of model parameters is essential to avoid disrupting ongoing inference services. The changes ensure that the system can update its underlying model weights dynamically, maintaining responsiveness and preventing errors that previously occurred during update-related pauses.

Highlights

  • Online Weight Updates: Introduced the capability to update model weights online without interrupting live inference, addressing issues where pausing the engine for updates caused runtime errors in agent training environments.
  • Scheduler Integration for Weight Swaps: Modified the scheduler to perform weight swaps between two forward passes of a running batch. If KV-cache flushing is required, the scheduler now retracts the entire running batch, flushes the cache, and then re-prefills and resumes generation.
  • API and Internal Synchronization: Added an 'online' flag to weight update request inputs (from disk, distributed, and tensor) and implemented synchronization mechanisms using threading events to ensure safe online updates within the tp_worker_overlap_thread.
  • Enhanced Test Coverage: Added new unit tests specifically for online weight updates across disk, distributed, and tensor update methods, including scenarios with concurrent decode requests to validate non-interruption.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature: online model weight updates without pausing the inference engine. The implementation appears robust, correctly handling synchronization between the scheduler and the model worker using a threading event to prevent race conditions during forward passes. The new interrupt_generation method in the scheduler, which retracts running requests to the waiting queue before flushing the cache, is a key part of this and seems well-designed. The addition of comprehensive tests for this new online update functionality is commendable. I've found one issue in the test code that needs to be addressed.

@ShawnY112358
Copy link
Contributor Author

@ocss884

@zhaochenyang20
Copy link
Collaborator

This PR enables parameter updates without interrupting live inference, keeping the agent unaware of the sync process.

It is quite strange. Which version of the parameter do you want to use? I do not think that this need is valid (correct me if I'm wrong).

Also, are you sure that the token can be correctly sampled during the refit? For example, given a 50 - layer model, 1~24 layers are the old version, the others are the new version. Could this give you a correct sample token?

@ShawnY112358
Copy link
Contributor Author

This PR enables parameter updates without interrupting live inference, keeping the agent unaware of the sync process.

It is quite strange. Which version of the parameter do you want to use? I do not think that this need is valid (correct me if I'm wrong).

Also, are you sure that the token can be correctly sampled during the refit? For example, given a 50 - layer model, 1~24 layers are the old version, the others are the new version. Could this give you a correct sample token?

In RL scenarios, our algorithm supports partial rollout, where different segments of response can be inferred using different checkpoints. We aim to update the inference engine parameters promptly. Using writer_lock would require waiting for ongoing inference requests to complete, creating idle periods. In contrast, online parameter updates can maximize inference resource utilization.
To support this, we've added an online parameter to the update weights request, enabling users to choose whether to flush pending requests from the inference engine or proceed with online parameter updates based on their requirements.

# Whether to update weights online
online: bool = False
# Whether to wait for all ongoing inference requests to complete before updating
# parameters, or interrupt them and update parameters directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or abort and return all the unfinished requests, then update parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or abort and return all the unfinished requests, then update parameters.

When non_blocking is set to True, the running request does not return immediately. Instead, it remains in the inference engine while the scheduler's event loop pauses decoding. Inference resumes only after the parameter synchronization is complete.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have 3 cases now?

  1. wait for all requests in queue and update
  2. cancel current requests with partial completion and update
  3. keep current requests, update inbetween forward, and update.

Can we still have all the options available in this struct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Simply set non_blocking=False.
  2. Call pause_generation first, then call the update_weights API—this approach is independent of the non_blocking parameter. Of course, if you use the pause_generation API introduced in this PR, you’ll need to set abort_all=True.
  3. First call pause_generation with abort_all=False, then call update_weights with non_blocking=True. If you want to implement a single-call non-blocking weight synchronization, you wouldn’t need to call pause_generation at all—just call update_weights directly with non_blocking=True.

Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this sense:

  1. Set non_blocking= False, the weights refit will wait till all the remaining requests finish, then begin the update.
  2. First pause_generation with abort_all=True, then all the requests will be returned. Then call the update weights.
  3. Directly call update_weights with non_blocking=True, then the engine will pause decoding, wait for the parameters to be updated, and then continue decoding with the old KV cache immediately.

Am I understanding these right? If so, could we add these three conditions and usage to the comments to let users know?

BTW, what shall happen if pause_generation but abort_all=False?

# Optional format specification for loading
flattened_bucket_meta: Optional[dict] = None
# Whether to wait for all ongoing inference requests to complete before updating
# parameters, or interrupt them and update parameters directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above, change it

Comment on lines +1096 to +1097
# Whether to wait for all ongoing inference requests to complete before updating
# parameters, or interrupt them and update parameters directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change it as above

Comment on lines +216 to +230
def pause_generation(self):
response = requests.post(
self.base_url + "/pause_generation",
json={"abort_all": False, "retract_all": True},
)
ret = response.json()
return ret

def continue_generation(self):
response = requests.post(
self.base_url + "/continue_generation",
json={},
)
ret = response.json()
return ret
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not catching the development of these APIs for a long time. Just pondering what these two APIs, "pause" and "continue" generation mean in SGLang. What's the difference of abort_all? In my understanding, the pause and continue don't return the unfinished requests to the user. It just stops generation, then maybe we update the weights but not flush KV cache, then continue generation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the old API, pause_generation would abort all running requests and return immediately, and no new requests would be processed afterward. #7419

We now want pause_generation to not abort requests, but instead only pause the decoding process of currently running requests. To achieve this, we introduced the abort_all parameter:

When abort_all=True, pause_generation behaves as before—aborting all running requests.
When abort_all=False, it pauses the inference process in the scheduler's event loop without aborting any requests. The requests remain in either the running_queue or the waiting_queue, preserving their state for potential resumption later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot have the model updating parameters during forward passes. So in this unit test, we first pause the scheduler's event loop using a pause_generation request, execute non-blocking parameter updates, and then resume the scheduler's event loop.

Users have the flexibility to decide whether to flush the KV cache according to their specific requirements, and can resume inference by calling the continue_generation API after the cache has been flushed.(Note that if users want to flush the KV cache, users need to retract all running batches using the retract_all parameter within the pause_generation process to force all requests to re-prefill after resuming inference.otherwise, this will result in an error)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that if users want to flush the KV cache, users need to retract all running batches using the retract_all parameter within the pause_generation process to force all requests to re-prefill after resuming inference.otherwise, this will result in an error

Great, could you add this in the comment of the change. This shall be important to users.

@zhaochenyang20
Copy link
Collaborator

Will review it soon.

@zhaochenyang20
Copy link
Collaborator

sorry, busy with my stuff at school. Ping me if not reply in 24h.

Copy link
Contributor

@JD-ETH JD-ETH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this! This is a feature that's we are very excited about. However, I have the following concerns about the interaction design.

Instead of the pause - update - continue; I would like a single call of update_nonblock:

ret = self.run_update_weights(named_tensors, non_blocking=True, flush_cache=False)

compared to what's in the current PR

            ret = self.pause_generation()
            ret = self.run_update_weights(named_tensors, non_blocking=True)
            self.assertTrue(ret["success"])
            ret = self.continue_generation()
  1. The key of this feature for me is to enable a true non-blocking weight update that is completely invisible to the inference caller. inference caller will not be aware of underlying updates, beyond metadata being returned; and this allows a weight update caller to work completely independently from the service caller.

From what I see it seems possible to just send the _pause_engine signal directly to the scheduler through weight_update call without even changing the pause_engine API? Obviously I don't know the overlap scheduler well enough, but handling this seems possible.

Following that "invisible" paradigm, the retract_all logic will need to be implicitly handled given the flush_cache flag, and we have one less failure point for the user.

  • non_blocking + flush_cache: pause generation, flush_cache, weight_update, send all current running requests back to prefill stage, continue
  • non_blocking only: pause generation, weight_update, continue.
    Personally I'm not sure how much people like option 2 --- You are now left with a dirty cache without knowing which version of checkpoint they came from, do we have an actual use case that would prefer this?

Let me know what you think

# Whether to update weights online
online: bool = False
# Whether to wait for all ongoing inference requests to complete before updating
# parameters, or interrupt them and update parameters directly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have 3 cases now?

  1. wait for all requests in queue and update
  2. cancel current requests with partial completion and update
  3. keep current requests, update inbetween forward, and update.

Can we still have all the options available in this struct?

@ShawnY112358
Copy link
Contributor Author

Thanks for implementing this! This is a feature that's we are very excited about. However, I have the following concerns about the interaction design.

Instead of the pause - update - continue; I would like a single call of update_nonblock:

ret = self.run_update_weights(named_tensors, non_blocking=True, flush_cache=False)

compared to what's in the current PR

            ret = self.pause_generation()
            ret = self.run_update_weights(named_tensors, non_blocking=True)
            self.assertTrue(ret["success"])
            ret = self.continue_generation()
  1. The key of this feature for me is to enable a true non-blocking weight update that is completely invisible to the inference caller. inference caller will not be aware of underlying updates, beyond metadata being returned; and this allows a weight update caller to work completely independently from the service caller.

From what I see it seems possible to just send the _pause_engine signal directly to the scheduler through weight_update call without even changing the pause_engine API? Obviously I don't know the overlap scheduler well enough, but handling this seems possible.

Following that "invisible" paradigm, the retract_all logic will need to be implicitly handled given the flush_cache flag, and we have one less failure point for the user.

  • non_blocking + flush_cache: pause generation, flush_cache, weight_update, send all current running requests back to prefill stage, continue
  • non_blocking only: pause generation, weight_update, continue.
    Personally I'm not sure how much people like option 2 --- You are now left with a dirty cache without knowing which version of checkpoint they came from, do we have an actual use case that would prefer this?

Let me know what you think

Question 1: Why not implement a single-call, non-blocking update?

I initially designed it this way for the following reasons:

  1. I believe the current pause_generation functionality is incomplete—it can only abort all requests but doesn’t truly support pausing inference.
  2. In training-inference co-located RL frameworks (e.g., where SGLang and Megatron share resources and run alternately), we often want to let SGLang finish generating enough data for one training step, then pause inference while keeping unfinished requests in SGLang. After switching to the training phase, once training completes, we update SGLang with the new model weights and resume inference. In this scenario, a simple non-blocking update alone is insufficient.
  3. Reusing the existing pause_generation interface allows us to minimize changes to the update_weights API. For example:
    To avoid OOM, we might need to release KV cache memory before weight synchronization and restore it afterward.
    If we need to flush the cache after weight synchronization, we must first retract all running requests before synchronization.
    Implementing all these logic inside the update_weights API would make it overly complex and bloated. With the current approach, we don’t need to modify update_weights at all—users can call operations like releasing memory or flushing the cache themselves after pause_generation, based on their specific needs.

That said, I agree that the single-call, non-blocking update you proposed is also very reasonable and more user-friendly. I support both implementation approaches. If we go with the single-call design, after merging this PR, we would only need to enhance the pause_generation implementation on top of the community’s baseline to satisfy the requirement mentioned in point #2 above.

@ShawnY112358
Copy link
Contributor Author

Thanks for implementing this! This is a feature that's we are very excited about. However, I have the following concerns about the interaction design.

Instead of the pause - update - continue; I would like a single call of update_nonblock:

ret = self.run_update_weights(named_tensors, non_blocking=True, flush_cache=False)

compared to what's in the current PR

            ret = self.pause_generation()
            ret = self.run_update_weights(named_tensors, non_blocking=True)
            self.assertTrue(ret["success"])
            ret = self.continue_generation()
  1. The key of this feature for me is to enable a true non-blocking weight update that is completely invisible to the inference caller. inference caller will not be aware of underlying updates, beyond metadata being returned; and this allows a weight update caller to work completely independently from the service caller.

From what I see it seems possible to just send the _pause_engine signal directly to the scheduler through weight_update call without even changing the pause_engine API? Obviously I don't know the overlap scheduler well enough, but handling this seems possible.

Following that "invisible" paradigm, the retract_all logic will need to be implicitly handled given the flush_cache flag, and we have one less failure point for the user.

  • non_blocking + flush_cache: pause generation, flush_cache, weight_update, send all current running requests back to prefill stage, continue
  • non_blocking only: pause generation, weight_update, continue.
    Personally I'm not sure how much people like option 2 --- You are now left with a dirty cache without knowing which version of checkpoint they came from, do we have an actual use case that would prefer this?

Let me know what you think

Question 2: Regarding "dirty" KV cache

Some research has shown that in partial rollouts, RL training can still converge properly even without flushing the KV cache. Our recent experiments have also observed similar findings.

Not flushing the cache avoids the computational overhead of re-running the prefill phase for cached prompts, which can provide a noticeable speed advantage—especially when weight synchronization happens frequently.

@ShawnY112358
Copy link
Contributor Author

sorry, busy with my stuff at school. Ping me if not reply in 24h.

ding~

@zhaochenyang20
Copy link
Collaborator

sorry, busy with my stuff at school. Ping me if not reply in 24h.

ding~

Thanks! Will reply today!

@zhaochenyang20
Copy link
Collaborator

I agree with this discussion. Let me set these APIs in this, would you guys agree: @JD-ETH @ShawnY112358

Let's rename non_blocking to block_generation.

  1. Wait for all the running requests to finish, then do the update weights:
 ret = self.run_update_weights(named_tensors, block_generation=False)
  1. Stop and return all the generated content, and update the weights.
ret = self.pause_generation(abort_all=True)
ret = self.run_update_weights(named_tensors, block_generation=False/True)

Note that, to continue rollout these unfinished requests, we shall manually send these requests back to the server. Simply calling the self.continue_generation() does not affect.

  1. Stop current decoding, update the weights, and continue decoding after refitting immediately with the old kv cache.
ret = self.run_update_weights(named_tensors, block_generation=True, flush_cache=False)
  1. Stop current decoding, update the weights, flush the KV cache with new parameters, and continue decoding:
ret = self.run_update_weights(named_tensors, block_generation=True, flush_cache=True)

Would this be clearer?

@JD-ETH
Copy link
Contributor

JD-ETH commented Nov 21, 2025

final minor nits

@ShawnY112358 ShawnY112358 requested a review from JD-ETH November 21, 2025 07:44
@dataclass
class PauseGenerationReqInput(BaseReq):
"""
abort: Abort all requests currently being processed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Abort all requests currently being processed.

->

Abort and return all requests currently being processed.

Comment on lines +1072 to +1079
in_place: Pause the scheduler's event_loop from performing inference;
only non-inference requests (e.g., control commands) will be handled.
Note: In 'inplace' mode, flush_cache will fail if there are any requests
in the running_batch.

retract: Pause the scheduler's event loop from performing inference;
only non-inference requests will be handled, and all currently running
requests will be retracted back to the waiting_queue.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen in in_place? The requests in the engine will be paused and stay in the event_loop, then continue generation after ContinueGenerationReqInput with the old kv cache?

Also, for retract, the requests in the engine will be paused and removed from the event_loop into waiting_queue, and the KV cache will be flushed and recomputed after ContinueGenerationReqInput?

If my understanding is correct, please make a more adequate description of this PR.

At the same time, it seems that in_place and retract are counterparts; should we only keep one parameter like retract=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed this with @JD-ETH . Currently, pause only supports these three modes. If we use the abort and retract parameters for control, the retract parameter becomes meaningless when abort=True.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add these details in the comments:

What will happen in in_place? The requests in the engine will be paused and stay in the event_loop, then continue generation after ContinueGenerationReqInput with the old kv cache?
Also, for retract, the requests in the engine will be paused and removed from the event_loop into waiting_queue, and the KV cache will be flushed and recomputed after ContinueGenerationReqInput?
If my understanding is correct, please make a more adequate description of this PR.

@dataclass
class PauseGenerationReqInput(BaseReq):
"""
abort: Abort all requests currently being processed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my late reply, but I'm just not sure we should put the abort here.

If we have an abort parameter here, the abort_request and pause_generation seem to be intersecting. This makes the API ambiguous.

Could we only leave the pause to do the pause, and let users choose whether to recompute the KV cache?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the current pause API can do the abort_all, which feels a bit strange to me. Let me discuss with the team.

Comment on lines +315 to +318
ret = self.pause_generation(mode)
ret = self.run_update_weights(
new_model_path, flush_cache=mode == "retract"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, we cannot validate whether the KV cache is indeed recomputed. We can leave a comment here, stating that “we can not validate whether the KV cache is indeed recomputed, we just trust the retract parameter.”

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't quite understand your point. Here, we just want to clarify that if the user wants to flush the cache, they must call pause_generation in retract mode.

]

for tp_size, dp_size, model_name, backend in test_suits:
modes = ["in_place", "retract"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modes = ["in_place", "retract"] -> pause_generation_modes = ["in_place", "retract"].


for tp_size, dp_size, model_name, backend in test_suits:
modes = ["in_place", "retract"]
for mode in modes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for pause_generation_mode in pause_generation_modes:

model_state_dict_shapes[model_name],
truncate_size,
checking_parameters,
mode=mode,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pause_generation_mode, please do not use pause_generation_mode=pause_generation_mode.

response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the text also.

response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the text also.


def test_update_weights(self):
modes = ["in_place", "retract"]
for mode in modes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, please do not use mode here.

origin_model_path = self.get_model_info()
print(f"[Server Mode] origin_model_path: {origin_model_path}")

modes = ["in_place", "retract"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, please do not use mode here.

]

for tp_size, dp_size, model_name, backend in test_suits:
modes = ["in_place", "retract"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the same time, please move the pause_generation_mode into the test_suits, make it like:

for tp_size, dp_size, model_name, backend, pause_generation_mode in test_suits:

The same for other update weights unit tests.

And, please help me change the codes as follows: (someone seems to be abusing my previous setting for random covering of the tests, basically, local tests should cover more)

        if is_in_ci():
            backend = random.choose("server", "engine")
            pause_generation_mode = random.choose("in_place", "retract")
            test_suits = [
                (1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, backend),
            ]
        else:

            test_suits = [
                (1, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Sever", random.choose("in_place", "retract")), (1, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Engine", random.choose("in_place", "retract"))
            ]

            if torch.cuda.device_count() >= 4:
                test_suits.append(
                    (1, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Engine", random.choose("in_place", "retract")), (2, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Sever", random.choose("in_place", "retract")),
                )

            if torch.cuda.device_count() >= 5:
                test_suits.append(
                    (2, 2, DEFAULT_MODEL_NAME_FOR_TEST, random.choose("Engine", "Sever"), random.choose("in_place", "retract")),
                )

Do it the same as other tests, adding new parameters in the test_suites.

Copy link
Contributor Author

@ShawnY112358 ShawnY112358 Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pause_generation API does not expose an interface for Engine mode

Comment on lines +1072 to +1079
in_place: Pause the scheduler's event_loop from performing inference;
only non-inference requests (e.g., control commands) will be handled.
Note: In 'inplace' mode, flush_cache will fail if there are any requests
in the running_batch.

retract: Pause the scheduler's event loop from performing inference;
only non-inference requests will be handled, and all currently running
requests will be retracted back to the waiting_queue.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add these details in the comments:

What will happen in in_place? The requests in the engine will be paused and stay in the event_loop, then continue generation after ContinueGenerationReqInput with the old kv cache?
Also, for retract, the requests in the engine will be paused and removed from the event_loop into waiting_queue, and the KV cache will be flushed and recomputed after ContinueGenerationReqInput?
If my understanding is correct, please make a more adequate description of this PR.

1,
1,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
"Server",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if_in_ci, also choose from ["Server", "Engine"]

@zhaochenyang20
Copy link
Collaborator

The test suites I want to have in update weights generally:

Inside the test function, if the backend is "Server", then set the pause_generation_mode = random.choice(["in_place", "retract"]). If the backend is "Engine", no need to do the non-blocking update weights.

Generally, we do not want to double the test file length and the test running time.

For example, keep the run_suites as`:

        if is_in_ci():
            backend = random.choose("Engine", "Server")
            test_suits = [
                (
                    1,
                    1,
                    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                   backend,
                ),
            ]
        else:
            test_suits = [
                (
                    1,
                    1,
                    DEFAULT_MODEL_NAME_FOR_TEST,
                    "Engine",
                ),
                (
                    1,
                    1,
                    DEFAULT_MODEL_NAME_FOR_TEST,
                    "Server",
                ),
            ]

            if torch.cuda.device_count() >= 4:
                test_suits.append(
                    (
                        1,
                        2,
                        DEFAULT_MODEL_NAME_FOR_TEST,
                        random.choice(["Server", "Engine"],
                    ),
                )

            if torch.cuda.device_count() >= 5:
                test_suits.append(
                    (
                        2,
                        2,
                        DEFAULT_MODEL_NAME_FOR_TEST,
                        random.choice(["Server", "Engine"]),
                    ),
                )

ShawnY112358 and others added 4 commits November 25, 2025 15:10
Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job

@Kangyan-Zhou Kangyan-Zhou merged commit 007c3e2 into sgl-project:main Nov 26, 2025
394 of 441 checks passed
result = (await self.update_weights_from_distributed_communicator(obj))[
0
]
return result.success, result.message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some other steps.

Please also do success, message = _Communicator.merge_results(results) and

        if success and obj.weight_version is not None:
            self._update_weight_version_if_provided(obj.weight_version)
            message += f" Weight version updated to {obj.weight_version}."

async with self.is_pause_cond:
if self.is_pause:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, please do the post-processing correctly instead of directly returning the result from rank 0.

Code skeleton

if self.is_pause:
    result = await self.update_weights_from_tensor_communicator(obj)
else:
    async with self.model_update_lock.writer_lock:
    result = ...

# post processing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants