Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions evaluation/data_processing/answer_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def strip_string(string):
string = string.replace("infinity", "\\infty")
if "\\infty" not in string:
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")
string = string.replace("+\\infty", "\\infty")

# and
# string = string.replace("and", "")
Expand Down Expand Up @@ -305,7 +305,6 @@ def extract_ocwcourses_few_shot_answer(question, reasoning, task):
patt = regex.search(r"final answer is (?P<ans>.*)\. I hope it is correct.", reasoning)
if patt is None:
pred = "[invalid]"
print(f"DEBUG >>>\n{reasoning}", flush=True)
else:
pred = patt.group('ans')
return pred
Expand All @@ -331,7 +330,6 @@ def extract_cmath_few_shot_test(question, reasoning, task):
try:
ans = [s for s in regex.findall(r'-?\d+\.?\d*', ans)][-1]
except:
print(f"DEBUG CMATH: {reasoning}", flush=True)
ans = "[invalid]"
else:
ans = extract_last_single_answer(question, reasoning, task)
Expand Down
6 changes: 1 addition & 5 deletions evaluation/eval/eval_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ def is_correct(item, pred_key='prediction', prec=1e-3):
if is_correct(item_cpy, pred_key=pred_key, prec=prec):
pred_matched.add(i)
ans_matched.add(j)
if item_cpy[pred_key] == '2,3,4':
print(item, flush=True)
print("wtf", flush=True)
return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
elif isinstance(pred, str) and isinstance(ans, str):
if '\\cup' in pred and '\\cup' in ans:
Expand All @@ -40,8 +37,7 @@ def is_correct(item, pred_key='prediction', prec=1e-3):
label = label or (ans and pred == ans) or math_equal(pred, ans)
return label
else:
print(item, flush=True)
raise NotImplementedError()
raise NotImplementedError(f"Unsupported types: pred={type(pred)}, ans={type(ans)}")

def eval_math(item, pred_key='prediction', prec=1e-3):
pred = item[pred_key]
Expand Down