-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add get base model state dict #3000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Isalia20
wants to merge
6
commits into
huggingface:main
Choose a base branch
from
Isalia20:add-get-base-model-state-dict
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+600
−1
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -760,3 +760,172 @@ def get_hub_filename(use_safetensors=True): | |
| remapped_adapters_weights[key_with_prefix] = val | ||
|
|
||
| return remapped_adapters_weights | ||
|
|
||
|
|
||
| def _get_adapter_state_dict_key_prefixes(model) -> set[str]: | ||
| """Collect state dict key prefixes for adapter parameters by inspecting the module tree. | ||
|
|
||
| For ``BaseTunerLayer`` modules this uses ``adapter_layer_names`` and ``other_param_names`` to build prefixes. For | ||
| ``AuxiliaryTrainingWrapper`` modules it queries ``adapter_state_dict_load_map`` (the canonical source of adapter | ||
| keys) for each registered adapter, and falls back to ``other_param_names`` for non-saveable adapter attributes. | ||
| """ | ||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
|
||
| adapter_key_prefixes: set[str] = set() | ||
| for module_name, module in model.base_model.model.named_modules(): | ||
| if isinstance(module, AuxiliaryTrainingWrapper): | ||
| # Use adapter_state_dict_load_map to get the actual state dict keys owned by | ||
| # each adapter | ||
| for adapter_name in module._adapters: | ||
| load_map = module.adapter_state_dict_load_map(adapter_name) | ||
| for state_dict_key in load_map.values(): | ||
| prefix = f"{module_name}.{state_dict_key}" if module_name else state_dict_key | ||
| adapter_key_prefixes.add(prefix) | ||
| for attr_name in module.other_param_names: | ||
| prefix = f"{module_name}.{attr_name}" if module_name else attr_name | ||
| adapter_key_prefixes.add(prefix) | ||
| elif isinstance(module, BaseTunerLayer): | ||
| for attr_name in module.adapter_layer_names + module.other_param_names: | ||
| prefix = f"{module_name}.{attr_name}" if module_name else attr_name | ||
| adapter_key_prefixes.add(prefix) | ||
| return adapter_key_prefixes | ||
|
|
||
|
|
||
| def _is_adapter_key(key: str, adapter_key_prefixes: set[str]) -> bool: | ||
| """Check if a state dict key belongs to an adapter parameter.""" | ||
| return any(key == pfx or key.startswith(pfx + ".") for pfx in adapter_key_prefixes) | ||
|
|
||
|
|
||
| def _peft_key_to_original_key(model, peft_key: str) -> str: | ||
| """Transform a PEFT state dict key to its original base model key. | ||
|
|
||
| Walks the module tree to strip wrapper infixes (``base_layer``, ``original_module``, and internal tuner modules | ||
| like ``token_adapter`` inside ``AuxiliaryTrainingWrapper``). | ||
| """ | ||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
|
||
| parts = peft_key.split(".") | ||
| original_parts: list[str] = [] | ||
| current = model.base_model.model | ||
|
|
||
| for part in parts: | ||
| if current is None: | ||
| # Already past the module tree (e.g. nested parameter names) | ||
| original_parts.append(part) | ||
| continue | ||
|
|
||
| child = getattr(current, part, None) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you motivate (in a code comment) why we're traversing the model this way and not using |
||
| if child is None or not isinstance(child, torch.nn.Module): | ||
| # Parameter/buffer name - keep it | ||
| original_parts.append(part) | ||
| current = None | ||
| elif part == "base_layer" and isinstance(current, BaseTunerLayer): | ||
| # Skip the base_layer infix inside a tuner | ||
| current = child | ||
| elif part == "original_module" and isinstance(current, AuxiliaryTrainingWrapper): | ||
| # Skip the original_module infix inside a wrapper | ||
| current = child | ||
| elif isinstance(child, BaseTunerLayer) and isinstance(current, AuxiliaryTrainingWrapper): | ||
| # Internal tuner inside a wrapper (e.g. token_adapter in TrainableTokensWrapper) - skip | ||
| current = child | ||
| elif isinstance(child, (BaseTunerLayer, AuxiliaryTrainingWrapper)): | ||
| # Tuner/wrapper that replaces an original module (e.g. LoRA at q_proj) - keep name | ||
| original_parts.append(part) | ||
| current = child | ||
| else: | ||
| # Regular module - keep name | ||
| original_parts.append(part) | ||
| current = child | ||
|
|
||
| return ".".join(original_parts) | ||
|
|
||
|
|
||
| def get_base_model_state_dict(model) -> dict[str, torch.Tensor]: | ||
| """Return the state dict of the base model with the original model keys. | ||
|
|
||
| Extracts the base model's parameters from a PEFT-wrapped model, removing PEFT-specific key modifications and | ||
| filtering out adapter-specific parameters. | ||
|
|
||
| Args: | ||
| model: A ``PeftModel`` instance. | ||
|
|
||
| Returns: | ||
| The base model's state dict with original keys (without PEFT modifications). | ||
| """ | ||
| # For prompt learning methods the base model structure is not modified, so the state | ||
| # dict already uses the original keys and contains no adapter-injected parameters. | ||
| if model._is_prompt_learning: | ||
| return dict(model.base_model.state_dict()) | ||
|
|
||
| state_dict = model.base_model.model.state_dict() | ||
| adapter_key_prefixes = _get_adapter_state_dict_key_prefixes(model) | ||
|
|
||
| result: dict[str, torch.Tensor] = {} | ||
| for key, value in state_dict.items(): | ||
| if _is_adapter_key(key, adapter_key_prefixes): | ||
| continue | ||
| result[_peft_key_to_original_key(model, key)] = value | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| def set_base_model_state_dict( | ||
| model, | ||
| state_dict: dict[str, torch.Tensor], | ||
| strict: bool = True, | ||
| ): | ||
| """Load a state dict with original model keys into a PEFT-wrapped model. | ||
|
|
||
| Takes a state dict keyed by the original (pre-PEFT) model key names and loads it into the base model, automatically | ||
| translating keys to account for PEFT wrapper infixes (``base_layer``, ``original_module``, etc.). | ||
|
|
||
| This is the counterpart to :func:`get_base_model_state_dict` and is useful for scenarios like loading base model | ||
| weights after FSDP wrapping. | ||
|
|
||
| Args: | ||
| model: A ``PeftModel`` instance. | ||
| state_dict: The state dict with original model keys to load. | ||
| strict: Whether to strictly enforce that the keys match. If ``True``, | ||
| raises ``RuntimeError`` on missing or unexpected keys. | ||
|
|
||
| Returns: | ||
| A ``namedtuple`` with ``missing_keys`` and ``unexpected_keys`` fields. | ||
| """ | ||
| _IncompatibleKeys = namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]) | ||
|
|
||
| # For prompt learning methods the base model structure is not modified, so the state | ||
| # dict already uses the original keys and contains no adapter-injected parameters. | ||
| if model._is_prompt_learning: | ||
| return model.base_model.load_state_dict(state_dict, strict=strict) | ||
|
|
||
| current_state_dict = model.base_model.model.state_dict() | ||
| adapter_key_prefixes = _get_adapter_state_dict_key_prefixes(model) | ||
|
|
||
| # Build mapping: original_key → peft_key | ||
| original_to_peft_key: dict[str, str] = {} | ||
| for peft_key in current_state_dict.keys(): | ||
| if _is_adapter_key(peft_key, adapter_key_prefixes): | ||
| continue | ||
| original_to_peft_key[_peft_key_to_original_key(model, peft_key)] = peft_key | ||
|
|
||
| peft_state_dict: dict[str, torch.Tensor] = {} | ||
| unexpected_keys: list[str] = [] | ||
|
|
||
| for original_key, value in state_dict.items(): | ||
| if original_key in original_to_peft_key: | ||
| peft_state_dict[original_to_peft_key[original_key]] = value | ||
| else: | ||
| unexpected_keys.append(original_key) | ||
|
|
||
| missing_keys = [k for k in original_to_peft_key if k not in state_dict] | ||
|
|
||
| if strict and (missing_keys or unexpected_keys): | ||
| error_msgs: list[str] = [] | ||
| if missing_keys: | ||
| error_msgs.append(f"Missing key(s) in state_dict: {missing_keys}") | ||
| if unexpected_keys: | ||
| error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected_keys}") | ||
| raise RuntimeError("Error(s) in loading state_dict:\n\t" + "\n\t".join(error_msgs)) | ||
|
|
||
| model.base_model.model.load_state_dict(peft_state_dict, strict=False) | ||
| return _IncompatibleKeys(missing_keys=missing_keys, unexpected_keys=unexpected_keys) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move the changes from
peft_model.pytoutils/save_and_load.pyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure there's precedent for this in
PeftModel. Unless I'm mistaken, let's remove these calls since the current pattern is to import fromsave_and_loadutils.