Skip to content

FIX Several bugs when adding merged LoRA weights#3111

Open
BenjaminBossan wants to merge 2 commits intohuggingface:mainfrom
BenjaminBossan:fix-merging-of-lora-weights
Open

FIX Several bugs when adding merged LoRA weights#3111
BenjaminBossan wants to merge 2 commits intohuggingface:mainfrom
BenjaminBossan:fix-merging-of-lora-weights

Conversation

@BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Mar 18, 2026

To be clear: This is about combining different LoRA adapters into a single one (add_weighted_adapter), not about merging LoRA weights into the base weights (merge).

There were a few issues being addressed:

  1. In the SVD path, the LoRA scaling was applied to the combination weightings (!= LoRA weights) but get_delta_weight already takes account of the scaling, meaning that scaling was effectively applied twice.
  2. In _test_weighted_combination_of_adapters_lora in testing_common.py, the sign was applied twice, canceling out for negative weights.
  3. The bug in 2. was masked because in _test_weighted_combination_of_adapters, the _test_weighted_combination_of_adapters_lora method was called twice with the same model. Since, after the first time, the combined adapters with the given names already existed, add_weighted_adapter simply skips the creation of a new adapter. Therefore, the second call basically did nothing. This is now fixed by using a new model for each call.
  4. The test_add_weighted_adapter_negative_weight_with_different_scaling test did not actually test that the expected weight is equal to the actual weight and it also only tested the linear combination type, meaning that errors like in 1. were missed. This test now tests multiple combination types with different weightings and checks the resulting delta weight. Note that I only included svd, cat (both precise), and linear (imprecise). The other methods (dare, ties, magnitude pruning) had huge errors in my testing.

While working on this, I also fixed:

  • Add more precise type annotation for combination_type
  • Better document the options
  • Add comments for better understanding
  • Use more consistent names in tests

Note: Since this PR fixes multiple, partly interdependent, issues, I left comments to point to the corresponding issue.

To be clear: This is about combining different LoRA adapters into a
single one, not about merging LoRA weights into the base weights.

There were a few issues being addressed:

1. In the SVD path, the LoRA scaling was applied to the combination
weights (!= LoRA weights) but get_delta_weight already takes account of
the scaling, meaning that scaling was effectively applied twice.
2. In _test_weighted_combination_of_adapters_lora in testing_common.py,
the sign was applied twice, canceling out for negative weights.
3. This bug was masked because in
_test_weighted_combination_of_adapters, this function was called twice
with the same model. I.e. after the first time, the adapter already
existed. But add_weighted_adapter simply skips if an adapter with the
same name already exists. Therefore, the second call basically did
nothing. This is now fixed by using a new model.
4. The test_add_weighted_adapter_negative_weight_with_different_scaling
test did not actually test that the expected weight is equal to the
actual weight and it also only tested the linear combination type,
meaning that errors like in 1. were missed. This test now tests multiple
combination types with different weightings and checks the resulting
delta weight. Note that I only included svd, cat (both precise), and
linear (imprecise). The other methods (dare, ties, magnitude pruning)
had huge errors in my testing.

While working on this, I also fixed:

- Add more precise type annotation for combination_type
- Better document the options
- Add comments for better understanding
- Use more consistent names in tests
if adapter in target.lora_A or adapter in target.lora_embedding_A:
valid_adapters.append(adapter)
valid_weights.append(weight * target.scaling[adapter])
valid_weights.append(weight)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 1.

def test_add_weighted_adapter_negative_weight_with_different_scaling(self):
# Test negative weights with different scaling factors (lora_alpha)
# This edge case ensures negative weights work correctly with different scaling values
@pytest.mark.parametrize("weights", [[1.0, 1.0], [0.0, 1.0], [5.0, 0.01], [-1.0, -1.0], [0.5, -0.3]])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 4.

lora_alpha=16, # scaling = 16/8 = 2
target_modules=["lin0"],
lora_dropout=0.0,
bias="none",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed because it's the default anyway, so it's less noise.

# We cannot expect the merged weights to be approximately equal because we're dealing with rough approximations.
# Therefore, we check for correlation to verify that the direction is right and MSE to verify that the magnitude
# is right.
for module in model.modules():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 4.

if "single" in adapter_name:
new_delta_weight = target.get_delta_weight(adapter_name)
weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0]
sign = 1 if weight_list[0] > 0 else -1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 2.

else:
pytest.skip(f"Test not applicable for {config}")

del model
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 3.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes correctness issues in add_weighted_adapter when combining multiple LoRA adapters into a new adapter, and strengthens the test suite to catch negative-weight/scaling edge cases that were previously missed.

Changes:

  • Fix double-application of LoRA scaling in the SVD-based weighted-combination path.
  • Correct test logic for negative weights and ensure the weighted-combination test actually runs twice by using a fresh model instance.
  • Expand/parameterize tests to validate multiple combination_type modes and check merged delta-weights (direction + magnitude).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
src/peft/tuners/lora/model.py Adjusts SVD weighted-combination logic (avoid double scaling) and tightens combination_type typing/docs.
tests/testing_common.py Fixes weighted-combination test harness so negative-weight coverage is effective and adapter assertions are robust.
tests/test_custom_models.py Reworks/extends coverage for mixed scaling factors across multiple combination types, including negative weights.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

- fix typo
- fix error message and comment
- use MSE instead of SSE

I also decided to test all combination types, even if they give very bad
results. This is less as a 'proof' that they work, but at least we get
an error if there is any regression.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants