Skip to content
Open
7 changes: 7 additions & 0 deletions docs/source/package_reference/target_selection.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# KappaTuneSelector

::: peft.utils.target_selection.KappaTuneSelector
options:
heading_level: 3

::: peft.utils.target_selection.find_kappa_target_modules
151 changes: 151 additions & 0 deletions examples/experiments_SA_kappatune_peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import torch
import torch.nn as nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from peft.utils.target_selection import KappaTuneSelector # ← NEW: PEFT selector
import bitsandbytes as bnb
import gc
import math

# ==========================================
# 1. Kappa-Selection using OFFICIAL PEFT KappaTuneSelector
# (exact paper logic: lowest κ on experts only)
# ==========================================
def get_stable_expert_names(model, budget_k=300):
print(f" [KappaTune] Identifying {budget_k} most stable expert modules using PEFT KappaTuneSelector...")

selector = KappaTuneSelector(model) # ← uses the new PEFT class

# Get top candidates (more than needed)
all_low_kappa = selector.get_best_targets(num_modules=budget_k * 2)

# Keep ONLY expert modules (exact same filtering as the paper)
expert_modules = [name for name in all_low_kappa if "experts" in name]

selected = expert_modules[:budget_k]
print(f" → Selected {len(selected)} expert modules with lowest κ")
return selected

# ==========================================
# 2. Data Preparation
# ==========================================
MODEL_ID = "deepseek-ai/DeepSeek-V2-Lite"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

def format_imdb(example):
val = example.get('label', example.get('sentiment', 0))
label_text = "Positive" if val == 1 else "Negative"
return {"text": f"Review: {example['text'][:512]}\n\nSentiment: {label_text}"}

print("Loading and preprocessing datasets...")
imdb_ds = load_dataset("imdb", split="train[:1000]").train_test_split(test_size=0.1)
imdb_tokenized = imdb_ds.map(format_imdb).map(
lambda x: tokenizer(x["text"], padding="max_length", truncation=True, max_length=256),
batched=True, remove_columns=imdb_ds["train"].column_names
)

wiki_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:400]")
wiki_tokenized = wiki_ds.filter(lambda x: len(x["text"]) > 20).map(
lambda x: tokenizer(x["text"], padding="max_length", truncation=True, max_length=256),
batched=True, remove_columns=wiki_ds.column_names
)

# ==========================================
# 3. Experiment Engine
# ==========================================
def evaluate_perplexity(model, dataset, name="Dataset"):
model.eval()
total_loss = 0
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=data_collator)

with torch.no_grad():
for i, batch in enumerate(dataloader):
batch = {k: v.to(model.device) for k, v in batch.items()}
outputs = model(**batch, use_cache=False)
total_loss += outputs.loss.item()
if i >= 40: break
return math.exp(total_loss / (i + 1))

def run_experiment(method_name):
print(f"\n{'='*40}\n>>> EXPERIMENT: {method_name}\n{'='*40}")

bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, quantization_config=bnb_config, trust_remote_code=True, device_map="auto"
)
model = prepare_model_for_kbit_training(model)

# Configure PEFT based on method
if method_name == "LoRA_Global":
lora_config = LoraConfig(
r=16,
target_modules=["q_proj", "v_proj", "up_proj", "down_proj"],
task_type=TaskType.CAUSAL_LM, lora_dropout=0.05
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
LR=2e-4
STP=10

elif method_name == "KappaTune_LoRA":
# ← NOW USES PEFT SELECTOR
stable_modules = get_stable_expert_names(model, budget_k=300)
lora_config = LoraConfig(
r=190,
target_modules=stable_modules,
task_type=TaskType.CAUSAL_LM, lora_dropout=0.05
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
LR=2e-4
STP=35

if method_name != "Baseline":
args = TrainingArguments(
output_dir=f"./{method_name}_out", per_device_train_batch_size=40,
gradient_accumulation_steps=4, learning_rate=LR, num_train_epochs=STP,
bf16=True, logging_steps=5, save_strategy="no", report_to="none"
)
trainer = Trainer(
model=model, args=args, train_dataset=imdb_tokenized["train"],
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
trainer.train()

t_ppl_test = evaluate_perplexity(model, imdb_tokenized["test"], "IMDB")
t_ppl_train = evaluate_perplexity(model, imdb_tokenized["train"], "IMDB")
f_ppl = evaluate_perplexity(model, wiki_tokenized, "WikiText")

del model
gc.collect()
torch.cuda.empty_cache()
return t_ppl_test, t_ppl_train, f_ppl

# ==========================================
# 4. Results (same table as paper)
# ==========================================
results = {}
results["Baseline"] = run_experiment("Baseline")
results["LoRA_Global"] = run_experiment("LoRA_Global")
results["KappaTune"] = run_experiment("KappaTune_LoRA")

print("\n" + "="*70)
print(f"{'METHOD':<15} | {'IMDB PPL (Task train)':<18} | {'IMDB PPL (Task test)':<18} | {'Wiki PPL (General/control)':<18}")
print("-" * 70)
for m, (tpte,tptr,fp) in results.items():
print(f"{m:<15} | {tptr:<18.4f} | {tpte:<18.4f} | {fp:<18.4f}")
print("="*70)
1 change: 1 addition & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
shift_tokens_right,
)

from .utils.target_selection import KappaTuneSelector, find_kappa_target_modules

__all__ = [
"MODEL_TYPE_TO_PEFT_MODEL_MAPPING",
Expand Down
126 changes: 126 additions & 0 deletions src/peft/utils/target_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2026-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# See https://arxiv.org/abs/2506.16289 for details

import torch
import torch.nn as nn
from typing import List, Optional, Dict

try:
import bitsandbytes as bnb
except ImportError:
bnb = None

class KappaTuneSelector:
"""
Lightweight utility to compute per-module condition numbers (κ = σ_max / σ_min)
and return the best LoRA target modules. Now supports bnb 4-bit models (paper reproduction).
"""
def __init__(self, model: nn.Module, max_dim_size_to_analyze: int = 16384):
self.model = model
self.max_dim_size_to_analyze = max_dim_size_to_analyze
self._condition_numbers: Optional[Dict[str, float]] = None

def _compute_kappas(self) -> None:
if self._condition_numbers is not None:
return

condition_numbers: Dict[str, float] = {}

for module_name, module in self.model.named_modules():
if not isinstance(module, nn.Linear):
continue

# Handle bnb 4-bit quantization (for QLoRA / paper example)
weight = module.weight
if bnb is not None and hasattr(weight, "quant_state"):
try:
w = bnb.functional.dequantize_4bit(weight.data, weight.quant_state).float()
except Exception:
w = weight.data.detach().float()
else:
w = weight.data.detach().float()

# Skip huge matrices
if any(dim > self.max_dim_size_to_analyze for dim in w.shape):
continue

try:
S = torch.linalg.svdvals(w.view(w.size(0), -1))
kappa = (S[0] / (S[-1] + 1e-8)).item()
condition_numbers[module_name] = kappa
except (torch.linalg.LinAlgError, RuntimeError):
condition_numbers[module_name] = float("inf")

self._condition_numbers = condition_numbers

def get_best_targets(
self,
top_p: Optional[float] = None,
num_modules: Optional[int] = None,
threshold: Optional[float] = None,
) -> List[str]:

"""
Return the best target modules according to one of three mutually-exclusive strategies.
Args:
top_p: Return the top best modules (e.g. 0.2 = paper default).
num_modules: Return exactly this many best modules (fixed budget).
threshold: Return every module with κ ≤ threshold (quality cutoff).
Returns:
List of module names (e.g. [model.layers.0.self_attn.q_proj, ...])
Notes:
- Precedence (checked in order): num_modules → top_p → threshold → all modules.
- Modules are always sorted by ascending κ (lowest = best).
- Recommended: top_p=0.2 for most models (Llama-3, Mistral, Qwen, etc.).
"""

self._compute_kappas()
if not self._condition_numbers:
return []

sorted_modules = sorted(self._condition_numbers.items(), key=lambda x: x[1])

if num_modules is not None:
k = min(num_modules, len(sorted_modules))
return [name for name, _ in sorted_modules[:k]]
if top_p is not None:
k = max(1, int(len(sorted_modules) * top_p))
return [name for name, _ in sorted_modules[:k]]
if threshold is not None:
return [name for name, kappa in sorted_modules if kappa <= threshold]

return [name for name, _ in sorted_modules]


def find_kappa_target_modules(
model: nn.Module, top_p: float = 0.2, max_dim_size_to_analyze: int = 16384
) -> List[str]:

"""
One-liner convenience function (recommended for most users).
Equivalent to:
selector = KappaTuneSelector(model, max_dim_size_to_analyze)
return selector.get_best_targets(top_p=top_p)
Args:
model: The base model.
top_p: Fraction of best modules to return (paper default = 0.2).
max_dim_size_to_analyze: See KappaTuneSelector.__init__.
Returns:
List of module names ready for LoraConfig(target_modules=...).
"""

selector = KappaTuneSelector(model, max_dim_size_to_analyze)
return selector.get_best_targets(top_p=top_p)
27 changes: 27 additions & 0 deletions tests/test_target_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch
import torch.nn as nn
from peft.utils.target_selection import KappaTuneSelector, find_kappa_target_modules

class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 5)

def test_selector_basic():
model = SimpleMLP()
selector = KappaTuneSelector(model)
targets = selector.get_best_targets(top_p=0.5)
assert len(targets) == 1
assert targets[0] in ["fc1", "fc2"]

def test_one_liner():
model = SimpleMLP()
targets = find_kappa_target_modules(model, top_p=1.0)
assert len(targets) == 2

def test_num_modules():
model = SimpleMLP()
targets = KappaTuneSelector(model).get_best_targets(num_modules=1)
assert len(targets) == 1