Skip to content

michaelchen-lab/caft-llm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Concept-Aware Fine-Tuning (CAFT)

arXiv Paper

Concept-aware fine-tuning (CAFT) encourages stronger conceptual understanding by incorporating multi-token prediction into fine-tuning.

Installation

git clone https://github.com/michaelchen-lab/caft-llm.git
cd caft-llm
pip install -e .

Setup

  1. Create .env file with HUGGINGFACE_TOKEN=<token> and optionally WANDB_TOKEN=<token>
  2. Add train_set.jsonl and eval_set.jsonl files to scripts/datasets/. Each instance should be of the format:
{
    "id": "<int/str>", "status": "OK", 
    "conversation": [
        {"role": "human", "content": "(prompt)"}, 
        {"role": "assistant", "content": "(ground truth answer)"},
    ]
}

Fine-tune a model using CAFT

Currently, only the auxiliary heads of meta-llama/Llama-3.1-8B-Instruct have been pretrained.

Method 1: Use the provided training script scripts/train.py

torchrun --nproc-per-node 1 scripts/train.py -ftm lora 
torchrun --nproc-per-node 1 scripts/train.py -ftm lora -ft-heads -hpretrain
torchrun --nproc-per-node 1 scripts/train.py -ftm sft -lr 5e-6 -fr-unembed
torchrun --nproc-per-node 1 scripts/train.py -ftm sft -lr 5e-6 -fr-unembed -ft-heads -hpretrain

Selected Arguments:

  • --model-name-or-path -model: Currently only meta-llama/Llama-3.1-8B-Instruct is supported.
  • --model-max-length -maxlen
  • --finetune-method -ftm: lora or sft (full finetuning)
  • --learning-rate -lr
  • --epochs -e
  • --freeze-unembedding -fr-unembed: Only applicable for full fine-tuning. Recommended: True
  • --per-device-batch-size -micro-bs
  • --gradient-accumulation-steps -grad-acc
  • --heads-pretraining -hpretrain: Train auxiliary heads on your dataset for 1 epoch before apply CAFT to your model. -ft-heads must also be set to True.

The full list of arguments can be found using this command:

python scripts/train.py --help

Method 2: Integrate CAFT into your existing Transformers fine-tuning pipeline

import transformers
from caft import *

# Import your pretrained Transformers model, tokenizer, TrainingArguments, and data_module

add_auxiliary_heads(model)
add_caft_loss(transformers)

trainer = transformers.trainer.Trainer( # The additional CAFT functions track and save the auxiliary losses
    model=model, tokenizer=tokenizer, args=model_training_args,
    callbacks=[CAFTSaveLogging], 
    compute_metrics=caft_compute_metrics, 
    preprocess_logits_for_metrics=preprocess_logits_for_metrics, 
    **data_module
)

Please refer to scripts/train.py for a complete implementation example.

Evaluation Datasets

All datasets used in the paper can be found in this Huggingface repo.

(Optional) Train Auxiliary Heads

  1. Download the train and validation dataset from this Huggingface repo and save to scripts/datasets
  2. Run the following command
torchrun nproc-per-node 4 scripts/train_aux_heads.py

Contributing

We welcome community contributions and feature requests for caft-llm. Feel free to open an issue or submit a pull request. If you have any questions or wish to collaborate, please contact michaelchenkj@gmail.com.

Roadmap

  • Support all model architectures.
    Description Currently, the `LlamaDecoderLayer` is used to create auxiliary heads; in other words, only Llama-based models are supported. Edit `core.py` to copy the last hidden layer of the given model instead of inserting `LlamaDecoderLayer`, then reinitialize the weights.
  • Support speculative decoding.
    Description Speculative decoding can be implemented using the same method as Gloeckle et al. (2024) and Stern et al. (2018).
  • Support FSDP and DeepSpeed

Acknowledgements

This codebase adapts code from several amazing projects, including Medusa and Facebook Multi-Token.

About

Improving large language models with concept-aware fine-tuning (CAFT)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages