-
Notifications
You must be signed in to change notification settings - Fork 2.2k
FIX Several bugs when adding merged LoRA weights #3111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3326,9 +3326,29 @@ def test_add_weighted_adapter_subtraction_with_negative_weights(self): | |
| dw_cancelled = module.get_delta_weight("cancelled") | ||
| assert torch.allclose(dw_cancelled, torch.zeros_like(dw_cancelled)) | ||
|
|
||
| def test_add_weighted_adapter_negative_weight_with_different_scaling(self): | ||
| # Test negative weights with different scaling factors (lora_alpha) | ||
| # This edge case ensures negative weights work correctly with different scaling values | ||
| @pytest.mark.parametrize("weights", [[1.0, 1.0], [0.0, 1.0], [5.0, 0.01], [-1.0, -1.0], [0.5, -0.3]]) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: This is issue 4. |
||
| @pytest.mark.parametrize( | ||
| "combination_type, min_corr, max_mse", | ||
| [ | ||
| # note: SVD and cat are 'precise', the others are approximation | ||
| ("svd", 0.99, 0.01), | ||
| ("cat", 0.99, 0.01), | ||
| ("linear", 0.6, 1.0), | ||
| ("ties", 0.4, 1.0), | ||
| ("ties_svd", 0.8, 1.0), | ||
| ("dare_ties", 0.1, 1.0), | ||
| ("dare_ties_svd", 0.55, 1.0), | ||
| ("dare_linear", 0.2, 1.0), | ||
| ("dare_linear_svd", 0.6, 1.0), | ||
| ("magnitude_prune", 0.55, 1.0), | ||
| ("magnitude_prune_svd", 0.9, 0.1), | ||
| ], | ||
| ) | ||
| def test_add_weighted_adapter_with_different_scaling(self, weights, combination_type, min_corr, max_mse): | ||
| # Check that the actually merged weights correspond to what their theoretical value should be. Note that each | ||
| # method is an approximation so we can never expect exact equality. We thus test for correlation and MSE as a | ||
| # proxy. The acceptance criteria are empirically determined and thus serve more as a regression test than | ||
| # actually proving that the merging method works. | ||
| torch.manual_seed(42) | ||
| model = MLP() | ||
|
|
||
|
|
@@ -3337,36 +3357,43 @@ def test_add_weighted_adapter_negative_weight_with_different_scaling(self): | |
| r=8, | ||
| lora_alpha=16, # scaling = 16/8 = 2 | ||
| target_modules=["lin0"], | ||
| lora_dropout=0.0, | ||
| bias="none", | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was removed because it's the default anyway, so it's less noise. |
||
| init_lora_weights=False, | ||
| ) | ||
| config2 = LoraConfig( | ||
| r=8, | ||
| lora_alpha=32, # scaling = 32/8 = 4 | ||
| target_modules=["lin0"], | ||
| lora_dropout=0.0, | ||
| bias="none", | ||
| init_lora_weights=False, | ||
| ) | ||
|
|
||
| model = get_peft_model(model, config1, adapter_name="adapter1") | ||
| model.add_adapter("adapter2", config2) | ||
|
|
||
| # Merge with negative weight - should handle different scalings correctly | ||
| model.add_weighted_adapter( | ||
| adapters=["adapter1", "adapter2"], | ||
| weights=[0.5, -0.3], | ||
| adapter_name="merged_diff_scaling", | ||
| combination_type="linear", | ||
| weights=weights, | ||
| adapter_name="merged", | ||
| combination_type=combination_type, | ||
| density=0.5, | ||
| ) | ||
|
|
||
| # Verify the merged adapter can run forward pass | ||
| model.set_adapter("merged_diff_scaling") | ||
| model.set_adapter("merged") | ||
| dummy_input = torch.randn(2, 10) | ||
| output = model(dummy_input) | ||
| assert output is not None | ||
|
|
||
| # We cannot expect the merged weights to be approximately equal because we're dealing with rough approximations. | ||
| # Therefore, we check for correlation to verify that the direction is right and MSE to verify that the magnitude | ||
| # is right. | ||
| for module in model.modules(): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: This is issue 4. |
||
| if isinstance(module, lora.LoraLayer): | ||
| dw1 = module.get_delta_weight("adapter1") | ||
| dw2 = module.get_delta_weight("adapter2") | ||
| dw_merged = module.get_delta_weight("merged") | ||
| expected = weights[0] * dw1 + weights[1] * dw2 | ||
| corr = torch.corrcoef(torch.stack((dw_merged.flatten(), expected.flatten()))) | ||
| mse = ((dw_merged - expected) ** 2).mean() | ||
| assert corr[0, 1] > min_corr | ||
| assert mse < max_mse | ||
|
|
||
| def test_multiple_adapters_no_needless_copy_modules_to_save(self): | ||
| # See 2206 | ||
| # The problem was that we keep a "global" modules_to_save on the model which contains all possible | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1591,21 +1591,7 @@ def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_lis | |
| density=0.5, | ||
| ) | ||
|
|
||
| new_adapters = [ | ||
| "single_adapter_reweighting", | ||
| "multi_adapter_svd_reweighting", | ||
| "multi_adapter_ties_svd_reweighting", | ||
| "multi_adapter_dare_linear_svd_reweighting", | ||
| "multi_adapter_dare_ties_svd_reweighting", | ||
| "multi_adapter_magnitude_prune_svd_reweighting", | ||
| "multi_adapter_cat_reweighting", | ||
| "multi_adapter_linear_reweighting", | ||
| "multi_adapter_linear_reweighting_single_enabled", | ||
| "multi_adapter_ties_reweighting", | ||
| "multi_adapter_dare_linear_reweighting", | ||
| "multi_adapter_dare_ties_reweighting", | ||
| "multi_adapter_magnitude_prune_reweighting", | ||
| ] | ||
| new_adapters = [k for k in model.peft_config.keys() if not k.startswith("adapter_")] | ||
| for new_adapter in new_adapters: | ||
| assert new_adapter in model.peft_config | ||
|
|
||
|
|
@@ -1614,11 +1600,11 @@ def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_lis | |
| _, target, _ = _get_submodules(model, key) | ||
| if isinstance(target, LoraLayer): | ||
| for adapter_name in new_adapters: | ||
| # for a single adapter, the result should be exact and we can check that; otherwise, we deal with | ||
| # approximations | ||
| if "single" in adapter_name: | ||
| new_delta_weight = target.get_delta_weight(adapter_name) | ||
| weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0] | ||
| sign = 1 if weight_list[0] > 0 else -1 | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: This is issue 2. |
||
| weighted_original_delta_weights = sign * weighted_original_delta_weights | ||
| assert torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4) | ||
| elif "svd" in adapter_name: | ||
| assert target.r[adapter_name] == 20 | ||
|
|
@@ -1673,7 +1659,7 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw | |
| if "gemma" in model_id.lower(): | ||
| return pytest.skip("Combining Gemma adapters with SVD is currently failing") | ||
|
|
||
| adapter_list = ["adapter1", "adapter_2", "adapter_3"] | ||
| adapter_list = ["adapter_1", "adapter_2", "adapter_3"] | ||
| weight_list = [0.5, 1.5, 1.5] | ||
| negative_weight_list = [-0.5, -0.8, -1.2] | ||
| # Initialize the config | ||
|
|
@@ -1690,11 +1676,22 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw | |
| model = self.transformers_class.from_pretrained(model_id) | ||
| model = get_peft_model(model, config, adapter_list[0]) | ||
|
|
||
| # test positive weights | ||
| if isinstance(config, LoraConfig): | ||
| self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, weight_list) | ||
| self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, negative_weight_list) | ||
| elif isinstance(config, IA3Config): | ||
| self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, weight_list) | ||
| else: | ||
| pytest.skip(f"Test not applicable for {config}") | ||
|
|
||
| del model | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: This is issue 3. |
||
| model = self.transformers_class.from_pretrained(model_id) | ||
| model = get_peft_model(model, config, adapter_list[0]) | ||
|
|
||
| # test negative weights | ||
| if isinstance(config, LoraConfig): | ||
| self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, negative_weight_list) | ||
| elif isinstance(config, IA3Config): | ||
| self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, negative_weight_list) | ||
| else: | ||
| pytest.skip(f"Test not applicable for {config}") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This is issue 1.