Skip to content

add glora#3098

Open
not-lain wants to merge 12 commits intohuggingface:mainfrom
not-lain:glora2
Open

add glora#3098
not-lain wants to merge 12 commits intohuggingface:mainfrom
not-lain:glora2

Conversation

@not-lain
Copy link

follow up on #780 and #2568
this pr adds GLoRA to the library
i also made a minor colab notebook to test this remotely in which you can find here

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviving the GLoRA implementation to PEFT. There is still a lot missing (docs, examples, more tests) but let's focus on the integration for now and work on the rest later.

Unfortunately, this PR seems to be based on the state that PEFT was in when the PR was first suggested. We made several refactors since then, which require different patterns to implement. The good news is that the final result should be much simpler with less code required. I marked the corresponding parts, please check. Maybe this is even something that a coding agent can do if asked to update the implementation and pointed to the most recent PEFT code for reference.



@dataclass
class GLoraConfig(PeftConfig):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about GloraConfig, GloraLayer etc. for easier typing and better consistency with the rest of PEFT?

_VALID_D_E_CONFIGS = {"constant", "none", "vector"}

r: int = field(
default=4, metadata={"help": "Default rank of the LoRA matrices if the config contains LoRA parametrization."}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not 8 as for LoRA?

Comment on lines +35 to +36
_VALID_A_B_CONFIGS = {"LoRA", "vector", "constant", "none"}
_VALID_C_CONFIGS = {"LoRA", "vector", "none"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is easier to have everything in lower case, i.e. "LoRA" -> "lora"

config_A_B: str = field(
default="LoRA",
metadata={
"help": "Configuration for A and B matrices in GLora."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these explanations lack a bit more context about how they differ.



# Refactored GLoraLinear for PEFT compatibility
class GLoraLinear(GLoraLayer, nn.Linear):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably based this PR on the very old original PR. Since then, we had several refactors in PEFT, some of which affect how we designed the adapter layers. Could you please check the latest implementation in tuners/lora/layer or PRs of other recent PEFT additions (e.g. #2851)? Most notably, we now pass the base_layer (i.e. the original layer) and wrap it inside the PEFT layer. To get the results of the base layer, we can then call self.base_layer(x) in the forward call (no need for F.linear(x, self.weight)).

m.bias.requires_grad = True


class GLoraModel(BaseTuner):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar argument as for GLoraLayer: We have substantially refactored this part of PEFT. The good news is that it should greatly simplify the overall implementation: You only need to define _create_and_replace and _create_new_module, the remaining methods should all be fine when inherited from the parent class. Moreover, we need these class attributes:

  • prefix: str = "glora_"
  • tuner_layer_cls = GloraLayer
  • target_module_mapping = TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to have a separate test file for each PEFT method. Instead, you should add an entry to the test matrix for GLoRA, which is just a few lines of code and allows to run a wide battery of tests. As a start, please add GLoRA here:

TEST_CASES = [

Check how it's done for other methods and use the same approach. Before pushing your other changes, ensure that the tests pass locally (pytest tests/test_custom_models.py -k glora -v).

self.glora_Cu: nn.ParameterDict = nn.ParameterDict()
self.glora_D: nn.ParameterDict = nn.ParameterDict()
self.glora_E: nn.ParameterDict = nn.ParameterDict()
self.eval_config: dict[str, dict[str, object]] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? IIUC, this is supposed to be an object that unifies the settings for each parameter, as we have distinct options like "vector" or "constant". I haven't checked all the details, but ideally we would just define during the layer initialization something like:

if config_A_B == "lora":
    self.glora_Ad[adapter_name] = nn.Linear(...)
elif config_A_B == "vector":
    self.glora_Ad[adapter_name] = ...

The arguments config_A_B etc. should be passed to the __init__ and update_layer methods, directly coming from the GloraConfig.

This change may require some custom nn.Modules but that would be fine. I would really like to avoid the whole prepare_path call during forward and frontload the whole resolution to the initialization. This make the forward/merge/unmerge call simpler to understand and should also be slightly more performant.

Comment on lines +80 to +83
try:
rank = int(config.split("_")[1])
except Exception:
rank = 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why pass make_config(..., config=f"LoRA_{r}") only to immediately parse the r into an int again? Why not pass r as an int directly? Honestly, I think we can completely remove make_param and just directly initialize the parameters.

not-lain and others added 4 commits March 17, 2026 07:55
@BenjaminBossan
Copy link
Member

Note that due to another PEFT method being merged, there is now a merge conflict, but it should be straightforward to resolve. Once you're finished, please ping me for another review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants