11from typing import Dict , List
22
33import datasets
4-
4+ from lm_eval . tasks . hendrycks_math . utils import is_equiv
55
66def extract_answer (output : str ) -> str :
7+ '''
8+ Input: model-generated solution
9+ Output: extracted final answer. Output "" if the final answer is not inside \\ boxed
10+ '''
711 try :
812 answer = remove_boxed (last_boxed_only_string (output ))
913 return answer
1014 except :
1115 return ""
1216
1317def process_result (answer : str , solution : str ) -> Dict [str , int ]:
18+ '''
19+ Input: answer - gold final answer, solution - predicted final answer
20+ Output: whether the gold answer and predicted final answers are equivalent
21+ '''
1422 retval = 0
1523 if is_equiv (answer , solution ):
1624 retval = 1
17- return retval
18-
19-
20- # string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
21- def is_equiv (str1 , str2 , verbose = False ):
22- if str1 is None and str2 is None :
23- print ("WARNING: Both None" )
24- return True
25- if str1 is None or str2 is None :
26- return False
27-
28- try :
29- ss1 = strip_string (str1 )
30- ss2 = strip_string (str2 )
31- if verbose :
32- print (ss1 , ss2 )
33- return ss1 == ss2
34- except Exception :
35- return str1 == str2
36-
37-
38- def remove_boxed (s ):
39- if "\\ boxed " in s :
40- left = "\\ boxed "
41- assert s [: len (left )] == left
42- return s [len (left ) :]
43-
44- left = "\\ boxed{"
45-
46- assert s [: len (left )] == left
47- assert s [- 1 ] == "}"
48-
49- return s [len (left ) : - 1 ]
50-
51-
52- def last_boxed_only_string (string ):
53- idx = string .rfind ("\\ boxed" )
54- if "\\ boxed " in string :
55- return "\\ boxed " + string .split ("\\ boxed " )[- 1 ].split ("$" )[0 ]
56- if idx < 0 :
57- idx = string .rfind ("\\ fbox" )
58- if idx < 0 :
59- return None
60-
61- i = idx
62- right_brace_idx = None
63- num_left_braces_open = 0
64- while i < len (string ):
65- if string [i ] == "{" :
66- num_left_braces_open += 1
67- if string [i ] == "}" :
68- num_left_braces_open -= 1
69- if num_left_braces_open == 0 :
70- right_brace_idx = i
71- break
72- i += 1
73-
74- if right_brace_idx is None :
75- retval = None
76- else :
77- retval = string [idx : right_brace_idx + 1 ]
78-
79- return retval
80-
81-
82- def fix_fracs (string ):
83- substrs = string .split ("\\ frac" )
84- new_str = substrs [0 ]
85- if len (substrs ) > 1 :
86- substrs = substrs [1 :]
87- for substr in substrs :
88- new_str += "\\ frac"
89- if substr [0 ] == "{" :
90- new_str += substr
91- else :
92- try :
93- assert len (substr ) >= 2
94- except AssertionError :
95- return string
96- a = substr [0 ]
97- b = substr [1 ]
98- if b != "{" :
99- if len (substr ) > 2 :
100- post_substr = substr [2 :]
101- new_str += "{" + a + "}{" + b + "}" + post_substr
102- else :
103- new_str += "{" + a + "}{" + b + "}"
104- else :
105- if len (substr ) > 2 :
106- post_substr = substr [2 :]
107- new_str += "{" + a + "}" + b + post_substr
108- else :
109- new_str += "{" + a + "}" + b
110- string = new_str
111- return string
112-
113-
114- def fix_a_slash_b (string ):
115- if len (string .split ("/" )) != 2 :
116- return string
117- a = string .split ("/" )[0 ]
118- b = string .split ("/" )[1 ]
119- try :
120- a = int (a )
121- b = int (b )
122- assert string == "{}/{}" .format (a , b )
123- new_string = "\\ frac{" + str (a ) + "}{" + str (b ) + "}"
124- return new_string
125- except AssertionError :
126- return string
127-
128-
129- def remove_right_units (string ):
130- # "\\text{ " only ever occurs (at least in the val set) when describing units
131- if "\\ text{ " in string :
132- splits = string .split ("\\ text{ " )
133- assert len (splits ) == 2
134- return splits [0 ]
135- else :
136- return string
137-
138-
139- def fix_sqrt (string ):
140- if "\\ sqrt" not in string :
141- return string
142- splits = string .split ("\\ sqrt" )
143- new_string = splits [0 ]
144- for split in splits [1 :]:
145- if split [0 ] != "{" :
146- a = split [0 ]
147- new_substr = "\\ sqrt{" + a + "}" + split [1 :]
148- else :
149- new_substr = "\\ sqrt" + split
150- new_string += new_substr
151- return new_string
152-
153-
154- def strip_string (string ):
155- # linebreaks
156- string = string .replace ("\n " , "" )
157-
158- # remove inverse spaces
159- string = string .replace ("\\ !" , "" )
160-
161- # replace \\ with \
162- string = string .replace ("\\ \\ " , "\\ " )
163-
164- # replace tfrac and dfrac with frac
165- string = string .replace ("tfrac" , "frac" )
166- string = string .replace ("dfrac" , "frac" )
167-
168- # remove \left and \right
169- string = string .replace ("\\ left" , "" )
170- string = string .replace ("\\ right" , "" )
171-
172- # Remove circ (degrees)
173- string = string .replace ("^{\\ circ}" , "" )
174- string = string .replace ("^\\ circ" , "" )
175-
176- # remove dollar signs
177- string = string .replace ("\\ $" , "" )
178-
179- # remove units (on the right)
180- string = remove_right_units (string )
181-
182- # remove percentage
183- string = string .replace ("\\ %" , "" )
184- string = string .replace ("\%" , "" ) # noqa: W605
185-
186- # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
187- string = string .replace (" ." , " 0." )
188- string = string .replace ("{." , "{0." )
189- # if empty, return empty string
190- if len (string ) == 0 :
191- return string
192- if string [0 ] == "." :
193- string = "0" + string
194-
195- # to consider: get rid of e.g. "k = " or "q = " at beginning
196- if len (string .split ("=" )) == 2 :
197- if len (string .split ("=" )[0 ]) <= 2 :
198- string = string .split ("=" )[1 ]
199-
200- # fix sqrt3 --> sqrt{3}
201- string = fix_sqrt (string )
202-
203- # remove spaces
204- string = string .replace (" " , "" )
205-
206- # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
207- string = fix_fracs (string )
208-
209- # manually change 0.5 --> \frac{1}{2}
210- if string == "0.5" :
211- string = "\\ frac{1}{2}"
212-
213- # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
214- string = fix_a_slash_b (string )
215-
216- return string
25+ return retval
0 commit comments