Skip to content

[TinyLoRA]tinylora implementation#3024

Open
kashif wants to merge 29 commits intohuggingface:mainfrom
kashif:tinylora
Open

[TinyLoRA]tinylora implementation#3024
kashif wants to merge 29 commits intohuggingface:mainfrom
kashif:tinylora

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Feb 6, 2026

Adds TinyLoRA, a new PEFT method based on "TinyLoRA: Learning to Reason in 13 Parameters". TinyLoRA achieves extreme parameter efficiency by replacing LoRA's trainable low-rank matrices with a tiny trainable vector projected through fixed random bases.

The key idea: given a frozen SVD decomposition W ≈ B @ A (where B = U @ sqrt(S) and A = sqrt(S) @ V^T), the weight update is delta_W = B @ R @ A where R is an r x r trainable matrix (following LoRA-XS). TinyLoRA takes this further by parameterizing R as a linear combination of fixed random projection matrices:

  R = sum_i(v[i] * P[i])

where v is the only trainable parameter (as small as 13 values) and P_i are fixed random matrices seeded deterministically.

Features

  • Extreme efficiency: trainable parameter count is u per target module (or even less with weight tying), compared to r * (in + out) for LoRA
  • Weight tying: configurable sharing of v vectors across layers via weight_tying (0.0 = no sharing, 1.0 = all layers share one v)
  • SVD initialization: frozen A and B matrices computed from truncated SVD of pretrained weights, with singular values distributed equally via sqrt(S)
  • Full layer support: nn.Linear, Conv1D, and nn.Embedding
  • Merge/unmerge: full support including safe merge with NaN checking
  • LoRA conversion: supports_lora_conversion() -> True — delta weights can be converted to standard LoRA format via get_delta_weight
  • Deterministic projections: P matrices are seeded per-layer for reproducibility; optionally saved in checkpoints (save_projection=True)

Config

  from peft import TinyLoraConfig, get_peft_model

  config = TinyLoraConfig(
      r=2,              # SVD rank (frozen)
      u=64,             # trainable vector dimension
      weight_tying=0.0, # 0.0=no sharing, 1.0=full sharing
      target_modules="all-linear",
  )
  model = get_peft_model(base_model, config)

Architecture

  • TinyLoraLayer (base): SVD decomposition, projection init, get_delta_weight, supports_lora_conversion
  • Linear / Embedding: forward pass, merge/unmerge
  • TinyLoraModel: weight tying groups, shared v parameter management via nested ModuleDict/ParameterDict
  • update_layer follows LoRA's config-object pattern: (adapter_name, tinylora_v, v_key, r, config, **kwargs)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif requested a review from githubnemo February 8, 2026 08:21
@kashif
Copy link
Contributor Author

kashif commented Feb 10, 2026

cc @jxmorris12 I have an implementation of TinyLoRA if you can kindly have a look?

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @kashif :)

Thanks for the PR, this is already solid.
Merging with main should hopefully resolve the CI errors.

Some questions and comments below.

@kashif
Copy link
Contributor Author

kashif commented Feb 23, 2026

thanks fixing

kashif and others added 13 commits February 23, 2026 19:44
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
@kashif
Copy link
Contributor Author

kashif commented Feb 24, 2026

@githubnemo should be ready for another review thanks

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick response :)

I think implementation-wise this is, except for two nits, good to go.

Let's add an example that showcases the primary use-case and add the method to the method comparison suite (maybe copy from method_comparison/MetaMathQA/experiments/lora/... and see where it takes us).

It'd also be stellar to have a commit message / PR description that is meaningful in the commit history.

@kashif kashif changed the title [TinyLoRA] initial tinylora implementation [TinyLoRA]tinylora implementation Feb 24, 2026
@kashif
Copy link
Contributor Author

kashif commented Feb 24, 2026

ready @githubnemo

@BenjaminBossan
Copy link
Member

Thanks for the PR Kashif. I ran the experiments on my machine and for got a test accuracy of 0% and 0.002% :)
This isn't really surprising with 3584 and 64 trainable parameters, respectively, and basically confirms the paper results. Still good to include them, but without context users could be drawing the wrong conclusion.

@kashif
Copy link
Contributor Author

kashif commented Feb 25, 2026

yes @BenjaminBossan i will test with the RL setup, we can wait if its ok, I want to also double check that nothing is wrong

@BenjaminBossan
Copy link
Member

Out of curiosity, I wanted to check if TinyLoRA can achieve better scores if we increase the number of trainable parameters. So I took the default* setting and increased u and indeed, we can get decent results.

u accelerator memory max num trainable params test_accuracy
512 20310917120 28672 0.30477634571645185
1024 20434649088 57344 0.3434420015163002
2024 20673724416 113344 0.31766489764973466

Given the still tiny number of trainable parameters, this result is quite respectable. This is also a nice confirmation that there is no major bug in the implementation.

I wonder if it would make sense to have a "maximalist" and a "minimalist" config, i.e. one with more trainable parameters and better score and one with extremely few trainable parameters (basically the current llama-3.2-3B-weight-tying) and low score.

*One more change I did was to increase r to 32 as 2 seemed pretty small to me, but that was just on a whim.

@githubnemo
Copy link
Collaborator

I wonder if it would make sense to have a "maximalist" and a "minimalist" config, i.e. one with more trainable parameters and better score and one with extremely few trainable parameters (basically the current llama-3.2-3B-weight-tying) and low score.

I think that's a good thing to have!

I also wondered if it would make sense to extend the target modules to all-linear since it is so cheap, but then again this would diverge from other experiments quite a lot and be not as comparable.

@kashif
Copy link
Contributor Author

kashif commented Mar 1, 2026

should we just document this? or add it somewhere else?

@BenjaminBossan
Copy link
Member

I also wondered if it would make sense to extend the target modules to all-linear since it is so cheap, but then again this would diverge from other experiments quite a lot and be not as comparable.

I'd rather target either the attention xor the MLP part for consistence with other experiments.

should we just document this? or add it somewhere else?

What does "this" reference here?

@kashif
Copy link
Contributor Author

kashif commented Mar 2, 2026

ah sorry, i meant the minimal/maximal config?

@githubnemo
Copy link
Collaborator

githubnemo commented Mar 3, 2026

ah sorry, i meant the minimal/maximal config?

Yes, let's add a 'maximalist' config with (possibly) r > 2 and u >= 2048. VeRA has ~128k parameters with 37.6% task accuracy according to https://huggingface.co/spaces/peft-internal-testing/PEFT-method-comparison - maybe it makes sense to match that setting (u ~= 2300) to see where the ceiling is within a reasonable memory budget (~+0.5GB).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants