From 2290ead557ea31c68388833ecf85fdb88062029f Mon Sep 17 00:00:00 2001 From: Oswaldo Ludwig Date: Mon, 16 Mar 2026 18:25:45 +0100 Subject: [PATCH 1/9] Refine docstring for KappaTuneSelector class --- src/peft/utils/target_selection.py | 106 +++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 src/peft/utils/target_selection.py diff --git a/src/peft/utils/target_selection.py b/src/peft/utils/target_selection.py new file mode 100644 index 0000000000..5bc782ab8b --- /dev/null +++ b/src/peft/utils/target_selection.py @@ -0,0 +1,106 @@ +# src/peft/utils/target_selection.py + +import torch +import torch.nn as nn +from typing import List, Optional, Dict + +class KappaTuneSelector: + """ + Lightweight utility to compute per-module condition numbers (κ = σ_max / σ_min) + and return the best LoRA target modules (lowest κ = most stable / least anisotropic). + Use it before creating LoraConfig (no dependency on the full KappaTune optimizer). + """ + def __init__(self, model: nn.Module, max_dim_size_to_analyze: int = 16384): + """ + Args: + model: The base model (e.g. AutoModelForCausalLM). + max_dim_size_to_analyze: Skip any weight with a dimension > this value + (safety for embeddings, very large matrices, etc.). + """ + self.model = model + self.max_dim_size_to_analyze = max_dim_size_to_analyze + self._condition_numbers: Optional[Dict[str, float]] = None # module_name -> kappa + + def _compute_kappas(self) -> None: + """Compute condition number for every nn.Linear module's weight.""" + if self._condition_numbers is not None: + return + + condition_numbers: Dict[str, float] = {} + logger = None # optional: you can add logging if you want + + for module_name, module in self.model.named_modules(): + if not isinstance(module, nn.Linear): + continue + + weight = module.weight.detach() + + # Skip if too large + if any(dim > self.max_dim_size_to_analyze for dim in weight.shape): + continue + + # SVD (GPU if possible) + try: + if weight.is_cuda: + _, s, _ = torch.linalg.svd(weight, full_matrices=False) + else: + _, s, _ = torch.linalg.svd(weight.cpu(), full_matrices=False) + + kappa = (s[0] / s[-1]).item() if s[-1] > 1e-8 else float("inf") + condition_numbers[module_name] = kappa + except (torch.linalg.LinAlgError, RuntimeError): + condition_numbers[module_name] = float("inf") # treat as bad target + + self._condition_numbers = condition_numbers + + def get_best_targets( + self, + top_p: Optional[float] = None, # e.g. 0.2 → top 20% of modules + num_modules: Optional[int] = None, # absolute number (overrides top_p) + threshold: Optional[float] = None # kappa <= threshold (rarely used) + ) -> List[str]: + """ + Returns a list of module names ready for LoraConfig(target_modules=...). + + Priority order: + 1. num_modules (fixed budget) + 2. top_p (percentage) + 3. threshold (all modules below kappa) + 4. everything (fallback) + + Modules are sorted by ascending kappa (lowest = best for adaptation). + """ + self._compute_kappas() + if not self._condition_numbers: + return [] + + # Sort: lowest kappa first + sorted_modules = sorted( + self._condition_numbers.items(), + key=lambda x: x[1] + ) + + if num_modules is not None: + k = min(num_modules, len(sorted_modules)) + return [name for name, _ in sorted_modules[:k]] + + if top_p is not None: + k = max(1, int(len(sorted_modules) * top_p)) + return [name for name, _ in sorted_modules[:k]] + + if threshold is not None: + return [name for name, kappa in sorted_modules if kappa <= threshold] + + # fallback: all modules + return [name for name, _ in sorted_modules] + + +# Optional convenience function (many people prefer this style) +def find_kappa_target_modules( + model: nn.Module, + top_p: float = 0.2, + max_dim_size_to_analyze: int = 16384, +) -> List[str]: + """One-liner version for quick use.""" + selector = KappaTuneSelector(model, max_dim_size_to_analyze) + return selector.get_best_targets(top_p=top_p) From 7dcded92e2ddd2fbd4020d1d640c994372f289d4 Mon Sep 17 00:00:00 2001 From: Oswaldo Ludwig Date: Mon, 16 Mar 2026 18:31:04 +0100 Subject: [PATCH 2/9] Import KappaTuneSelector and find_kappa_target_modules --- src/peft/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 404dd9742b..34a5d8cde7 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -150,6 +150,7 @@ shift_tokens_right, ) +from .utils.target_selection import KappaTuneSelector, find_kappa_target_modules __all__ = [ "MODEL_TYPE_TO_PEFT_MODEL_MAPPING", From f3586a75f91cdacb15ad95fbeb599b0d926e9e67 Mon Sep 17 00:00:00 2001 From: Oswaldo Ludwig Date: Tue, 17 Mar 2026 14:52:08 +0100 Subject: [PATCH 3/9] Update src/peft/utils/target_selection.py Co-authored-by: githubnemo --- src/peft/utils/target_selection.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/peft/utils/target_selection.py b/src/peft/utils/target_selection.py index 5bc782ab8b..2f43c9114b 100644 --- a/src/peft/utils/target_selection.py +++ b/src/peft/utils/target_selection.py @@ -1,4 +1,16 @@ -# src/peft/utils/target_selection.py +# Copyright 2026-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch import torch.nn as nn From 1ed87f42f3f740d54a646590ae689cc66fa21b92 Mon Sep 17 00:00:00 2001 From: Oswaldo Ludwig Date: Tue, 17 Mar 2026 14:58:16 +0100 Subject: [PATCH 4/9] Add KappaTune PEFT experiment framework This script implements Kappa-Selection using the PEFT KappaTuneSelector for identifying higher-entropy, less-anisotropic modules. It includes data preparation, experiment execution, and evaluation of perplexity on IMDB and WikiText datasets. --- examples/experiments_SA_kappatune_peft.py | 151 ++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 examples/experiments_SA_kappatune_peft.py diff --git a/examples/experiments_SA_kappatune_peft.py b/examples/experiments_SA_kappatune_peft.py new file mode 100644 index 0000000000..7b5942b549 --- /dev/null +++ b/examples/experiments_SA_kappatune_peft.py @@ -0,0 +1,151 @@ +import torch +import torch.nn as nn +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + TrainingArguments, + Trainer, + BitsAndBytesConfig, + DataCollatorForLanguageModeling +) +from datasets import load_dataset +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType +from peft.utils.target_selection import KappaTuneSelector # ← NEW: PEFT selector +import bitsandbytes as bnb +import gc +import math + +# ========================================== +# 1. Kappa-Selection using OFFICIAL PEFT KappaTuneSelector +# (exact paper logic: lowest κ on experts only) +# ========================================== +def get_stable_expert_names(model, budget_k=300): + print(f" [KappaTune] Identifying {budget_k} most stable expert modules using PEFT KappaTuneSelector...") + + selector = KappaTuneSelector(model) # ← uses the new PEFT class + + # Get top candidates (more than needed) + all_low_kappa = selector.get_best_targets(num_modules=budget_k * 2) + + # Keep ONLY expert modules (exact same filtering as the paper) + expert_modules = [name for name in all_low_kappa if "experts" in name] + + selected = expert_modules[:budget_k] + print(f" → Selected {len(selected)} expert modules with lowest κ") + return selected + +# ========================================== +# 2. Data Preparation +# ========================================== +MODEL_ID = "deepseek-ai/DeepSeek-V2-Lite" +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_imdb(example): + val = example.get('label', example.get('sentiment', 0)) + label_text = "Positive" if val == 1 else "Negative" + return {"text": f"Review: {example['text'][:512]}\n\nSentiment: {label_text}"} + +print("Loading and preprocessing datasets...") +imdb_ds = load_dataset("imdb", split="train[:1000]").train_test_split(test_size=0.1) +imdb_tokenized = imdb_ds.map(format_imdb).map( + lambda x: tokenizer(x["text"], padding="max_length", truncation=True, max_length=256), + batched=True, remove_columns=imdb_ds["train"].column_names +) + +wiki_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:400]") +wiki_tokenized = wiki_ds.filter(lambda x: len(x["text"]) > 20).map( + lambda x: tokenizer(x["text"], padding="max_length", truncation=True, max_length=256), + batched=True, remove_columns=wiki_ds.column_names +) + +# ========================================== +# 3. Experiment Engine +# ========================================== +def evaluate_perplexity(model, dataset, name="Dataset"): + model.eval() + total_loss = 0 + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=data_collator) + + with torch.no_grad(): + for i, batch in enumerate(dataloader): + batch = {k: v.to(model.device) for k, v in batch.items()} + outputs = model(**batch) + total_loss += outputs.loss.item() + if i >= 40: break + return math.exp(total_loss / (i + 1)) + +def run_experiment(method_name): + print(f"\n{'='*40}\n>>> EXPERIMENT: {method_name}\n{'='*40}") + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, + ) + + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, quantization_config=bnb_config, trust_remote_code=True, device_map="auto" + ) + model = prepare_model_for_kbit_training(model) + + # Configure PEFT based on method + if method_name == "LoRA_Global": + lora_config = LoraConfig( + r=16, + target_modules=["q_proj", "v_proj", "up_proj", "down_proj"], + task_type=TaskType.CAUSAL_LM, lora_dropout=0.05 + ) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + LR=2e-4 + STP=10 + + elif method_name == "KappaTune_LoRA": + # ← NOW USES PEFT SELECTOR + stable_modules = get_stable_expert_names(model, budget_k=300) + lora_config = LoraConfig( + r=190, + target_modules=stable_modules, + task_type=TaskType.CAUSAL_LM, lora_dropout=0.05 + ) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + LR=2e-4 + STP=35 + + if method_name != "Baseline": + args = TrainingArguments( + output_dir=f"./{method_name}_out", per_device_train_batch_size=40, + gradient_accumulation_steps=4, learning_rate=LR, num_train_epochs=STP, + bf16=True, logging_steps=5, save_strategy="no", report_to="none" + ) + trainer = Trainer( + model=model, args=args, train_dataset=imdb_tokenized["train"], + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False) + ) + trainer.train() + + t_ppl_test = evaluate_perplexity(model, imdb_tokenized["test"], "IMDB") + t_ppl_train = evaluate_perplexity(model, imdb_tokenized["train"], "IMDB") + f_ppl = evaluate_perplexity(model, wiki_tokenized, "WikiText") + + del model + gc.collect() + torch.cuda.empty_cache() + return t_ppl_test, t_ppl_train, f_ppl + +# ========================================== +# 4. Results (same table as paper) +# ========================================== +results = {} +results["Baseline"] = run_experiment("Baseline") +results["LoRA_Global"] = run_experiment("LoRA_Global") +results["KappaTune"] = run_experiment("KappaTune_LoRA") + +print("\n" + "="*70) +print(f"{'METHOD':<15} | {'IMDB PPL (Task train)':<18} | {'IMDB PPL (Task test)':<18} | {'Wiki PPL (General/control)':<18}") +print("-" * 70) +for m, (tpte,tptr,fp) in results.items(): + print(f"{m:<15} | {tptr:<18.4f} | {tpte:<18.4f} | {fp:<18.4f}") +print("="*70) From cce09e247dfa5cc9f85813921a51c6a07c3715ce Mon Sep 17 00:00:00 2001 From: Oswaldo Ludwig Date: Tue, 17 Mar 2026 15:06:40 +0100 Subject: [PATCH 5/9] Enhance KappaTuneSelector for bnb 4-bit support Updated KappaTuneSelector to support bnb 4-bit models and improved comments (ready for paper experiments with QLoRA / 4-bit models). --- src/peft/utils/target_selection.py | 79 +++++++++++------------------- 1 file changed, 29 insertions(+), 50 deletions(-) diff --git a/src/peft/utils/target_selection.py b/src/peft/utils/target_selection.py index 2f43c9114b..8243987e37 100644 --- a/src/peft/utils/target_selection.py +++ b/src/peft/utils/target_selection.py @@ -12,107 +12,86 @@ # See the License for the specific language governing permissions and # limitations under the License. +# See https://arxiv.org/abs/2506.16289 for details + import torch import torch.nn as nn from typing import List, Optional, Dict +try: + import bitsandbytes as bnb +except ImportError: + bnb = None + class KappaTuneSelector: """ Lightweight utility to compute per-module condition numbers (κ = σ_max / σ_min) - and return the best LoRA target modules (lowest κ = most stable / least anisotropic). - Use it before creating LoraConfig (no dependency on the full KappaTune optimizer). + and return the best LoRA target modules. Now supports bnb 4-bit models (paper reproduction). """ def __init__(self, model: nn.Module, max_dim_size_to_analyze: int = 16384): - """ - Args: - model: The base model (e.g. AutoModelForCausalLM). - max_dim_size_to_analyze: Skip any weight with a dimension > this value - (safety for embeddings, very large matrices, etc.). - """ self.model = model self.max_dim_size_to_analyze = max_dim_size_to_analyze - self._condition_numbers: Optional[Dict[str, float]] = None # module_name -> kappa + self._condition_numbers: Optional[Dict[str, float]] = None def _compute_kappas(self) -> None: - """Compute condition number for every nn.Linear module's weight.""" if self._condition_numbers is not None: return condition_numbers: Dict[str, float] = {} - logger = None # optional: you can add logging if you want for module_name, module in self.model.named_modules(): if not isinstance(module, nn.Linear): continue - weight = module.weight.detach() - - # Skip if too large - if any(dim > self.max_dim_size_to_analyze for dim in weight.shape): + # Handle bnb 4-bit quantization (for QLoRA / paper example) + weight = module.weight + if bnb is not None and hasattr(weight, "quant_state"): + try: + w = bnb.functional.dequantize_4bit(weight.data, weight.quant_state).float() + except Exception: + w = weight.data.detach().float() + else: + w = weight.data.detach().float() + + # Skip huge matrices + if any(dim > self.max_dim_size_to_analyze for dim in w.shape): continue - # SVD (GPU if possible) try: - if weight.is_cuda: - _, s, _ = torch.linalg.svd(weight, full_matrices=False) - else: - _, s, _ = torch.linalg.svd(weight.cpu(), full_matrices=False) - - kappa = (s[0] / s[-1]).item() if s[-1] > 1e-8 else float("inf") + S = torch.linalg.svdvals(w.view(w.size(0), -1)) + kappa = (S[0] / (S[-1] + 1e-8)).item() condition_numbers[module_name] = kappa except (torch.linalg.LinAlgError, RuntimeError): - condition_numbers[module_name] = float("inf") # treat as bad target + condition_numbers[module_name] = float("inf") self._condition_numbers = condition_numbers def get_best_targets( self, - top_p: Optional[float] = None, # e.g. 0.2 → top 20% of modules - num_modules: Optional[int] = None, # absolute number (overrides top_p) - threshold: Optional[float] = None # kappa <= threshold (rarely used) + top_p: Optional[float] = None, + num_modules: Optional[int] = None, + threshold: Optional[float] = None, ) -> List[str]: - """ - Returns a list of module names ready for LoraConfig(target_modules=...). - - Priority order: - 1. num_modules (fixed budget) - 2. top_p (percentage) - 3. threshold (all modules below kappa) - 4. everything (fallback) - - Modules are sorted by ascending kappa (lowest = best for adaptation). - """ self._compute_kappas() if not self._condition_numbers: return [] - # Sort: lowest kappa first - sorted_modules = sorted( - self._condition_numbers.items(), - key=lambda x: x[1] - ) + sorted_modules = sorted(self._condition_numbers.items(), key=lambda x: x[1]) if num_modules is not None: k = min(num_modules, len(sorted_modules)) return [name for name, _ in sorted_modules[:k]] - if top_p is not None: k = max(1, int(len(sorted_modules) * top_p)) return [name for name, _ in sorted_modules[:k]] - if threshold is not None: return [name for name, kappa in sorted_modules if kappa <= threshold] - # fallback: all modules return [name for name, _ in sorted_modules] -# Optional convenience function (many people prefer this style) def find_kappa_target_modules( - model: nn.Module, - top_p: float = 0.2, - max_dim_size_to_analyze: int = 16384, + model: nn.Module, top_p: float = 0.2, max_dim_size_to_analyze: int = 16384 ) -> List[str]: - """One-liner version for quick use.""" selector = KappaTuneSelector(model, max_dim_size_to_analyze) return selector.get_best_targets(top_p=top_p) From 50a9a5e00c04cdf0da8b562912e42613136d59c3 Mon Sep 17 00:00:00 2001 From: Oswaldo Ludwig Date: Tue, 17 Mar 2026 15:16:16 +0100 Subject: [PATCH 6/9] Add KappaTuneSelector documentation The selector picks the layers with the lowest condition number (most isotropic = best for adaptation), exactly as shown in the KappaTune paper (https://arxiv.org/abs/2506.16289). --- docs/source/package_reference/target_selection.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 docs/source/package_reference/target_selection.md diff --git a/docs/source/package_reference/target_selection.md b/docs/source/package_reference/target_selection.md new file mode 100644 index 0000000000..c34e15c50b --- /dev/null +++ b/docs/source/package_reference/target_selection.md @@ -0,0 +1,7 @@ +# KappaTuneSelector + +::: peft.utils.target_selection.KappaTuneSelector + options: + heading_level: 3 + +::: peft.utils.target_selection.find_kappa_target_modules From 7c310065e01b37952e1f377d4ea97cc65de1b92f Mon Sep 17 00:00:00 2001 From: Oswaldo Ludwig Date: Tue, 17 Mar 2026 15:21:55 +0100 Subject: [PATCH 7/9] Implement unit tests for target selection module Add tests for KappaTuneSelector and target selection --- tests/test_target_selection.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/test_target_selection.py diff --git a/tests/test_target_selection.py b/tests/test_target_selection.py new file mode 100644 index 0000000000..720eb13f62 --- /dev/null +++ b/tests/test_target_selection.py @@ -0,0 +1,27 @@ +import pytest +import torch +import torch.nn as nn +from peft.utils.target_selection import KappaTuneSelector, find_kappa_target_modules + +class SimpleMLP(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 20) + self.fc2 = nn.Linear(20, 5) + +def test_selector_basic(): + model = SimpleMLP() + selector = KappaTuneSelector(model) + targets = selector.get_best_targets(top_p=0.5) + assert len(targets) == 1 + assert targets[0] in ["fc1", "fc2"] + +def test_one_liner(): + model = SimpleMLP() + targets = find_kappa_target_modules(model, top_p=1.0) + assert len(targets) == 2 + +def test_num_modules(): + model = SimpleMLP() + targets = KappaTuneSelector(model).get_best_targets(num_modules=1) + assert len(targets) == 1 From d38aebb1f7fc1326c334f21d949f920c5322e954 Mon Sep 17 00:00:00 2001 From: Oswaldo Ludwig Date: Mon, 23 Mar 2026 18:25:25 +0100 Subject: [PATCH 8/9] Enhance docstrings in target_selection.py Added detailed docstrings for target selection functions, documenting top_p, num_modules and threshold explicitly. --- src/peft/utils/target_selection.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/peft/utils/target_selection.py b/src/peft/utils/target_selection.py index 8243987e37..4ccc05b69f 100644 --- a/src/peft/utils/target_selection.py +++ b/src/peft/utils/target_selection.py @@ -72,6 +72,21 @@ def get_best_targets( num_modules: Optional[int] = None, threshold: Optional[float] = None, ) -> List[str]: + + """ + Return the best target modules according to one of three mutually-exclusive strategies. + Args: + top_p: Return the top best modules (e.g. 0.2 = paper default). + num_modules: Return exactly this many best modules (fixed budget). + threshold: Return every module with κ ≤ threshold (quality cutoff). + Returns: + List of module names (e.g. [model.layers.0.self_attn.q_proj, ...]) + Notes: + - Precedence (checked in order): num_modules → top_p → threshold → all modules. + - Modules are always sorted by ascending κ (lowest = best). + - Recommended: top_p=0.2 for most models (Llama-3, Mistral, Qwen, etc.). + """ + self._compute_kappas() if not self._condition_numbers: return [] @@ -93,5 +108,19 @@ def get_best_targets( def find_kappa_target_modules( model: nn.Module, top_p: float = 0.2, max_dim_size_to_analyze: int = 16384 ) -> List[str]: + + """ + One-liner convenience function (recommended for most users). + Equivalent to: + selector = KappaTuneSelector(model, max_dim_size_to_analyze) + return selector.get_best_targets(top_p=top_p) + Args: + model: The base model. + top_p: Fraction of best modules to return (paper default = 0.2). + max_dim_size_to_analyze: See KappaTuneSelector.__init__. + Returns: + List of module names ready for LoraConfig(target_modules=...). + """ + selector = KappaTuneSelector(model, max_dim_size_to_analyze) return selector.get_best_targets(top_p=top_p) From 2ccdee14fda22952c207f8f2260f190493a291b8 Mon Sep 17 00:00:00 2001 From: Oswaldo Ludwig Date: Tue, 24 Mar 2026 17:01:49 +0100 Subject: [PATCH 9/9] Disable caching in model outputs during evaluation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Avoiding a compatibility issue, reported dozens of times on HF discussions, between the DeepSeek-V2-Lite custom modeling code and Transformers v4.40+. The model’s forward pass still calls the old past_key_values.get_usable_length(...) method, but DynamicCache (the default cache class now) no longer has it. --- examples/experiments_SA_kappatune_peft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experiments_SA_kappatune_peft.py b/examples/experiments_SA_kappatune_peft.py index 7b5942b549..7b5b96bb1a 100644 --- a/examples/experiments_SA_kappatune_peft.py +++ b/examples/experiments_SA_kappatune_peft.py @@ -71,7 +71,7 @@ def evaluate_perplexity(model, dataset, name="Dataset"): with torch.no_grad(): for i, batch in enumerate(dataloader): batch = {k: v.to(model.device) for k, v in batch.items()} - outputs = model(**batch) + outputs = model(**batch, use_cache=False) total_loss += outputs.loss.item() if i >= 40: break return math.exp(total_loss / (i + 1))