Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/peft/peft_model.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move the changes from peft_model.py to utils/save_and_load.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there's precedent for this in PeftModel. Unless I'm mistaken, let's remove these calls since the current pattern is to import from save_and_load utils.

Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@
_prepare_prompt_learning_config,
_set_adapter,
_set_trainable,
get_base_model_state_dict,
get_peft_model_state_dict,
id_tensor_storage,
infer_device,
load_peft_weights,
map_cache_to_layer_device_map,
set_base_model_state_dict,
set_peft_model_state_dict,
shift_tokens_right,
)
Expand Down Expand Up @@ -1641,6 +1643,81 @@ def supports_lora_conversion(self, adapter_name: str = "default") -> bool:

return self.base_model.supports_lora_conversion()

def get_base_model_state_dict(self) -> dict[str, torch.Tensor]:
"""
Returns the state dict of the base model with the original model keys.

This method extracts the base model's parameters, removing PEFT-specific key modifications and filtering out
adapter-specific parameters (like LoRA matrices).

This is useful when you need to access or save the base model's weights with their original key names.

Returns:
`dict[str, torch.Tensor]`:
The base model's state dict with original keys (without PEFT modifications).

Example:
```python
>>> from transformers import AutoModelForCausalLM
>>> from peft import get_peft_model, LoraConfig

>>> base_model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> original_keys = set(base_model.state_dict().keys())

>>> peft_model = get_peft_model(base_model, LoraConfig(target_modules=["c_attn"]))
>>> base_state_dict = peft_model.get_base_model_state_dict()

>>> # The keys match the original model
>>> assert set(base_state_dict.keys()) == original_keys
```
"""
return get_base_model_state_dict(self)

def set_base_model_state_dict(
self,
state_dict: dict[str, torch.Tensor],
strict: bool = True,
):
"""
Sets the base model's state dict using original model keys.

This method takes a state dict with original model key names (without PEFT modifications) and loads it into the
base model, automatically handling the key transformations required by PEFT (such as adding `.base_layer.`
infixes for tuner layers).

This is the counterpart to `get_base_model_state_dict` and is useful for scenarios like loading base model
weights after FSDP wrapping.

Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict with original model keys to load.
strict (`bool`, *optional*, defaults to `True`):
Whether to strictly enforce that the keys in `state_dict` match the keys expected by the base model. If
`True`, raises a `RuntimeError` if there are missing or unexpected keys.

Returns:
`NamedTuple` with `missing_keys` and `unexpected_keys` fields (using original key names), similar to
`torch.nn.Module.load_state_dict`.

Raises:
RuntimeError: If `strict=True` and there are missing or unexpected keys.

Example:
```python
>>> from transformers import AutoModelForCausalLM
>>> from peft import get_peft_model, LoraConfig
>>> import torch

>>> base_model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> base_weights = base_model.state_dict()
>>> peft_model = get_peft_model(base_model, LoraConfig(target_modules=["c_attn"]))

>>> # Restore original base model weights on the peft wrapped model
>>> result = peft_model.set_base_model_state_dict(base_weights)
```
"""
return set_base_model_state_dict(self, state_dict, strict=strict)


class PeftModelForSequenceClassification(PeftModel):
"""
Expand Down
10 changes: 9 additions & 1 deletion src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,13 @@
transpose,
)
from .peft_types import PeftType, TaskType, register_peft_method
from .save_and_load import get_peft_model_state_dict, load_peft_weights, set_peft_model_state_dict
from .save_and_load import (
get_base_model_state_dict,
get_peft_model_state_dict,
load_peft_weights,
set_base_model_state_dict,
set_peft_model_state_dict,
)
from .warning import PeftWarning


Expand Down Expand Up @@ -117,6 +123,7 @@
"_set_trainable",
"bloom_model_postprocess_past_key_value",
"cast_mixed_precision_params",
"get_base_model_state_dict",
"get_gptqmodel_quant_linear",
"get_peft_model_state_dict",
"get_quantization_config",
Expand All @@ -128,6 +135,7 @@
"register_peft_method",
"replace_lora_weights_loftq",
"set_additional_trainable_modules",
"set_base_model_state_dict",
"set_peft_model_state_dict",
"shift_tokens_right",
"transpose",
Expand Down
169 changes: 169 additions & 0 deletions src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,3 +760,172 @@ def get_hub_filename(use_safetensors=True):
remapped_adapters_weights[key_with_prefix] = val

return remapped_adapters_weights


def _get_adapter_state_dict_key_prefixes(model) -> set[str]:
"""Collect state dict key prefixes for adapter parameters by inspecting the module tree.

For ``BaseTunerLayer`` modules this uses ``adapter_layer_names`` and ``other_param_names`` to build prefixes. For
``AuxiliaryTrainingWrapper`` modules it queries ``adapter_state_dict_load_map`` (the canonical source of adapter
keys) for each registered adapter, and falls back to ``other_param_names`` for non-saveable adapter attributes.
"""
from peft.tuners.tuners_utils import BaseTunerLayer

adapter_key_prefixes: set[str] = set()
for module_name, module in model.base_model.model.named_modules():
if isinstance(module, AuxiliaryTrainingWrapper):
# Use adapter_state_dict_load_map to get the actual state dict keys owned by
# each adapter
for adapter_name in module._adapters:
load_map = module.adapter_state_dict_load_map(adapter_name)
for state_dict_key in load_map.values():
prefix = f"{module_name}.{state_dict_key}" if module_name else state_dict_key
adapter_key_prefixes.add(prefix)
for attr_name in module.other_param_names:
prefix = f"{module_name}.{attr_name}" if module_name else attr_name
adapter_key_prefixes.add(prefix)
elif isinstance(module, BaseTunerLayer):
for attr_name in module.adapter_layer_names + module.other_param_names:
prefix = f"{module_name}.{attr_name}" if module_name else attr_name
adapter_key_prefixes.add(prefix)
return adapter_key_prefixes


def _is_adapter_key(key: str, adapter_key_prefixes: set[str]) -> bool:
"""Check if a state dict key belongs to an adapter parameter."""
return any(key == pfx or key.startswith(pfx + ".") for pfx in adapter_key_prefixes)


def _peft_key_to_original_key(model, peft_key: str) -> str:
"""Transform a PEFT state dict key to its original base model key.

Walks the module tree to strip wrapper infixes (``base_layer``, ``original_module``, and internal tuner modules
like ``token_adapter`` inside ``AuxiliaryTrainingWrapper``).
"""
from peft.tuners.tuners_utils import BaseTunerLayer

parts = peft_key.split(".")
original_parts: list[str] = []
current = model.base_model.model

for part in parts:
if current is None:
# Already past the module tree (e.g. nested parameter names)
original_parts.append(part)
continue

child = getattr(current, part, None)
Copy link
Collaborator

@githubnemo githubnemo Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you motivate (in a code comment) why we're traversing the model this way and not using model.get_named_modules() instead? This seems like an important decision so it would be unwise to not document it.

if child is None or not isinstance(child, torch.nn.Module):
# Parameter/buffer name - keep it
original_parts.append(part)
current = None
elif part == "base_layer" and isinstance(current, BaseTunerLayer):
# Skip the base_layer infix inside a tuner
current = child
elif part == "original_module" and isinstance(current, AuxiliaryTrainingWrapper):
# Skip the original_module infix inside a wrapper
current = child
elif isinstance(child, BaseTunerLayer) and isinstance(current, AuxiliaryTrainingWrapper):
# Internal tuner inside a wrapper (e.g. token_adapter in TrainableTokensWrapper) - skip
current = child
elif isinstance(child, (BaseTunerLayer, AuxiliaryTrainingWrapper)):
# Tuner/wrapper that replaces an original module (e.g. LoRA at q_proj) - keep name
original_parts.append(part)
current = child
else:
# Regular module - keep name
original_parts.append(part)
current = child

return ".".join(original_parts)


def get_base_model_state_dict(model) -> dict[str, torch.Tensor]:
"""Return the state dict of the base model with the original model keys.

Extracts the base model's parameters from a PEFT-wrapped model, removing PEFT-specific key modifications and
filtering out adapter-specific parameters.

Args:
model: A ``PeftModel`` instance.

Returns:
The base model's state dict with original keys (without PEFT modifications).
"""
# For prompt learning methods the base model structure is not modified, so the state
# dict already uses the original keys and contains no adapter-injected parameters.
if model._is_prompt_learning:
return dict(model.base_model.state_dict())

state_dict = model.base_model.model.state_dict()
adapter_key_prefixes = _get_adapter_state_dict_key_prefixes(model)

result: dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
if _is_adapter_key(key, adapter_key_prefixes):
continue
result[_peft_key_to_original_key(model, key)] = value

return result


def set_base_model_state_dict(
model,
state_dict: dict[str, torch.Tensor],
strict: bool = True,
):
"""Load a state dict with original model keys into a PEFT-wrapped model.

Takes a state dict keyed by the original (pre-PEFT) model key names and loads it into the base model, automatically
translating keys to account for PEFT wrapper infixes (``base_layer``, ``original_module``, etc.).

This is the counterpart to :func:`get_base_model_state_dict` and is useful for scenarios like loading base model
weights after FSDP wrapping.

Args:
model: A ``PeftModel`` instance.
state_dict: The state dict with original model keys to load.
strict: Whether to strictly enforce that the keys match. If ``True``,
raises ``RuntimeError`` on missing or unexpected keys.

Returns:
A ``namedtuple`` with ``missing_keys`` and ``unexpected_keys`` fields.
"""
_IncompatibleKeys = namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])

# For prompt learning methods the base model structure is not modified, so the state
# dict already uses the original keys and contains no adapter-injected parameters.
if model._is_prompt_learning:
return model.base_model.load_state_dict(state_dict, strict=strict)

current_state_dict = model.base_model.model.state_dict()
adapter_key_prefixes = _get_adapter_state_dict_key_prefixes(model)

# Build mapping: original_key → peft_key
original_to_peft_key: dict[str, str] = {}
for peft_key in current_state_dict.keys():
if _is_adapter_key(peft_key, adapter_key_prefixes):
continue
original_to_peft_key[_peft_key_to_original_key(model, peft_key)] = peft_key

peft_state_dict: dict[str, torch.Tensor] = {}
unexpected_keys: list[str] = []

for original_key, value in state_dict.items():
if original_key in original_to_peft_key:
peft_state_dict[original_to_peft_key[original_key]] = value
else:
unexpected_keys.append(original_key)

missing_keys = [k for k in original_to_peft_key if k not in state_dict]

if strict and (missing_keys or unexpected_keys):
error_msgs: list[str] = []
if missing_keys:
error_msgs.append(f"Missing key(s) in state_dict: {missing_keys}")
if unexpected_keys:
error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected_keys}")
raise RuntimeError("Error(s) in loading state_dict:\n\t" + "\n\t".join(error_msgs))

model.base_model.model.load_state_dict(peft_state_dict, strict=False)
return _IncompatibleKeys(missing_keys=missing_keys, unexpected_keys=unexpected_keys)
4 changes: 4 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2988,6 +2988,10 @@ def test_delete_adapter_multiple_adapters_with_trainable_token_indices(self):
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)

@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_get_base_model_state_dict(self, test_name, model_id, config_cls, config_kwargs):
self._test_get_base_model_state_dict(model_id, config_cls, config_kwargs.copy())

@staticmethod
def _check_requires_grad(module, adapter_name, requires_grad):
# a bit of a clumsy way to test requires_grad on the PEFT parameters
Expand Down
6 changes: 6 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,9 @@ def test_lora_conversion(self, model_id, config_cls, config_kwargs):

config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_lora_conversion(model_id, config_cls, config_kwargs)

@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_get_base_model_state_dict(self, model_id, config_cls, config_kwargs):
_skip_if_not_conv1d_supported(model_id, config_cls)
self._test_get_base_model_state_dict(model_id, config_cls, config_kwargs.copy())
5 changes: 5 additions & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,11 @@ def test_disable_adapter(self, model_id, config_cls, config_kwargs):
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_disable_adapter(model_id, config_cls, config_kwargs)

@pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_get_base_model_state_dict(self, model_id, config_cls, config_kwargs):
self._test_get_base_model_state_dict(model_id, config_cls, config_kwargs.copy())

def test_active_adapters_prompt_learning(self):
model = AutoModelForSeq2SeqLM.from_pretrained(
"peft-internal-testing/tiny-random-BartForConditionalGeneration"
Expand Down
Loading
Loading