Skip to content

[V1] Support meta loading for full and free#10236

Open
jiaqiw09 wants to merge 2 commits intohiyouga:mainfrom
jiaqiw09:meta_update
Open

[V1] Support meta loading for full and free#10236
jiaqiw09 wants to merge 2 commits intohiyouga:mainfrom
jiaqiw09:meta_update

Conversation

@jiaqiw09
Copy link
Collaborator

@jiaqiw09 jiaqiw09 commented Mar 2, 2026

What does this PR do?

Fixes # (issue)

Before submitting

@hiyouga would you mind having a check? I will make pr for lora next.

here is the test for full after I set all deterministic
Pasted Graphic 2

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves the model loading process for FSDP2 by enabling proper support for 'meta' device initialization. It addresses critical aspects like managing non-persistent buffers and ensuring tied weights are correctly handled, which is crucial for large model training efficiency. Concurrently, it temporarily restricts LoRA from using meta-device loading, indicating a focused approach to FSDP2's capabilities.

Highlights

  • FSDP2 Meta-Device Loading Enhancement: Implemented robust handling for non-persistent buffers and tied weights when loading models on 'meta' devices using FSDP2, ensuring correct model initialization and sharding.
  • LoRA Meta-Device Restriction: Introduced a ValueError to explicitly prevent LoRA models from being loaded on 'meta' devices, indicating current incompatibility or a planned future enhancement.
  • Comprehensive Testing: Added a new unit test to validate the consistency of FSDP2 meta-device loading, specifically verifying the correct handling of non-persistent buffers and tied weights against normal loading paths.
Changelog
  • src/llamafactory/v1/plugins/model_plugins/peft.py
    • Added a check to prevent LoRA models from loading on meta devices.
  • src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py
    • Imported the copy module.
    • Added _save_non_persistent_buffers method to collect non-persistent buffers.
    • Added _restore_non_persistent_buffers method to re-register saved non-persistent buffers.
    • Modified shard_model to utilize new buffer handling methods and re-tie weights when loading models on meta devices.
  • tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
    • Added a new test file test_fsdp2_meta_loading_buffers_and_tied_weights.
    • Implemented a helper function collect_non_persistent_buffers.
    • Included a test case to verify non-persistent buffers and tied weights consistency after meta-device loading in FSDP2.
Activity
  • The author completed the pre-submission checklist.
  • The author requested a review from @hiyouga.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for meta device loading in FSDP2. The changes involve saving and restoring non-persistent buffers, and correctly handling tied weights during model materialization. A new test file has been added to verify this behavior. My review includes a suggestion to improve code clarity in src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py by reducing code duplication.

Comment on lines +250 to +258
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()

model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)

# fix tied broken for no-fsdp-wrap case
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve code clarity and avoid redundant calls to getattr, you can store the result of getattr(model.config, "tie_word_embeddings", None) in a variable and reuse it.

Suggested change
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()
model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
# fix tied broken for no-fsdp-wrap case
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()
should_tie_weights = getattr(model.config, "tie_word_embeddings", None)
if should_tie_weights:
model.tie_weights()
model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
# fix tied broken for no-fsdp-wrap case
if should_tie_weights:
model.tie_weights()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant