conda activate coconut
pip install -r requirements.txtpython train.py \
--model_name gpt2 \
--batch_size 128 \
--learning_rate 1e-4 \
--num_epochs 5 \
--max_length 512 \
--latent_thoughts_per_step 1 \
--max_latent_length 20 \
--num_training_stages 4 \
--warmup_steps 100python eval.py \
--model <model_path> \
--data_path <test_data_path> \
--layers_to_delete <1 or more layers to delete> \
--early_exit_bound <confidence bound for ealy exit>Model paths:
llama-3.2-1b-gsm8k-stepscot/checkpoint-16000/ for the one we trained on full CoT.
checkpoints/llama-3.2.1b_gsm8k/checkpoint_75 for the one traned using the Internalized CoT script
Test data paths:
To evaluate the internalized model trained using the Internalize COT train script, use Internalize_CoT_Step_by_Step/data/gsm8k/test.txt as the test data.
To evaluate the model trained explicitly to output CoT, use implicit_chain_of_thought/data/gsm8k/test.txt as the test data.