Skip to content

Add get base model state dict#3000

Open
Isalia20 wants to merge 6 commits intohuggingface:mainfrom
Isalia20:add-get-base-model-state-dict
Open

Add get base model state dict#3000
Isalia20 wants to merge 6 commits intohuggingface:mainfrom
Isalia20:add-get-base-model-state-dict

Conversation

@Isalia20
Copy link
Contributor

Fixes #2945

@Isalia20
Copy link
Contributor Author

@BenjaminBossan Would be glad if you could review it when you get a chance

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR. I currently don't have much time to review and will be OoO next week. so hopefully @githubnemo can take over.

Just my first observations:

  1. From the original issue, I think we concluded that we would rather need a set_base_model_state_dict. That doesn't mean that get_base_model_state_dict doesn't have it's merits, but it wouldn't fully solve the issue. Ping @dvmazur.
  2. There can be deeper nesting of .base_layer, so it should run in a loop: while ".base_layer" in new_key: ....
  3. This doesn't take into account trainable tokens yet, they need to be treated similarly as modules_to_save.

@dvmazur
Copy link

dvmazur commented Jan 17, 2026

Hi! Thanks for this PR! Yeah, I rather need a set_base_model_state_dict, but it should be pretty easy to implement once we have a get_base_model_state_dict I think. Also, maybe we should expand the test matrix to make sure this method works for other PEFTs?

@Isalia20
Copy link
Contributor Author

Hi, I'll add the set method as well and more tests little later today

@Isalia20
Copy link
Contributor Author

Added the set base state dict and more tests

@Isalia20
Copy link
Contributor Author

@githubnemo Would be glad if you could review this :)

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Isalia20 :) Thanks for taking this on.

I need a bit of clarification (possibly from @dvmazur): do I understand correctly that one use-case is that we have a model that doesn't fit in memory so we need to first shard the empty model (via FSDP) onto several devices and then read the checkpoint onto the shards (in a streaming manner)? Furthermore, do I understand correctly that it is not possible to shard the base model and then apply PEFT on top of that? If that are the reasons for why this is useful, we should probably document that as well since it is not obvious.

The implementation seems to pass at first glance but there might be a few pitfalls still. I left one comment regarding a potential bug.

Let's build a test (e.g., a merge of test_get_base_model_state_dict_keys_match and test_get_base_model_state_dict_values_match) and integrate it into tests/testing_common.py (similar to _test_save_pretrained) to be called from the more exhaustive testing suites in tests/test_decoder_models.py, tests/test_encoder_decoder_models.py and tests/test_custom_models.py which cover a lot more cases. For example, trainable tokens and parameter targeting are not covered by the current tests and there are probably a lot more special cases, so leveraging the existing tests is probably best.

Comment on lines +1771 to +1777
for prefix in adapter_prefixes:
if f".{prefix}" in peft_key or peft_key.startswith(prefix):
is_adapter_param = True
break

if is_adapter_param:
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not a sufficient filter for methods like VeRA or VB-LoRA that employ weight sharing. This will be covered by the extended tests I suppose.

An alternative approach would be to iterate over all named modules of the model and remove those keys that belong to BaseTunerLayer instances (since the weight shared keys are caught by the prefix matching already in place). But lets see what the tests say first, maybe I'm wrong and everything works fine :)

@dvmazur
Copy link

dvmazur commented Jan 21, 2026

Hi!

I want to be able to load the base model's and adapter's state_dicts after wrapping the PEFT model in FSDP. The state_dict's keys match the original base model's keys, so I need a function that will map the wrapped model's keys to the original ones if that makes sense.

@Isalia20
Copy link
Contributor Author

Thanks for the comments. I'll take a look little later this week

@githubnemo
Copy link
Collaborator

Hey @dvmazur,

I want to be able to load the base model's and adapter's state_dicts after wrapping the PEFT model in FSDP.

I got that but why? What's your motivation? My question supposed that memory is a constraint and that's the reason but you didn't acknowledge nor refute that. Please give a bit more detail so that I can understand the use-case better. Thanks!

@dvmazur
Copy link

dvmazur commented Jan 21, 2026

The end goal is to have PEFT working for TorchTitan basically. Titan wraps models into FSDP to save VRAM, it also allocates GPU memory only after the model's meta-device weights were FSDP-sharded.

I think this pseudocode snippet should give you enough info, but feel free to ask if you need any more info:

with torch.device("meta"):
    # can't load base model weights here as it is on meta device before resharding
    model = AutoModelForCausalLM.from_pretrained(...)
    # can only wrap model in peft before fsdp-sharding it
    model = get_perft_model(model, ...)

model = fsdp_shard_model(model)

# actually allocate memory for the model's weights
# state dict can be loaded after that
model.to_empty(device=init_device)

# this function loads a state dict with the original model's module keys
# so I need a way to map them to the PEFT-wrapped model
load_base_model_state_dict(model)
initialize_adapters(model)

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

not stale

@githubnemo
Copy link
Collaborator

@Isalia20 gentle ping :)

@Isalia20
Copy link
Contributor Author

Isalia20 commented Mar 4, 2026

oh sorry my bad, forgot about this PR. Will try to resolve comments this week

@Isalia20
Copy link
Contributor Author

Isalia20 commented Mar 8, 2026

updated

@Isalia20
Copy link
Contributor Author

@githubnemo any updates on this?

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nice! I've made some suggestions / comments, most importantly about the missing aux. training wrapper handling but otherwise this is looking quite good.

One organisational thing: It would be nice of you to update the PR description with what this change entails and what the motivation was (i.e. support for TorchTitan). The description will be the commit message of the PR so it would be very helpful for future developers to know the context this change was made for. Precision over verbosity.

@@ -0,0 +1,349 @@
# Copyright 2025-present the HuggingFace Inc. team.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2025-present the HuggingFace Inc. team.
# Copyright 2026-present the HuggingFace Inc. team.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move the changes from peft_model.py to utils/save_and_load.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there's precedent for this in PeftModel. Unless I'm mistaken, let's remove these calls since the current pattern is to import from save_and_load utils.

"""Return a list of (config_name, config) tuples for parametrized testing"""
return [
("lora", LoraConfig(r=4, lora_alpha=2, target_modules=["q_proj", "v_proj"])),
("lora_all_linear", LoraConfig(r=4, lora_alpha=2, target_modules="all-linear")),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
("lora_all_linear", LoraConfig(r=4, lora_alpha=2, target_modules="all-linear")),
("lora_all_linear", LoraConfig(r=4, lora_alpha=2, target_modules="all-linear")),
("lora_trainable_tokens", LoraConfig(r=4, trainable_token_indices=[0, 1], target_modules=["q_proj", "v_proj"])),
("trainable_tokens", TrainableTokensConfig(token_indices=[0, 1])),

This will highlight why it's necessary to inspect the state dict of the aux. training wrappers.

from .testing_utils import hub_online_once


MODEL_ID = "peft-internal-testing/tiny-random-OPTForCausalLM"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wrap the tests in a class and parametrize the model id. I think it's a good idea to test at least an encoder-decoder model like T5 as well.

Comment on lines +74 to +80
@pytest.mark.parametrize("config_name,peft_config", get_peft_configs(), ids=[c[0] for c in get_peft_configs()])
def test_get_base_model_state_dict_keys_match(base_model, config_name, peft_config):
"""Test that get_base_model_state_dict returns keys matching the original base model."""
base_model_keys = set(base_model.state_dict().keys())
peft_model = get_peft_model(base_model, peft_config)
extracted_keys = set(peft_model.get_base_model_state_dict().keys())
assert base_model_keys == extracted_keys, f"Key mismatch for {config_name}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have this test in testing_common.py and call it in test_encoder_decoder_models.py, test_decoder_models.py and test_custom_models.py.

>>> result = peft_model.set_base_model_state_dict(base_weights)
```
"""
from collections import namedtuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this an top-level import

@Isalia20
Copy link
Contributor Author

Updated

@Isalia20
Copy link
Contributor Author

@githubnemo gentle ping

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. Some of the tests fail in CI, can you take a look?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there's precedent for this in PeftModel. Unless I'm mistaken, let's remove these calls since the current pattern is to import from save_and_load utils.

original_parts.append(part)
continue

child = getattr(current, part, None)
Copy link
Collaborator

@githubnemo githubnemo Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you motivate (in a code comment) why we're traversing the model this way and not using model.get_named_modules() instead? This seems like an important decision so it would be unwise to not document it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Return base model state_dict with original keys

5 participants