Skip to content

Commit 7c93c13

Browse files
committed
remove redundant code
1 parent b22691e commit 7c93c13

File tree

4 files changed

+28
-418
lines changed

4 files changed

+28
-418
lines changed

eval/chat_benchmarks/AIME24/eval_instruct.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,10 @@ def read_test_examples(self, data_path: str) -> Generator[Dict[str, str], None,
5252
Yields:
5353
Dictionary containing task_id and formatted prompt
5454
"""
55-
try:
56-
with open(data_path, "r") as f:
57-
problems = [json.loads(x) for x in f]
58-
self.logger.info(f"Loaded {len(problems)} problems from {data_path}")
59-
return problems
60-
except Exception as e:
61-
self.logger.error(f"Error loading dataset: {e}")
62-
raise
55+
with open(data_path, "r") as f:
56+
questions = [json.loads(x) for x in f]
57+
self.logger.info(f"Loaded {len(questions)} questions from {data_path}")
58+
return questions
6359

6460
def generate_responses(self, model: LM) -> Dict[str, Any]:
6561
"""
Lines changed: 10 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -1,216 +1,25 @@
11
from typing import Dict, List
22

33
import datasets
4-
4+
from lm_eval.tasks.hendrycks_math.utils import is_equiv
55

66
def extract_answer(output: str) -> str:
7+
'''
8+
Input: model-generated solution
9+
Output: extracted final answer. Output "" if the final answer is not inside \\boxed
10+
'''
711
try:
812
answer = remove_boxed(last_boxed_only_string(output))
913
return answer
1014
except:
1115
return ""
1216

1317
def process_result(answer: str, solution: str) -> Dict[str, int]:
18+
'''
19+
Input: answer - gold final answer, solution - predicted final answer
20+
Output: whether the gold answer and predicted final answers are equivalent
21+
'''
1422
retval = 0
1523
if is_equiv(answer, solution):
1624
retval = 1
17-
return retval
18-
19-
20-
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
21-
def is_equiv(str1, str2, verbose=False):
22-
if str1 is None and str2 is None:
23-
print("WARNING: Both None")
24-
return True
25-
if str1 is None or str2 is None:
26-
return False
27-
28-
try:
29-
ss1 = strip_string(str1)
30-
ss2 = strip_string(str2)
31-
if verbose:
32-
print(ss1, ss2)
33-
return ss1 == ss2
34-
except Exception:
35-
return str1 == str2
36-
37-
38-
def remove_boxed(s):
39-
if "\\boxed " in s:
40-
left = "\\boxed "
41-
assert s[: len(left)] == left
42-
return s[len(left) :]
43-
44-
left = "\\boxed{"
45-
46-
assert s[: len(left)] == left
47-
assert s[-1] == "}"
48-
49-
return s[len(left) : -1]
50-
51-
52-
def last_boxed_only_string(string):
53-
idx = string.rfind("\\boxed")
54-
if "\\boxed " in string:
55-
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
56-
if idx < 0:
57-
idx = string.rfind("\\fbox")
58-
if idx < 0:
59-
return None
60-
61-
i = idx
62-
right_brace_idx = None
63-
num_left_braces_open = 0
64-
while i < len(string):
65-
if string[i] == "{":
66-
num_left_braces_open += 1
67-
if string[i] == "}":
68-
num_left_braces_open -= 1
69-
if num_left_braces_open == 0:
70-
right_brace_idx = i
71-
break
72-
i += 1
73-
74-
if right_brace_idx is None:
75-
retval = None
76-
else:
77-
retval = string[idx : right_brace_idx + 1]
78-
79-
return retval
80-
81-
82-
def fix_fracs(string):
83-
substrs = string.split("\\frac")
84-
new_str = substrs[0]
85-
if len(substrs) > 1:
86-
substrs = substrs[1:]
87-
for substr in substrs:
88-
new_str += "\\frac"
89-
if substr[0] == "{":
90-
new_str += substr
91-
else:
92-
try:
93-
assert len(substr) >= 2
94-
except AssertionError:
95-
return string
96-
a = substr[0]
97-
b = substr[1]
98-
if b != "{":
99-
if len(substr) > 2:
100-
post_substr = substr[2:]
101-
new_str += "{" + a + "}{" + b + "}" + post_substr
102-
else:
103-
new_str += "{" + a + "}{" + b + "}"
104-
else:
105-
if len(substr) > 2:
106-
post_substr = substr[2:]
107-
new_str += "{" + a + "}" + b + post_substr
108-
else:
109-
new_str += "{" + a + "}" + b
110-
string = new_str
111-
return string
112-
113-
114-
def fix_a_slash_b(string):
115-
if len(string.split("/")) != 2:
116-
return string
117-
a = string.split("/")[0]
118-
b = string.split("/")[1]
119-
try:
120-
a = int(a)
121-
b = int(b)
122-
assert string == "{}/{}".format(a, b)
123-
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
124-
return new_string
125-
except AssertionError:
126-
return string
127-
128-
129-
def remove_right_units(string):
130-
# "\\text{ " only ever occurs (at least in the val set) when describing units
131-
if "\\text{ " in string:
132-
splits = string.split("\\text{ ")
133-
assert len(splits) == 2
134-
return splits[0]
135-
else:
136-
return string
137-
138-
139-
def fix_sqrt(string):
140-
if "\\sqrt" not in string:
141-
return string
142-
splits = string.split("\\sqrt")
143-
new_string = splits[0]
144-
for split in splits[1:]:
145-
if split[0] != "{":
146-
a = split[0]
147-
new_substr = "\\sqrt{" + a + "}" + split[1:]
148-
else:
149-
new_substr = "\\sqrt" + split
150-
new_string += new_substr
151-
return new_string
152-
153-
154-
def strip_string(string):
155-
# linebreaks
156-
string = string.replace("\n", "")
157-
158-
# remove inverse spaces
159-
string = string.replace("\\!", "")
160-
161-
# replace \\ with \
162-
string = string.replace("\\\\", "\\")
163-
164-
# replace tfrac and dfrac with frac
165-
string = string.replace("tfrac", "frac")
166-
string = string.replace("dfrac", "frac")
167-
168-
# remove \left and \right
169-
string = string.replace("\\left", "")
170-
string = string.replace("\\right", "")
171-
172-
# Remove circ (degrees)
173-
string = string.replace("^{\\circ}", "")
174-
string = string.replace("^\\circ", "")
175-
176-
# remove dollar signs
177-
string = string.replace("\\$", "")
178-
179-
# remove units (on the right)
180-
string = remove_right_units(string)
181-
182-
# remove percentage
183-
string = string.replace("\\%", "")
184-
string = string.replace("\%", "") # noqa: W605
185-
186-
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
187-
string = string.replace(" .", " 0.")
188-
string = string.replace("{.", "{0.")
189-
# if empty, return empty string
190-
if len(string) == 0:
191-
return string
192-
if string[0] == ".":
193-
string = "0" + string
194-
195-
# to consider: get rid of e.g. "k = " or "q = " at beginning
196-
if len(string.split("=")) == 2:
197-
if len(string.split("=")[0]) <= 2:
198-
string = string.split("=")[1]
199-
200-
# fix sqrt3 --> sqrt{3}
201-
string = fix_sqrt(string)
202-
203-
# remove spaces
204-
string = string.replace(" ", "")
205-
206-
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
207-
string = fix_fracs(string)
208-
209-
# manually change 0.5 --> \frac{1}{2}
210-
if string == "0.5":
211-
string = "\\frac{1}{2}"
212-
213-
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
214-
string = fix_a_slash_b(string)
215-
216-
return string
25+
return retval

eval/chat_benchmarks/AMC23/eval_instruct.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,10 @@ def read_test_examples(self, data_path: str) -> Generator[Dict[str, str], None,
5252
Yields:
5353
Dictionary containing task_id and formatted prompt
5454
"""
55-
try:
56-
with open(data_path, "r") as f:
57-
questions = [json.loads(x) for x in f]
58-
self.logger.info(f"Loaded {len(questions)} questions from {data_path}")
59-
return questions
60-
except Exception as e:
61-
self.logger.error(f"Error loading dataset: {e}")
62-
raise
55+
with open(data_path, "r") as f:
56+
questions = [json.loads(x) for x in f]
57+
self.logger.info(f"Loaded {len(questions)} questions from {data_path}")
58+
return questions
6359

6460
def generate_responses(self, model: LM) -> Dict[str, Any]:
6561
"""

0 commit comments

Comments
 (0)