Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lmms_eval/tasks/_task_utils/default_template_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pathlib import Path

import yaml


def load_default_template_yaml(task_file):
with open(Path(task_file).parent / "_default_template_yaml", "r") as f:
safe_data = [line for line in f if "!function" not in line]
return yaml.safe_load("".join(safe_data))
274 changes: 274 additions & 0 deletions lmms_eval/tasks/_task_utils/mmmu_mcq_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
import random
import re

import numpy as np


def get_multi_choice_info(options, start_chr="A"):
all_choices = []
index2ans = {}
for i, option in enumerate(options):
choice = chr(ord(start_chr) + i)
index2ans[choice] = option
all_choices.append(choice)
return index2ans, all_choices


def parse_mmmu_multi_choice_response(response, all_choices, index2ans):
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " "

index_ans = True
ans_with_brack = False
candidates = []

for choice in all_choices:
if f"({choice})" in response:
candidates.append(choice)
ans_with_brack = True

if len(candidates) == 0:
for choice in all_choices:
if f"{choice} " in response:
candidates.append(choice)

if len(candidates) == 0:
for choice in all_choices:
if f"{choice}." in response:
candidates.append(choice)

if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
candidates.append(index)
index_ans = False

if len(candidates) == 0:
pred_index = random.choice(all_choices)
elif len(candidates) > 1:
start_indexes = []
if index_ans:
if ans_with_brack:
for can in candidates:
start_indexes.append(response.rfind(f"({can})"))
else:
for can in candidates:
start_indexes.append(response.rfind(f" {can} "))
else:
for can in candidates:
start_indexes.append(response.lower().rfind(index2ans[can].lower()))
pred_index = candidates[np.argmax(start_indexes)]
else:
pred_index = candidates[0]

return pred_index


def parse_jmmmu_multi_choice_response(response, all_choices, index2ans):
for char in [",", ".", "!", "?", ";", ":", "'", "、", "。", "!", "?", ";", ":"]:
response = response.strip(char)
response = " " + response + " "

japanese_char_pattern = r"[\u3040-\u30FF\u4E00-\u9FFF]"
index_ans = True
ans_with_brack = False
candidates = []

for choice in all_choices:
if f"({choice})" in response:
candidates.append(choice)
ans_with_brack = True

if len(candidates) == 0:
for choice in all_choices:
if f"{choice} " in response:
candidates.append(choice)

if len(candidates) == 0:
for choice in all_choices:
pattern = rf"{japanese_char_pattern}{choice}{japanese_char_pattern}"
if re.search(pattern, response):
candidates.append(choice)

if len(candidates) == 0:
for choice in all_choices:
if f"{choice}." in response:
candidates.append(choice)

if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
candidates.append(index)
index_ans = False

if len(candidates) == 0:
pred_index = random.choice(all_choices)
elif len(candidates) > 1:
start_indexes = []
if index_ans:
if ans_with_brack:
for can in candidates:
start_indexes.append(response.rfind(f"({can})"))
else:
for can in candidates:
start_indexes.append(response.rfind(f" {can} "))
else:
for can in candidates:
start_indexes.append(response.lower().rfind(index2ans[can].lower()))
pred_index = candidates[np.argmax(start_indexes)]
else:
pred_index = candidates[0]

return pred_index


def parse_videommmu_multi_choice_response(response, all_choices, index2ans):
if response == "API Error" or response == "":
return "API Error"

for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " "

index_ans = True
ans_with_brack = False
ans_with_period = False
ans_with_colon = False
candidates = []

for choice in all_choices:
if f"{choice}." in response:
candidates.append(choice)
ans_with_period = True

for choice in all_choices:
if f"{choice}:" in response:
candidates.append(choice)
ans_with_colon = True

if len(candidates) == 0:
for choice in all_choices:
if f"({choice})" in response:
candidates.append(choice)
ans_with_brack = True

if len(candidates) == 0:
for choice in all_choices:
if f"{choice} " in response:
candidates.append(choice)

if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
candidates.append(index)
index_ans = False

if len(candidates) == 0:
pred_index = "No Answer Found."
elif len(candidates) > 1:
start_indexes = []
if index_ans:
if ans_with_period:
for can in candidates:
start_indexes.append(response.rfind(f"{can}."))
elif ans_with_colon:
for can in candidates:
start_indexes.append(response.rfind(f"{can}:"))
elif ans_with_brack:
for can in candidates:
start_indexes.append(response.rfind(f"({can})"))
else:
for can in candidates:
start_indexes.append(response.rfind(f" {can} "))
else:
for can in candidates:
start_indexes.append(response.lower().rfind(index2ans[can].lower()))
pred_index = candidates[np.argmax(start_indexes)]
else:
pred_index = candidates[0]

return pred_index


def parse_jmmmu_pro_multi_choice_response(response, all_choices):
fullwidth_map = {chr(ord("A") + i): chr(ord("A") + i) for i in range(26)}
fullwidth_trans = str.maketrans(fullwidth_map)

option_line_re = re.compile(
r"""^\s*
(?:[-*・>\u2022]\s*)?
[A-ZA-Z]
[\.\)\u3001\u3002::]
""",
re.VERBOSE | re.IGNORECASE,
)

explicit_re = re.compile(
r"""(?ix)
(?:answer|final|correct|solution|ans
|正解(?:は)?
|答え(?:は)?
|解答(?:は)?
)
\s*[::]?\s*
[【\[\(\u3010\u3011\*_-]*
([A-Z])
[】\]\)\*_-]*
\b
"""
)

markdown_letter_re = re.compile(
r"""
[【\[\(\*]*([A-Z])[】\]\)\*]*
\b
""",
re.IGNORECASE | re.VERBOSE,
)

def _normalize(text):
return text.translate(fullwidth_trans)

def _is_option_line(line):
return bool(option_line_re.match(line))

def _explicit_in_line(line):
match = explicit_re.search(line)
if match:
return match.group(1).upper()
return None

def _last_standalone_letter(lines):
candidates = []
for line in lines:
if _is_option_line(line):
continue
for match in markdown_letter_re.finditer(line):
candidates.append(match.group(1).upper())
return candidates[-1] if candidates else None

def parse_answer(text):
if not text:
return None

normalized = _normalize(text)
lines = [line.strip() for line in normalized.splitlines() if line.strip()]
if not lines:
return None

first_line_hit = _explicit_in_line(lines[0])
if first_line_hit:
return first_line_hit

any_line_hit = _explicit_in_line(normalized)
if any_line_hit:
return any_line_hit

return _last_standalone_letter(lines)

parsed_letter = parse_answer(response)
if parsed_letter and parsed_letter in all_choices:
return parsed_letter

return "X"
20 changes: 4 additions & 16 deletions lmms_eval/tasks/air_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import random
import re
import time
from pathlib import Path

import numpy as np
import requests
import yaml
from loguru import logger as eval_logger

from lmms_eval.tasks._task_utils.default_template_yaml import load_default_template_yaml
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file


Expand All @@ -24,15 +23,7 @@ def air_bench_doc_to_text_chat(doc, lmms_eval_specific_kwargs):
return f"{pre_prompt}{question}{post_prompt}"


with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)

config = yaml.safe_load("".join(safe_data))
config = load_default_template_yaml(__file__)

# specify api type and key in .env
GPT_EVAL_MODEL_NAME = os.getenv("MODEL_VERSION", "gpt-4o-2024-11-20")
Expand Down Expand Up @@ -107,17 +98,14 @@ def get_eval(max_tokens: int, content: str, retries: int = retries):


def air_bench_process_results_chat(doc, result):
path = doc["path"]
question = doc["question"]
answer_gt = doc["answer_gt"]
task_name = doc["task_name"]
dataset_name = doc["dataset_name"]
response = result[0]

if response == None:
if response is None:
exit(1)

if doc["meta_info"] == None:
if doc["meta_info"] is None:
print("lack meta info")
exit(1)
else:
Expand Down
Loading