Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,19 @@ def add_weighted_adapter(
adapters: list[str],
weights: list[float],
adapter_name: str,
combination_type: str = "svd",
combination_type: Literal[
"svd",
"linear",
"cat",
"ties",
"ties_svd",
"dare_ties",
"dare_linear",
"dare_ties_svd",
"dare_linear_svd",
"magnitude_prune",
"magnitude_prune_svd",
] = "svd",
svd_rank: int | None = None,
svd_clamp: int | None = None,
svd_full_matrices: bool = True,
Expand All @@ -612,7 +624,9 @@ def add_weighted_adapter(
The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`,
`dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat`
combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the
mixed adapter may be too big and result in OOM errors).
mixed adapter may be too big and result in OOM errors). Note that `cat` and `svd` are precise methods
and will give you good accuracy, `linear` is efficient but a very rough approximation and should be
avoided if you can afford it.
svd_rank (`int`, *optional*):
Rank of output adapter for svd. If None provided, will use max rank of merging adapters.
svd_clamp (`float`, *optional*):
Expand Down Expand Up @@ -738,11 +752,12 @@ def _svd_generalized_task_arithmetic_weighted_adapter(
for adapter, weight in zip(adapters, weights):
if adapter in target.lora_A or adapter in target.lora_embedding_A:
valid_adapters.append(adapter)
valid_weights.append(weight * target.scaling[adapter])
valid_weights.append(weight)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 1.


# if no valid adapter, nothing to do
if len(valid_adapters) == 0:
raise ValueError("No matching LoRAs found. Please raise an issue on Github.")
raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.")
# get_delta_weight applies the scaling, no need to handle it explicitly
delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters]
valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device)
if combination_type == "svd":
Expand Down
57 changes: 42 additions & 15 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3326,9 +3326,29 @@ def test_add_weighted_adapter_subtraction_with_negative_weights(self):
dw_cancelled = module.get_delta_weight("cancelled")
assert torch.allclose(dw_cancelled, torch.zeros_like(dw_cancelled))

def test_add_weighted_adapter_negative_weight_with_different_scaling(self):
# Test negative weights with different scaling factors (lora_alpha)
# This edge case ensures negative weights work correctly with different scaling values
@pytest.mark.parametrize("weights", [[1.0, 1.0], [0.0, 1.0], [5.0, 0.01], [-1.0, -1.0], [0.5, -0.3]])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 4.

@pytest.mark.parametrize(
"combination_type, min_corr, max_mse",
[
# note: SVD and cat are 'precise', the others are approximation
("svd", 0.99, 0.01),
("cat", 0.99, 0.01),
("linear", 0.6, 1.0),
("ties", 0.4, 1.0),
("ties_svd", 0.8, 1.0),
("dare_ties", 0.1, 1.0),
("dare_ties_svd", 0.55, 1.0),
("dare_linear", 0.2, 1.0),
("dare_linear_svd", 0.6, 1.0),
("magnitude_prune", 0.55, 1.0),
("magnitude_prune_svd", 0.9, 0.1),
],
)
def test_add_weighted_adapter_with_different_scaling(self, weights, combination_type, min_corr, max_mse):
# Check that the actually merged weights correspond to what their theoretical value should be. Note that each
# method is an approximation so we can never expect exact equality. We thus test for correlation and MSE as a
# proxy. The acceptance criteria are empirically determined and thus serve more as a regression test than
# actually proving that the merging method works.
torch.manual_seed(42)
model = MLP()

Expand All @@ -3337,36 +3357,43 @@ def test_add_weighted_adapter_negative_weight_with_different_scaling(self):
r=8,
lora_alpha=16, # scaling = 16/8 = 2
target_modules=["lin0"],
lora_dropout=0.0,
bias="none",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed because it's the default anyway, so it's less noise.

init_lora_weights=False,
)
config2 = LoraConfig(
r=8,
lora_alpha=32, # scaling = 32/8 = 4
target_modules=["lin0"],
lora_dropout=0.0,
bias="none",
init_lora_weights=False,
)

model = get_peft_model(model, config1, adapter_name="adapter1")
model.add_adapter("adapter2", config2)

# Merge with negative weight - should handle different scalings correctly
model.add_weighted_adapter(
adapters=["adapter1", "adapter2"],
weights=[0.5, -0.3],
adapter_name="merged_diff_scaling",
combination_type="linear",
weights=weights,
adapter_name="merged",
combination_type=combination_type,
density=0.5,
)

# Verify the merged adapter can run forward pass
model.set_adapter("merged_diff_scaling")
model.set_adapter("merged")
dummy_input = torch.randn(2, 10)
output = model(dummy_input)
assert output is not None

# We cannot expect the merged weights to be approximately equal because we're dealing with rough approximations.
# Therefore, we check for correlation to verify that the direction is right and MSE to verify that the magnitude
# is right.
for module in model.modules():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 4.

if isinstance(module, lora.LoraLayer):
dw1 = module.get_delta_weight("adapter1")
dw2 = module.get_delta_weight("adapter2")
dw_merged = module.get_delta_weight("merged")
expected = weights[0] * dw1 + weights[1] * dw2
corr = torch.corrcoef(torch.stack((dw_merged.flatten(), expected.flatten())))
mse = ((dw_merged - expected) ** 2).mean()
assert corr[0, 1] > min_corr
assert mse < max_mse

def test_multiple_adapters_no_needless_copy_modules_to_save(self):
# See 2206
# The problem was that we keep a "global" modules_to_save on the model which contains all possible
Expand Down
35 changes: 16 additions & 19 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,21 +1591,7 @@ def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_lis
density=0.5,
)

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_svd_reweighting",
"multi_adapter_ties_svd_reweighting",
"multi_adapter_dare_linear_svd_reweighting",
"multi_adapter_dare_ties_svd_reweighting",
"multi_adapter_magnitude_prune_svd_reweighting",
"multi_adapter_cat_reweighting",
"multi_adapter_linear_reweighting",
"multi_adapter_linear_reweighting_single_enabled",
"multi_adapter_ties_reweighting",
"multi_adapter_dare_linear_reweighting",
"multi_adapter_dare_ties_reweighting",
"multi_adapter_magnitude_prune_reweighting",
]
new_adapters = [k for k in model.peft_config.keys() if not k.startswith("adapter_")]
for new_adapter in new_adapters:
assert new_adapter in model.peft_config

Expand All @@ -1614,11 +1600,11 @@ def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_lis
_, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer):
for adapter_name in new_adapters:
# for a single adapter, the result should be exact and we can check that; otherwise, we deal with
# approximations
if "single" in adapter_name:
new_delta_weight = target.get_delta_weight(adapter_name)
weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0]
sign = 1 if weight_list[0] > 0 else -1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 2.

weighted_original_delta_weights = sign * weighted_original_delta_weights
assert torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4)
elif "svd" in adapter_name:
assert target.r[adapter_name] == 20
Expand Down Expand Up @@ -1673,7 +1659,7 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
if "gemma" in model_id.lower():
return pytest.skip("Combining Gemma adapters with SVD is currently failing")

adapter_list = ["adapter1", "adapter_2", "adapter_3"]
adapter_list = ["adapter_1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5]
negative_weight_list = [-0.5, -0.8, -1.2]
# Initialize the config
Expand All @@ -1690,11 +1676,22 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0])

# test positive weights
if isinstance(config, LoraConfig):
self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, weight_list)
self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, negative_weight_list)
elif isinstance(config, IA3Config):
self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, weight_list)
else:
pytest.skip(f"Test not applicable for {config}")

del model
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is issue 3.

model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0])

# test negative weights
if isinstance(config, LoraConfig):
self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, negative_weight_list)
elif isinstance(config, IA3Config):
self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, negative_weight_list)
else:
pytest.skip(f"Test not applicable for {config}")
Expand Down
Loading