-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
132 lines (117 loc) · 5.05 KB
/
eval.py
File metadata and controls
132 lines (117 loc) · 5.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# 1. Load model
# 2. Load tokenizer
# 3. Load data
# 4. Tokenize data
# 5. Evaluate model
from data import prepare_data, load_data_from_file, prepare_dataloader
from transformers import (
GPT2LMHeadModel,
GPT2Tokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
AutoTokenizer,
AutoModelForCausalLM,
)
from model import EarlyExitLlamaModelForCausalLM, EarlyExitLlamaModel
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import argparse
from post_train import bi_layers_to_delete, angular_block_to_delete, least_important_layer
def evaluate(
model,
data_path: str,
tokenizer,
) -> float:
model.eval()
with torch.no_grad():
test_loader = prepare_dataloader(path=data_path, tokenizer=tokenizer)
pbar = tqdm(test_loader)
correct = []
for batch in pbar:
batch = {k: v.to(model.device) for k, v in batch.items()}
outputs = model.generate(
input_ids=batch['question_input_ids'],
attention_mask=batch['question_attention_mask'],
max_new_tokens=150,
#num_beams=1,
#num_return_sequences=1,
repetition_penalty=1.0,
temperature=1.0,
#do_sample=False,
)
pred_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
real_outputs = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
#if "####" not in pred_outputs[0]
for r, p in zip(real_outputs, pred_outputs):
r_answer = r.split('####')[1].strip()
try:
p_answer = p.split('####')[-1].strip("#").strip()
except:
breakpoint()
correct.append(r_answer == p_answer)
print(r)
print(p)
print("*===*")
acc = correct.count(True) / len(correct)
pbar.set_description(f"Accuracy: {acc*100:.2f}%")
print(f"Accuracy: {acc*100:.2f}%")
return acc
def main(args):
# tokenizer data
if "gpt2" in args.model:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', add_eos_token=True)
else:
tokenizer = AutoTokenizer.from_pretrained("/raid/lingo/models/Llama-3.2-1B-Instruct/", add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
if "llama" in args.model:
layers_to_delete = [] if args.layers_to_delete is None else args.layers_to_delete
model = EarlyExitLlamaModelForCausalLM.from_pretrained(
args.model,
layers_to_delete=layers_to_delete,
confidence_bound=args.early_exit_bound,
device_map="auto",
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map="auto",
)
# breakpoint()
model.resize_token_embeddings(len(tokenizer))
# remove layer by score
valid_loader = prepare_dataloader(path=args.valid_path, tokenizer=tokenizer)
if args.delete_by_bi:
# breakpoint()
model.model.layers_to_delete = bi_layers_to_delete(model=model,
num_to_delete=args.num_to_delete,
data_loader=valid_loader)
elif args.angular_block_delete:
model.model.layers_to_delete = angular_block_to_delete(model=model,
data_loader=valid_loader,
block_size=args.block_size)
elif args.greedy_delete:
# breakpoint()
least_important, accuracies = least_important_layer(model=model,
tokenizer=tokenizer,
data_loader=valid_loader)
breakpoint()
model.model.layers_to_delete.append(least_important)
print(f"Skipping layers {model.model.layers_to_delete}")
evaluate(model=model, data_path=args.data_path, tokenizer=tokenizer)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="llama-3.2-1b-gsm8k-stepscot/checkpoint-16000")
parser.add_argument("--data_path", type=str, default="Internalize_CoT_Step_by_Step/data/gsm8k/test.txt")
parser.add_argument("--valid_path", type=str, default="Internalize_CoT_Step_by_Step/data/gsm8k/valid.txt")
parser.add_argument("--layers_to_delete", nargs='+', type=int, default=None)
parser.add_argument("--early_exit_bound", type=float, default=None)
parser.add_argument("--delete_by_bi", type=bool, default=False)
parser.add_argument("--num_to_delete", type=int, default=0)
parser.add_argument("--angular_block_delete", type=bool, default=False)
parser.add_argument("--block_size", type=int, default=0)
parser.add_argument("--greedy_delete", type=bool, default=0)
args = parser.parse_args()
main(args)