Skip to content

Commit f3ee423

Browse files
authored
refactor: extract shared task utilities to reduce cross-task code duplication (#1122)
* refactor: extract default_template_yaml loader to shared utility Deduplicate the repeated YAML-loading boilerplate (open file, filter !function lines, yaml.safe_load) across 11 task utils into a single load_default_template_yaml() in _task_utils/default_template_yaml.py. Affected tasks: air_bench, blink, egoschema, erqa, illusionbench, mix_evals/image2text, mix_evals/video2text, perceptiontest/test, perceptiontest/val, where2place, worldqa. * refactor: extract MMMU MCQ parsing to shared mmmu_mcq_utils Consolidate duplicated multi-choice response parsers (get_multi_choice_info, parse_mmmu_multi_choice_response, and variants for jmmmu, jmmmu_pro, videommmu) into a single shared module at _task_utils/mmmu_mcq_utils.py. Affected tasks: mmmu, mmmu/reasoning, mmmu_pro, mmmu_pro/reasoning, jmmmu, jmmmu_pro, videommmu, mmmu/utils_group_img. * refactor: extract ASR WER computation to shared asr_wer_utils Consolidate duplicated EvaluationTokenizer, remove_sp, compute_wer, and normalization helpers across 7 ASR task utils into a single shared module at tasks/asr_wer_utils.py. Affected tasks: common_voice_15, fleurs, gigaspeech, librispeech, open_asr, people_speech, tedlium.
1 parent 8d34fca commit f3ee423

29 files changed

Lines changed: 742 additions & 1581 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pathlib import Path
2+
3+
import yaml
4+
5+
6+
def load_default_template_yaml(task_file):
7+
with open(Path(task_file).parent / "_default_template_yaml", "r") as f:
8+
safe_data = [line for line in f if "!function" not in line]
9+
return yaml.safe_load("".join(safe_data))
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
import random
2+
import re
3+
4+
import numpy as np
5+
6+
7+
def get_multi_choice_info(options, start_chr="A"):
8+
all_choices = []
9+
index2ans = {}
10+
for i, option in enumerate(options):
11+
choice = chr(ord(start_chr) + i)
12+
index2ans[choice] = option
13+
all_choices.append(choice)
14+
return index2ans, all_choices
15+
16+
17+
def parse_mmmu_multi_choice_response(response, all_choices, index2ans):
18+
for char in [",", ".", "!", "?", ";", ":", "'"]:
19+
response = response.strip(char)
20+
response = " " + response + " "
21+
22+
index_ans = True
23+
ans_with_brack = False
24+
candidates = []
25+
26+
for choice in all_choices:
27+
if f"({choice})" in response:
28+
candidates.append(choice)
29+
ans_with_brack = True
30+
31+
if len(candidates) == 0:
32+
for choice in all_choices:
33+
if f"{choice} " in response:
34+
candidates.append(choice)
35+
36+
if len(candidates) == 0:
37+
for choice in all_choices:
38+
if f"{choice}." in response:
39+
candidates.append(choice)
40+
41+
if len(candidates) == 0 and len(response.split()) > 5:
42+
for index, ans in index2ans.items():
43+
if ans.lower() in response.lower():
44+
candidates.append(index)
45+
index_ans = False
46+
47+
if len(candidates) == 0:
48+
pred_index = random.choice(all_choices)
49+
elif len(candidates) > 1:
50+
start_indexes = []
51+
if index_ans:
52+
if ans_with_brack:
53+
for can in candidates:
54+
start_indexes.append(response.rfind(f"({can})"))
55+
else:
56+
for can in candidates:
57+
start_indexes.append(response.rfind(f" {can} "))
58+
else:
59+
for can in candidates:
60+
start_indexes.append(response.lower().rfind(index2ans[can].lower()))
61+
pred_index = candidates[np.argmax(start_indexes)]
62+
else:
63+
pred_index = candidates[0]
64+
65+
return pred_index
66+
67+
68+
def parse_jmmmu_multi_choice_response(response, all_choices, index2ans):
69+
for char in [",", ".", "!", "?", ";", ":", "'", "、", "。", "!", "?", ";", ":"]:
70+
response = response.strip(char)
71+
response = " " + response + " "
72+
73+
japanese_char_pattern = r"[\u3040-\u30FF\u4E00-\u9FFF]"
74+
index_ans = True
75+
ans_with_brack = False
76+
candidates = []
77+
78+
for choice in all_choices:
79+
if f"({choice})" in response:
80+
candidates.append(choice)
81+
ans_with_brack = True
82+
83+
if len(candidates) == 0:
84+
for choice in all_choices:
85+
if f"{choice} " in response:
86+
candidates.append(choice)
87+
88+
if len(candidates) == 0:
89+
for choice in all_choices:
90+
pattern = rf"{japanese_char_pattern}{choice}{japanese_char_pattern}"
91+
if re.search(pattern, response):
92+
candidates.append(choice)
93+
94+
if len(candidates) == 0:
95+
for choice in all_choices:
96+
if f"{choice}." in response:
97+
candidates.append(choice)
98+
99+
if len(candidates) == 0 and len(response.split()) > 5:
100+
for index, ans in index2ans.items():
101+
if ans.lower() in response.lower():
102+
candidates.append(index)
103+
index_ans = False
104+
105+
if len(candidates) == 0:
106+
pred_index = random.choice(all_choices)
107+
elif len(candidates) > 1:
108+
start_indexes = []
109+
if index_ans:
110+
if ans_with_brack:
111+
for can in candidates:
112+
start_indexes.append(response.rfind(f"({can})"))
113+
else:
114+
for can in candidates:
115+
start_indexes.append(response.rfind(f" {can} "))
116+
else:
117+
for can in candidates:
118+
start_indexes.append(response.lower().rfind(index2ans[can].lower()))
119+
pred_index = candidates[np.argmax(start_indexes)]
120+
else:
121+
pred_index = candidates[0]
122+
123+
return pred_index
124+
125+
126+
def parse_videommmu_multi_choice_response(response, all_choices, index2ans):
127+
if response == "API Error" or response == "":
128+
return "API Error"
129+
130+
for char in [",", ".", "!", "?", ";", ":", "'"]:
131+
response = response.strip(char)
132+
response = " " + response + " "
133+
134+
index_ans = True
135+
ans_with_brack = False
136+
ans_with_period = False
137+
ans_with_colon = False
138+
candidates = []
139+
140+
for choice in all_choices:
141+
if f"{choice}." in response:
142+
candidates.append(choice)
143+
ans_with_period = True
144+
145+
for choice in all_choices:
146+
if f"{choice}:" in response:
147+
candidates.append(choice)
148+
ans_with_colon = True
149+
150+
if len(candidates) == 0:
151+
for choice in all_choices:
152+
if f"({choice})" in response:
153+
candidates.append(choice)
154+
ans_with_brack = True
155+
156+
if len(candidates) == 0:
157+
for choice in all_choices:
158+
if f"{choice} " in response:
159+
candidates.append(choice)
160+
161+
if len(candidates) == 0 and len(response.split()) > 5:
162+
for index, ans in index2ans.items():
163+
if ans.lower() in response.lower():
164+
candidates.append(index)
165+
index_ans = False
166+
167+
if len(candidates) == 0:
168+
pred_index = "No Answer Found."
169+
elif len(candidates) > 1:
170+
start_indexes = []
171+
if index_ans:
172+
if ans_with_period:
173+
for can in candidates:
174+
start_indexes.append(response.rfind(f"{can}."))
175+
elif ans_with_colon:
176+
for can in candidates:
177+
start_indexes.append(response.rfind(f"{can}:"))
178+
elif ans_with_brack:
179+
for can in candidates:
180+
start_indexes.append(response.rfind(f"({can})"))
181+
else:
182+
for can in candidates:
183+
start_indexes.append(response.rfind(f" {can} "))
184+
else:
185+
for can in candidates:
186+
start_indexes.append(response.lower().rfind(index2ans[can].lower()))
187+
pred_index = candidates[np.argmax(start_indexes)]
188+
else:
189+
pred_index = candidates[0]
190+
191+
return pred_index
192+
193+
194+
def parse_jmmmu_pro_multi_choice_response(response, all_choices):
195+
fullwidth_map = {chr(ord("A") + i): chr(ord("A") + i) for i in range(26)}
196+
fullwidth_trans = str.maketrans(fullwidth_map)
197+
198+
option_line_re = re.compile(
199+
r"""^\s*
200+
(?:[-*・>\u2022]\s*)?
201+
[A-ZA-Z]
202+
[\.\)\u3001\u3002::]
203+
""",
204+
re.VERBOSE | re.IGNORECASE,
205+
)
206+
207+
explicit_re = re.compile(
208+
r"""(?ix)
209+
(?:answer|final|correct|solution|ans
210+
|正解(?:は)?
211+
|答え(?:は)?
212+
|解答(?:は)?
213+
)
214+
\s*[::]?\s*
215+
[【\[\(\u3010\u3011\*_-]*
216+
([A-Z])
217+
[】\]\)\*_-]*
218+
\b
219+
"""
220+
)
221+
222+
markdown_letter_re = re.compile(
223+
r"""
224+
[【\[\(\*]*([A-Z])[】\]\)\*]*
225+
\b
226+
""",
227+
re.IGNORECASE | re.VERBOSE,
228+
)
229+
230+
def _normalize(text):
231+
return text.translate(fullwidth_trans)
232+
233+
def _is_option_line(line):
234+
return bool(option_line_re.match(line))
235+
236+
def _explicit_in_line(line):
237+
match = explicit_re.search(line)
238+
if match:
239+
return match.group(1).upper()
240+
return None
241+
242+
def _last_standalone_letter(lines):
243+
candidates = []
244+
for line in lines:
245+
if _is_option_line(line):
246+
continue
247+
for match in markdown_letter_re.finditer(line):
248+
candidates.append(match.group(1).upper())
249+
return candidates[-1] if candidates else None
250+
251+
def parse_answer(text):
252+
if not text:
253+
return None
254+
255+
normalized = _normalize(text)
256+
lines = [line.strip() for line in normalized.splitlines() if line.strip()]
257+
if not lines:
258+
return None
259+
260+
first_line_hit = _explicit_in_line(lines[0])
261+
if first_line_hit:
262+
return first_line_hit
263+
264+
any_line_hit = _explicit_in_line(normalized)
265+
if any_line_hit:
266+
return any_line_hit
267+
268+
return _last_standalone_letter(lines)
269+
270+
parsed_letter = parse_answer(response)
271+
if parsed_letter and parsed_letter in all_choices:
272+
return parsed_letter
273+
274+
return "X"

lmms_eval/tasks/air_bench/utils.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import random
44
import re
55
import time
6-
from pathlib import Path
76

87
import numpy as np
98
import requests
10-
import yaml
119
from loguru import logger as eval_logger
1210

11+
from lmms_eval.tasks._task_utils.default_template_yaml import load_default_template_yaml
1312
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
1413

1514

@@ -24,15 +23,7 @@ def air_bench_doc_to_text_chat(doc, lmms_eval_specific_kwargs):
2423
return f"{pre_prompt}{question}{post_prompt}"
2524

2625

27-
with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
28-
raw_data = f.readlines()
29-
safe_data = []
30-
for i, line in enumerate(raw_data):
31-
# remove function definition since yaml load cannot handle it
32-
if "!function" not in line:
33-
safe_data.append(line)
34-
35-
config = yaml.safe_load("".join(safe_data))
26+
config = load_default_template_yaml(__file__)
3627

3728
# specify api type and key in .env
3829
GPT_EVAL_MODEL_NAME = os.getenv("MODEL_VERSION", "gpt-4o-2024-11-20")
@@ -107,17 +98,14 @@ def get_eval(max_tokens: int, content: str, retries: int = retries):
10798

10899

109100
def air_bench_process_results_chat(doc, result):
110-
path = doc["path"]
111101
question = doc["question"]
112102
answer_gt = doc["answer_gt"]
113-
task_name = doc["task_name"]
114-
dataset_name = doc["dataset_name"]
115103
response = result[0]
116104

117-
if response == None:
105+
if response is None:
118106
exit(1)
119107

120-
if doc["meta_info"] == None:
108+
if doc["meta_info"] is None:
121109
print("lack meta info")
122110
exit(1)
123111
else:

0 commit comments

Comments
 (0)