-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
332 lines (277 loc) · 11.9 KB
/
evaluate.py
File metadata and controls
332 lines (277 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
# evaluate.py
import argparse
import os
import csv
# Set HuggingFace cache to avoid disk space issues
_HF_CACHE = os.environ.get('HF_HOME', '/data/qbao775/lemo/.cache/huggingface')
os.environ['HF_HOME'] = _HF_CACHE
os.environ['HF_DATASETS_CACHE'] = os.path.join(_HF_CACHE, 'datasets')
os.environ['TRANSFORMERS_CACHE'] = os.path.join(_HF_CACHE, 'transformers')
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_LIST = {
"bert": "bert-base-uncased",
"qwen": "Qwen/Qwen2-1.5B",
"qwen3": "/data/shared/qwen3/Qwen3-8B",
"llama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
# All test splits, including multi-law variant4
DEFAULT_TEST_FILES = {
"base": "data/test_base.csv",
"hard_mixed": "data/test_hard_mixed.csv",
"variant1": "data/test_variant1.csv",
"variant2": "data/test_variant2.csv",
"variant3": "data/test_variant3.csv",
"variant4_equiv_contrapositive": "data/test_variant4_equiv_contrapositive.csv",
"variant4_equiv_double_negation": "data/test_variant4_equiv_double_negation.csv",
"variant4_equiv_implication": "data/test_variant4_equiv_implication.csv",
"variant4_equiv_demorgan": "data/test_variant4_equiv_demorgan.csv",
"variant4_equiv_identity": "data/test_variant4_equiv_identity.csv",
"variant4_equiv_commutativity": "data/test_variant4_equiv_commutativity.csv",
"variant4_equiv_multi": "data/test_variant4_equiv_multi.csv",
}
def build_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def predict_single(model, tokenizer, text, device):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
pred = logits.argmax(-1).item()
return "T" if pred == 1 else "F"
def describe_change(split_name: str, laws_used: str, law_count: int) -> str:
"""
Human-readable description of what changed for this split.
For multi-law cases, we embed the law list and count.
"""
if split_name == "base":
return "none"
if split_name == "hard_mixed":
return "mixed T/F answer patterns: partial reasoning chains with distractors"
if split_name == "variant1":
return "removed redundant rule: 'If someone is young then they are cold.'"
if split_name == "variant2":
return "removed key rule: 'If someone is cold then they are rough.'"
if split_name == "variant3":
return "changed facts: added '<name> is not cold or not nice'"
if split_name.startswith("variant4_equiv_"):
if split_name == "variant4_equiv_multi":
return f"multiple logical equivalence laws applied (count={law_count}): {laws_used}"
base = split_name.replace("variant4_equiv_", "")
return f"logical equivalence law applied: {base}"
return "unknown"
def eval_and_save(model, tokenizer, filename, model_key, split_name, device, out_dir):
"""
Evaluate on one CSV file AND save predictions into a CSV.
Each row in the output corresponds to ONE question.
"""
if not os.path.exists(filename):
raise FileNotFoundError(f"Test file not found: {filename}")
ds = load_dataset("csv", data_files=filename)["train"]
total, correct = 0, 0
output_rows = []
os.makedirs(out_dir, exist_ok=True)
output_csv = os.path.join(out_dir, f"{model_key}_{split_name}_predictions.csv")
for row in ds:
facts = row["facts"]
rules = row["rules"]
questions = row["questions"].split(" | ")
answers = row["answers"].split(" | ")
laws_used = row.get("equiv_laws_used", "") or ""
law_list = [x for x in laws_used.split(",") if x]
law_count = len(law_list)
changed_desc = describe_change(split_name, laws_used, law_count)
for q, truth in zip(questions, answers):
text = facts + " " + rules + " " + q
pred = predict_single(model, tokenizer, text, device)
output_rows.append({
"group_id": row["group_id"],
"type": split_name,
"facts": facts,
"rules": rules,
"question": q,
"ground_truth": truth,
"prediction": pred,
"equiv_laws_used": laws_used,
"equiv_law_count": law_count,
"changed_rule": changed_desc,
})
if pred == truth:
correct += 1
total += 1
acc = correct / total if total > 0 else 0.0
# Save prediction CSV
with open(output_csv, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"group_id",
"type",
"facts",
"rules",
"question",
"ground_truth",
"prediction",
"equiv_laws_used",
"equiv_law_count",
"changed_rule",
],
)
writer.writeheader()
writer.writerows(output_rows)
print(f"📄 Predictions saved to: {output_csv}")
return acc, total, correct
def main(model_key: str, model_dir: str = None):
# Allow custom model directory or use default
if model_dir is None:
model_dir = f"./trained_models/{model_key}"
base_model_name = MODEL_LIST.get(model_key, "unknown")
print(f"▶ Loading model from: {model_dir}")
print(f"▶ Base model type: {model_key}")
if base_model_name != "unknown":
print(f"▶ Base model: {base_model_name}")
device = build_device()
print(f"▶ Device: {device}")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.to(device)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
predictions_dir = os.path.join(model_dir, "predictions")
results = []
print("\n===== Detailed Evaluation Per Split =====")
for split_name, filename in DEFAULT_TEST_FILES.items():
output_csv = os.path.join(predictions_dir, f"{model_key}_{split_name}_predictions.csv")
# Load ground truth dataset
ds = load_dataset("csv", data_files=filename)["train"]
# Calculate total EXPECTED predictions (sum of questions in each row)
expected_total = 0
for row in ds:
questions = row["questions"].split(" | ")
expected_total += len(questions)
# Check if predictions already exist and are complete
if os.path.exists(output_csv):
print(f"\n[{split_name}] Found existing predictions: {output_csv}")
try:
with open(output_csv, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
rows = list(reader)
if len(rows) == expected_total:
print(f" ✅ File is complete ({len(rows)} predictions). Calculating accuracy from existing file...")
correct = sum(1 for r in rows if r["prediction"] == r["ground_truth"])
acc = correct / expected_total if expected_total > 0 else 0.0
results.append({
"split": split_name,
"accuracy": acc,
"correct": correct,
"total": expected_total
})
print(f" accuracy: {acc:.4f}")
continue
else:
print(f" ⚠️ File incomplete ({len(rows)}/{expected_total}). Re-running evaluation...")
except Exception as e:
print(f" ⚠️ Error reading file: {e}. Re-running evaluation...")
print(f"\n[{split_name}] Evaluating {filename}...")
acc, total, correct = eval_and_save(
model,
tokenizer,
filename,
model_key,
split_name,
device,
predictions_dir,
)
results.append({
"split": split_name,
"accuracy": acc,
"correct": correct,
"total": total
})
print(f" samples (questions): {total}")
print(f" correct: {correct}")
print(f" accuracy: {acc:.4f}")
print("-" * 40)
# ------- summary table -------
print("\n===== Base vs Variants Accuracy Table =====")
base_acc = next((r["accuracy"] for r in results if r["split"] == "base"), 0.0)
header = f"{'Split':<35} | {'Accuracy':>9} | {'Δ vs base':>9}"
print(header)
print("-" * len(header))
_suffix = getattr(main, "_output_suffix", "")
summary_csv_path = os.path.join(model_dir, f"accuracy_summary{_suffix}.csv")
with open(summary_csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["split", "accuracy", "delta_vs_base", "correct", "total"])
writer.writeheader()
ordered_splits = [
"base",
"hard_mixed",
"variant1",
"variant2",
"variant3",
"variant4_equiv_contrapositive",
"variant4_equiv_double_negation",
"variant4_equiv_implication",
"variant4_equiv_demorgan",
"variant4_equiv_identity",
"variant4_equiv_commutativity",
"variant4_equiv_multi",
]
for split in ordered_splits:
res = next((r for r in results if r["split"] == split), None)
if not res:
continue
acc = res["accuracy"]
delta = acc - base_acc
delta_str = f"{delta:+.3f}" if split != "base" else "0.000"
print(f"{split:<35} | {acc:>9.4f} | {delta_str:>9}")
writer.writerow({
"split": split,
"accuracy": acc,
"delta_vs_base": delta,
"correct": res["correct"],
"total": res["total"]
})
print(f"\n📄 Accuracy summary saved to: {summary_csv_path}")
print("\n✅ Evaluation FINISHED.\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, choices=["bert", "qwen", "qwen3", "llama"],
help="Model type (bert, qwen, qwen3, llama)")
parser.add_argument("--model_dir", type=str, default=None,
help="Custom model directory (default: trained_models/{model})")
parser.add_argument("--stage", type=str, default=None, choices=["stage1", "stage2", "stage2_mixed"],
help="Shortcut to evaluate stage models (overrides model_dir)")
parser.add_argument("--data_dir", type=str, default=None,
help="Custom data directory containing test_*.csv files (default: data/)")
parser.add_argument("--output_suffix", type=str, default="",
help="Suffix for accuracy_summary filename, e.g. '_v2' → accuracy_summary_v2.csv")
args = parser.parse_args()
# Handle stage shortcuts
if args.stage:
model_dir = f"./trained_models/{args.model}_{args.stage}"
else:
model_dir = args.model_dir
# Pass output suffix to main via function attribute
main._output_suffix = args.output_suffix
# Override DEFAULT_TEST_FILES if custom data_dir provided
if args.data_dir:
import glob as _glob
for split_name, old_path in list(DEFAULT_TEST_FILES.items()):
new_path = os.path.join(args.data_dir, os.path.basename(old_path))
if os.path.exists(new_path):
DEFAULT_TEST_FILES[split_name] = new_path
main(args.model, model_dir=model_dir)