Skip to content

Add KappaTuneSelector: condition-number-based automatic LoRA target selection#3106

Open
oswaldoludwig wants to merge 9 commits intohuggingface:mainfrom
oswaldoludwig:main
Open

Add KappaTuneSelector: condition-number-based automatic LoRA target selection#3106
oswaldoludwig wants to merge 9 commits intohuggingface:mainfrom
oswaldoludwig:main

Conversation

@oswaldoludwig
Copy link

What this PR does

This PR adds lightweight target selection tooling, a new utility that lets users automatically select the best LoRA target_modules before creating a LoraConfig based on the lowest condition number per tensor, the metric used in KappaTune (https://arxiv.org/abs/2506.16289).

New API

from peft import LoraConfig, get_peft_model
from peft.utils.target_selection import KappaTuneSelector, find_kappa_target_modules

model = AutoModelForCausalLM.from_pretrained(...)   # your base model

# Option 1: class (full control)
selector = KappaTuneSelector(model)
optimal_targets = selector.get_best_targets(top_p=0.2)   # or num_modules=50, threshold=...

# Option 2: one-liner
optimal_targets = find_kappa_target_modules(model, top_p=0.2)

config = LoraConfig(
    target_modules=optimal_targets,
    r=16,
    lora_alpha=32,
    task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, config)

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I think the interface is fine, KappaTuneSelector as a class to compute once and be able to experiment with different selections is good, having a short-cut one-liner is good as well.

There are only a few things missing:

  • Let's add unit tests, e.g. in tests/test_target_selection.py
  • Let's add an example, if possible re-creating the results from the paper
  • Let's add documentation (entry in _toctree.yml under Utilities and a package reference file in package_reference/target_selection)

Did you test this with other PEFT methods like MiSS or SHiRA to see if it has a similar effect? I wonder if this method generalizes to other methods as well!

3. threshold (all modules below kappa)
4. everything (fallback)

Modules are sorted by ascending kappa (lowest = best for adaptation).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's document top_p, num_modules and threshold explicitly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't noticed that comment before. Check out the new comments in target_selection.py.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope all reviewer feedback has been addressed (docstrings, tests, 4-bit support, processor note in the docs, and the experiment`) and the PR is ready for final review & merge @githubnemo

top_p: float = 0.2,
max_dim_size_to_analyze: int = 16384,
) -> List[str]:
"""One-liner version for quick use."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's write a user-friendly documentation that explains the usage, default parameters and impact of parameter choice. Recommendations and examples are good.

oswaldoludwig and others added 7 commits March 17, 2026 14:52
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
This script implements Kappa-Selection using the PEFT KappaTuneSelector for identifying higher-entropy, less-anisotropic modules. It includes data preparation, experiment execution, and evaluation of perplexity on IMDB and WikiText datasets.
Updated KappaTuneSelector to support bnb 4-bit models and improved comments (ready for paper experiments with QLoRA / 4-bit models).
The selector picks the layers with the lowest condition number (most isotropic = best for adaptation), exactly as shown in the KappaTune paper (https://arxiv.org/abs/2506.16289).
Add tests for KappaTuneSelector and target selection
Added detailed docstrings for target selection functions, documenting top_p, num_modules and threshold explicitly.
Avoiding a compatibility issue, reported dozens of times on HF discussions, between the DeepSeek-V2-Lite custom modeling code and Transformers v4.40+. The model’s forward pass still calls the old past_key_values.get_usable_length(...) method, but DynamicCache (the default cache class now) no longer has it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants