-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[feat] update bucketed weights from distributed #13824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @ShawnY112358, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request optimizes the distributed weight update mechanism within SGLang, specifically targeting scenarios where training and inference are disaggregated. By enabling the use of flattened tensors for weight synchronization, it significantly reduces the performance overhead previously encountered, leading to more efficient and faster parameter updates across distributed systems. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an optimized path for updating model weights from a distributed source by using flattened tensors, which should reduce parameter synchronization overhead. The changes involve adding a new field to UpdateWeightsFromDistributedReqInput, and implementing the corresponding logic in TpModelWorker and ModelRunner.
The implementation is mostly correct, but I've found a couple of critical issues in the new update_bucketed_weights_from_distributed method in ModelRunner:
- An incorrect
dtypeis used when creating the tensor to receive broadcasted weights. - The
weight_versionis used without being defined, which will lead to aNameError.
I've provided detailed comments and suggestions to fix these issues. Once these are addressed, the changes should work as intended.
| success, message = self.model_runner.update_bucketed_weights_from_distributed( | ||
| recv_req.flattened_bucket_meta, recv_req.group_name, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weight_version from recv_req is not being passed to update_bucketed_weights_from_distributed. The called method in model_runner.py attempts to use weight_version, which will cause a NameError. Please pass recv_req.weight_version to the method call.
| success, message = self.model_runner.update_bucketed_weights_from_distributed( | |
| recv_req.flattened_bucket_meta, recv_req.group_name, | |
| ) | |
| success, message = self.model_runner.update_bucketed_weights_from_distributed( | |
| recv_req.flattened_bucket_meta, recv_req.group_name, recv_req.weight_version, | |
| ) |
|
Could you please share the speedup before and after the change? |
hebiao064
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
posted in PR's description |
wow that's a lot. Good job! |
|
/tag-and-rerun-ci |
|
please rebase |
|
thanks! Nice done |
Co-authored-by: Stefan He <[email protected]>
Motivation
In the training-inference disaggregation RL scenario, we use update_weights_from_distributed to update SGLang's weights. However, update_weights_from_distributed currently does not support flattened tensors, resulting in significant time overhead during parameter synchronization.
Modifications
support using accelerating update_weights_from_distributed with flattened tensor
Accuracy Tests
Benchmarking and Profiling
Qwen3 turbopp (30B) 5.5s -> 2.5s
Qwen3 Next (80B) 130s -> 7s
Checklist