Skip to content

Commit f58a02b

Browse files
committed
Refactor vLLM model files and add OlympiadBench evaluation utilities
- Cleaned up imports and removed unused variables in `vllm.py`. - Updated threading configuration in `simple/vllm.py` to use environment variables. - Introduced new utility functions for processing OlympiadBench documents and results in `utils.py`, `zh_utils.py`, and `en_utils.py`. - Added evaluation logic for OlympiadBench tasks in `olympiadbench_evals.py`. - Created multiple YAML configuration files for various OlympiadBench tasks, including math and physics in both English and Chinese. - Implemented aggregation functions for results processing in the OlympiadBench context.
1 parent 56c0679 commit f58a02b

27 files changed

Lines changed: 1209 additions & 23 deletions

lmms_eval/models/chat/vllm.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,12 @@
1-
import asyncio
2-
import base64
3-
import json
4-
import os
5-
import time
6-
from concurrent.futures import ThreadPoolExecutor
7-
from copy import deepcopy
8-
from io import BytesIO
9-
from multiprocessing import cpu_count
10-
from typing import List, Optional, Tuple, Union
1+
from typing import List, Tuple
112

12-
import numpy as np
13-
from accelerate import Accelerator, DistributedType
14-
from decord import VideoReader, cpu
15-
from loguru import logger as eval_logger
16-
from PIL import Image
173
from tqdm import tqdm
184

195
from lmms_eval.api.instance import Instance
20-
from lmms_eval.api.model import lmms
216
from lmms_eval.api.registry import register_model
227
from lmms_eval.models.simple.vllm import VLLM as VLLMSimple
238
from lmms_eval.protocol import ChatMessages
249

25-
NUM_SECONDS_TO_SLEEP = 5
26-
2710
try:
2811
from vllm import LLM, SamplingParams
2912
except ImportError:

lmms_eval/models/simple/vllm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from lmms_eval.api.model import lmms
2121
from lmms_eval.api.registry import register_model
2222

23-
NUM_SECONDS_TO_SLEEP = 5
23+
NUM_SECONDS_TO_SLEEP = os.getenv("NUM_SECONDS_TO_SLEEP", 5)
24+
WORKERS = os.getenv("WORKERS", 32)
2425

2526
try:
2627
from vllm import LLM, SamplingParams
@@ -37,7 +38,6 @@ def __init__(
3738
gpu_memory_utilization: float = 0.8,
3839
batch_size: int = 1,
3940
max_frame_num: int = 32,
40-
threads: int = 16, # Threads to use for decoding visuals
4141
trust_remote_code: Optional[bool] = True,
4242
chat_template: Optional[str] = None,
4343
min_image_pixels: int = 28, # minimum image dimension, required for Qwen 2/2.5-VL models
@@ -49,11 +49,10 @@ def __init__(
4949
# Here we just use the same token as llava for convenient
5050
self.model = model
5151
self.max_frame_num = max_frame_num
52-
self.threads = threads
5352
self.chat_template = chat_template
5453
self.min_image_pixels = min_image_pixels
5554
# Qwen 2/2.5-VL models enforce minimum image dimensions
56-
self._enforce_image_resize = self._is_qwen_vl_model(model_version)
55+
self._enforce_image_resize = self._is_qwen_vl_model(model)
5756

5857
# Convert any string arguments that start with { and end with } to dictionaries
5958
for key, value in kwargs.items():
@@ -188,7 +187,7 @@ def generate_until(self, requests) -> List[str]:
188187
visuals = self.flatten(visuals)
189188
imgs = [] # multiple images or frames for video
190189
all_tasks = []
191-
with ThreadPoolExecutor(max_workers=self.threads) as executor:
190+
with ThreadPoolExecutor(max_workers=WORKERS) as executor:
192191
for visual in visuals:
193192
if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual):
194193
all_tasks.append(executor.submit(self.encode_video, visual))
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright 2025 Xiaomi Corporation.
2+
3+
4+
import importlib
5+
6+
from wrapt_timeout_decorator import timeout
7+
8+
9+
def patch_target_module(
10+
to_patch: str,
11+
replace_with,
12+
):
13+
to_patch = to_patch.split(".")
14+
assert len(to_patch) > 1, "must have an object to patch"
15+
16+
to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1]
17+
to_patch = ".".join(to_patch)
18+
source = importlib.import_module(to_patch)
19+
setattr(source, obj_name_to_patch, replace_with)
20+
21+
22+
def timeout_adapter(func=None, **kwargs):
23+
timeout_val = kwargs.pop("timeout_seconds", None)
24+
return timeout(dec_timeout=timeout_val, use_signals=False, **kwargs)
25+
26+
27+
# replace the signal-based timeout with a non-signal-based timeout to allow multithreading
28+
patch_target_module("math_verify.utils.timeout", timeout_adapter)
29+
patch_target_module("math_verify.parser.timeout", timeout_adapter)
30+
patch_target_module("math_verify.grader.timeout", timeout_adapter)
31+
32+
33+
import os
34+
35+
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig
36+
from math_verify import *
37+
38+
39+
def monkeypatch_math_verify_logger():
40+
"""
41+
replace the loggers in math_verify with a self-returning object, so that it does not print any logs
42+
"""
43+
import math_verify
44+
45+
class SelfReturningObject:
46+
def __getattr__(self, name):
47+
return self
48+
49+
def __call__(self, *args, **kwargs):
50+
return self
51+
52+
def __getitem__(self, key):
53+
return self
54+
55+
self_returning_object = SelfReturningObject()
56+
57+
def bfs_search(module, lst):
58+
lst.append(module)
59+
for name, obj in module.__dict__.items():
60+
if isinstance(obj, type(math_verify)):
61+
if obj not in lst:
62+
bfs_search(obj, lst)
63+
64+
all_modules = []
65+
bfs_search(math_verify, all_modules)
66+
all_modules = [module for module in all_modules if module.__name__.startswith("math_verify")]
67+
for module in all_modules:
68+
if hasattr(module, "logger"):
69+
module.logger = self_returning_object
70+
71+
72+
class MathVerifyFn:
73+
def __init__(self, correct_score=1.0, incorrect_score=0.0, timeout_seconds=10, strict=True, silent=True):
74+
self.correct_score = correct_score
75+
self.incorrect_score = incorrect_score
76+
self.timeout_seconds = timeout_seconds
77+
self.strict = strict
78+
if silent:
79+
monkeypatch_math_verify_logger()
80+
81+
def __call__(self, solution_str: str, ground_truth) -> float:
82+
# return self.compute_score(solution_str, ground_truth)
83+
return self.compute_score_with_ext(solution_str, ground_truth)
84+
85+
def preprocess_answer(self, annotated_answer: str) -> str:
86+
if annotated_answer:
87+
if annotated_answer.startswith("$") and annotated_answer.endswith("$"):
88+
annotated_answer = f"\\boxed{{{annotated_answer.strip('$')}}}"
89+
elif "\\boxed" not in annotated_answer:
90+
annotated_answer = f"\\boxed{{{annotated_answer}}}"
91+
return annotated_answer
92+
93+
def parse_LatexExpr(self, input_str: str):
94+
config = NormalizationConfig(
95+
basic_latex=True,
96+
units=True,
97+
malformed_operators=True,
98+
nits=True,
99+
boxed="last",
100+
equations=False,
101+
)
102+
return parse(
103+
input_str,
104+
extraction_mode="first_match",
105+
extraction_config=[
106+
LatexExtractionConfig(boxed_match_priority=0, normalization_config=config),
107+
],
108+
parsing_timeout=self.timeout_seconds,
109+
)
110+
111+
def parse_String(self, input_str: str):
112+
return parse(
113+
input_str,
114+
extraction_mode="first_match",
115+
extraction_config=[
116+
StringExtractionConfig(),
117+
],
118+
parsing_timeout=self.timeout_seconds,
119+
)
120+
121+
def judge_with_ext(self, solution_str: str, ground_truth) -> float:
122+
prediction_str = solution_str
123+
answer_str = self.preprocess_answer(ground_truth)
124+
answer_parsed = self.parse_LatexExpr(answer_str)
125+
126+
def _judger(x):
127+
if len(x) == 0:
128+
return False
129+
if verify(answer_parsed, x, timeout_seconds=self.timeout_seconds, strict=self.strict):
130+
return True
131+
return False
132+
133+
def ext_to_str(x):
134+
for item in x:
135+
if isinstance(item, str):
136+
return item
137+
for item in x:
138+
return str(item)
139+
return ""
140+
141+
ext_pred = self.parse_LatexExpr(prediction_str)
142+
ext_str = ext_to_str(ext_pred)
143+
# print(solution_str[:20], ground_truth, ext_pred, ext_str, _judger(ext_pred))
144+
if _judger(ext_pred):
145+
return True, ext_str
146+
return False, ext_str
147+
148+
def compute_score_with_ext(self, solution_str: str, ground_truth) -> float:
149+
try:
150+
is_correct, ext_pred = self.judge_with_ext(solution_str, ground_truth)
151+
if is_correct:
152+
return self.correct_score, ext_pred
153+
else:
154+
return self.incorrect_score, ext_pred
155+
except Exception as e:
156+
print(e)
157+
return self.incorrect_score, ""
158+
159+
160+
if __name__ == "__main__":
161+
math_verify_fn = MathVerifyFn()
162+
print(math_verify_fn("\\boxed{D}", "D"))
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Xiaomi Corporation.
2+
3+
import datetime
4+
import json
5+
import os
6+
7+
from loguru import logger as eval_logger
8+
9+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
10+
from lmms_eval.tasks.olympiadbench_official.olympiadbench_evals import (
11+
OlympiadBenchEvaluator,
12+
)
13+
14+
dir_name = os.path.dirname(os.path.abspath(__file__))
15+
16+
olympiadbench_evaluator = OlympiadBenchEvaluator()
17+
18+
19+
def olympiadbench_doc_to_visual(doc):
20+
res = []
21+
for i in range(1, 6):
22+
image_key = f"image_{i}"
23+
if doc[image_key] is not None:
24+
res.append(doc[image_key])
25+
return [image.convert("RGB") for image in res]
26+
27+
28+
def olympiadbench_doc_to_text(doc):
29+
question = doc["question"]
30+
subject = doc["subfield"]
31+
mul_ans = doc["is_multiple_answer"]
32+
if mul_ans is None:
33+
mul_ans = False
34+
ans_type = doc["answer_type"]
35+
if ans_type == "Need_human_evaluate":
36+
ans_type = "proof based"
37+
38+
pre_prompt = f"The following is a question from an International {subject} competition.\n"
39+
40+
post_prompt = ""
41+
if not mul_ans:
42+
post_prompt += f"The answer of the question should be {ans_type}.\n"
43+
else:
44+
post_prompt += f"The question has multiple answers, each of them should be {ans_type}.\n"
45+
post_prompt += (
46+
"Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with "
47+
)
48+
if not mul_ans:
49+
post_prompt += '"So the final answer is \\boxed{answer}."\n'
50+
else:
51+
post_prompt += "So the final answer is \\boxed{multiple answers connected with commas}.\n"
52+
53+
final_question = pre_prompt + question + "\n" + post_prompt
54+
return final_question
55+
56+
57+
def olympiadbench_process_results(doc, results):
58+
precision = doc["error"]
59+
is_proving = doc["question_type"] == "Theorem proof" or doc["final_answer"] is None
60+
if precision is None:
61+
precision = 0
62+
prediction = results[0].strip()
63+
64+
if is_proving:
65+
return {"submission": prediction}
66+
else:
67+
prediction = prediction.split("final answer is")[-1]
68+
prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。")
69+
accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision)
70+
accuracy = int(accuracy)
71+
return {"exact_match": accuracy}
72+
73+
74+
def olympiadbench_aggregate_results(results, args):
75+
now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S")
76+
submission_file_name = f"olympiadbench-test-en-submission-{now_date_time}.json"
77+
path = generate_submission_file(submission_file_name, args)
78+
with open(path, "w") as f:
79+
json.dump(results, f, ensure_ascii=False)
80+
print(f"Submission file saved to {path}")
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2025 Xiaomi Corporation.
2+
3+
dataset_path: lscpku/OlympiadBench-official
4+
dataset_kwargs:
5+
token: True
6+
dataset_name: all_no_proof
7+
task : "olympiadbench_all_boxed"
8+
test_split: test
9+
output_type: generate_until
10+
doc_to_visual: !function utils.olympiadbench_doc_to_visual
11+
doc_to_text: !function utils.olympiadbench_doc_to_text
12+
doc_to_target: "answer"
13+
generation_kwargs:
14+
max_new_tokens: 32768
15+
temperature: 0
16+
top_p: 1.0
17+
num_beams: 1
18+
do_sample: false
19+
process_results: !function utils.olympiadbench_process_results
20+
metric_list:
21+
- metric: exact_match
22+
aggregation: mean
23+
higher_is_better: true
24+
- metric: math_verify
25+
aggregation: !function utils.olympiadbench_math_verify_aggregate_results
26+
higher_is_better: true
27+
- metric: Math_English
28+
aggregation: !function utils.olympiadbench_aggregate_results
29+
higher_is_better: true
30+
- metric: Math_Chinese
31+
aggregation: !function utils.olympiadbench_aggregate_results
32+
higher_is_better: true
33+
- metric: Physics_English
34+
aggregation: !function utils.olympiadbench_aggregate_results
35+
higher_is_better: true
36+
- metric: Physics_Chinese
37+
aggregation: !function utils.olympiadbench_aggregate_results
38+
higher_is_better: true
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2025 Xiaomi Corporation.
2+
3+
dataset_path: lscpku/OlympiadBench-image
4+
dataset_kwargs:
5+
token: True
6+
dataset_name: all
7+
task : "olympiadbench_boxed"
8+
test_split: test
9+
output_type: generate_until
10+
doc_to_visual: !function utils.olympiadbench_doc_to_visual
11+
doc_to_text: !function utils.olympiadbench_doc_to_text
12+
doc_to_target: "answer"
13+
generation_kwargs:
14+
max_new_tokens: 16384
15+
temperature: 0
16+
top_p: 1.0
17+
num_beams: 1
18+
do_sample: false
19+
process_results: !function utils.olympiadbench_process_results
20+
metric_list:
21+
- metric: exact_match
22+
aggregation: mean
23+
higher_is_better: true
24+
- metric: math_verify
25+
aggregation: !function utils.olympiadbench_math_verify_aggregate_results
26+
higher_is_better: true
27+
- metric: Math_English
28+
aggregation: !function utils.olympiadbench_aggregate_results
29+
higher_is_better: true
30+
- metric: Math_Chinese
31+
aggregation: !function utils.olympiadbench_aggregate_results
32+
higher_is_better: true
33+
- metric: Physics_English
34+
aggregation: !function utils.olympiadbench_aggregate_results
35+
higher_is_better: true
36+
- metric: Physics_Chinese
37+
aggregation: !function utils.olympiadbench_aggregate_results
38+
higher_is_better: true

0 commit comments

Comments
 (0)