Skip to content

Commit f07f7a5

Browse files
authored
fix linting (#842)
1 parent 48e8b59 commit f07f7a5

1 file changed

Lines changed: 28 additions & 36 deletions

File tree

lmms_eval/tasks/lemonade/utils.py

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import ast
22
import os
3-
import cv2
4-
import numpy as np
53
import zipfile
6-
import yaml
74
from collections import defaultdict
85
from pathlib import Path
9-
from PIL import Image
106
from typing import Any, Optional
7+
8+
import cv2
9+
import numpy as np
10+
import yaml
1111
from huggingface_hub import hf_hub_download
12+
from PIL import Image
1213

1314
with open(Path(__file__).parent / "lemonade.yaml", "r") as f:
1415
raw_data = f.readlines()
@@ -39,7 +40,7 @@ def load_video(video_file: str, start_frame: int, end_frame: int, max_num_frames
3940
"""
4041

4142
cap = cv2.VideoCapture(video_file)
42-
try:
43+
try:
4344
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
4445
start_frame = max(0, start_frame)
4546
end_frame = min(end_frame, total_frames - 1)
@@ -61,11 +62,12 @@ def load_video(video_file: str, start_frame: int, end_frame: int, max_num_frames
6162
finally:
6263
cap.release()
6364

65+
6466
def parse_options(options: list[str]) -> str:
6567
"""
6668
Format a list of multiple-choice options into a string.
67-
The function assigns letters to each option and returns them in a newline-separated string.
68-
69+
The function assigns letters to each option and returns them in a newline-separated string.
70+
6971
Args:
7072
options (list[str]): A list of option strings.
7173
@@ -100,12 +102,9 @@ def lemonade_doc_to_visual(doc: dict[str, Any]) -> list[Image.Image]:
100102
end = int(doc["End"])
101103
frames = load_video(video_path, start, end, max_num_frames=max_num_frames)
102104
else:
103-
raise FileNotFoundError(
104-
f"Video file not found: {video_path}. "
105-
f"Expected video for clip '{doc['Clip']}' at {video_path}"
106-
)
105+
raise FileNotFoundError(f"Video file not found: {video_path}. " f"Expected video for clip '{doc['Clip']}' at {video_path}")
107106
return frames
108-
107+
109108

110109
def lemonade_doc_to_text(doc: dict[str, Any], lmms_eval_specific_kwargs: Optional[dict[str, Any]] = None) -> str:
111110
"""
@@ -119,10 +118,10 @@ def lemonade_doc_to_text(doc: dict[str, Any], lmms_eval_specific_kwargs: Optiona
119118

120119
if lmms_eval_specific_kwargs is None:
121120
lmms_eval_specific_kwargs = {}
122-
121+
123122
pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "")
124123
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "")
125-
124+
126125
question = "Question: " + doc["Question"]
127126
parsed_options = parse_options(ast.literal_eval(doc["Answers"]))
128127
choices = "Choices:\n" + parsed_options
@@ -133,18 +132,18 @@ def lemonade_doc_to_text(doc: dict[str, Any], lmms_eval_specific_kwargs: Optiona
133132
def get_multi_choice_info(options: list[str]) -> tuple[dict[str, str], list[str]]:
134133
"""
135134
Map a list of options to letter labels (A, B, C, ...).
136-
135+
137136
Args:
138137
options: The set of answer options
139138
Returns:
140-
tuple[dict[str, str], list[str]]:
139+
tuple[dict[str, str], list[str]]:
141140
- index2ans: Mapping from letters to option text.
142141
- all_choices: List of the assigned letters.
143142
"""
144-
143+
145144
if not isinstance(options, list):
146145
raise TypeError(f"Expected list of options, got {type(options)}: {options}")
147-
146+
148147
start_chr = "A"
149148
all_choices = []
150149
index2ans = {}
@@ -157,7 +156,7 @@ def get_multi_choice_info(options: list[str]) -> tuple[dict[str, str], list[str]
157156

158157
def parse_multi_choice_response(response: str, all_choices: list[str], index2ans: dict[str, str]) -> str:
159158
"""
160-
Parse a model response and return the predicted choice label (e.g., "A", "B", "C", "D").
159+
Parse a model response and return the predicted choice label (e.g., "A", "B", "C", "D").
161160
162161
Args:
163162
response (str): The generated response to parse.
@@ -175,7 +174,7 @@ def parse_multi_choice_response(response: str, all_choices: list[str], index2ans
175174

176175
for char in [",", ".", "!", "?", ";", ":", "'"]:
177176
response = response.strip(char)
178-
response = " " + response + " "
177+
response = " " + response + " "
179178

180179
index_ans = True
181180
ans_with_brack = False
@@ -187,7 +186,7 @@ def parse_multi_choice_response(response: str, all_choices: list[str], index2ans
187186
if f"{choice}." in response:
188187
candidates.append(choice)
189188
ans_with_period = True
190-
for choice in all_choices:
189+
for choice in all_choices:
191190
if f"{choice}:" in response:
192191
candidates.append(choice)
193192
ans_with_colon = True
@@ -197,14 +196,14 @@ def parse_multi_choice_response(response: str, all_choices: list[str], index2ans
197196
candidates.append(choice)
198197
ans_with_brack = True
199198
if len(candidates) == 0:
200-
for choice in all_choices:
199+
for choice in all_choices:
201200
if f"{choice} " in response:
202201
candidates.append(choice)
203202
if len(candidates) == 0 and len(response.split()) > 5:
204203
for index, ans in index2ans.items():
205204
if ans.lower() in response.lower():
206205
candidates.append(index)
207-
index_ans = False
206+
index_ans = False
208207
if len(candidates) == 0:
209208
pred_index = "A"
210209

@@ -241,40 +240,33 @@ def parse_multi_choice_response(response: str, all_choices: list[str], index2ans
241240
def lemonade_process_results(doc: dict[str, Any], results: list[Any]) -> dict[str, dict]:
242241
"""
243242
Process the results from the model and compute accuracy.
244-
243+
245244
Args:
246245
doc: A dictionary representing an entry in the dataset.
247246
results: List of model outputs.
248247
Returns:
249-
A dictionary containing accuracy information.
248+
A dictionary containing accuracy information.
250249
"""
251-
250+
252251
pred = results[0]
253252
index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["Answers"]))
254253
parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
255254

256-
acc = {
257-
"QID": doc["QID"],
258-
"category": doc["Category"],
259-
"subcategory": doc["Subcategory"],
260-
"difficulty": doc["Difficulty"],
261-
"answer": doc["Correct Answer"],
262-
"parsed_pred": parsed_pred,
263-
"original_pred": pred
264-
}
255+
acc = {"QID": doc["QID"], "category": doc["Category"], "subcategory": doc["Subcategory"], "difficulty": doc["Difficulty"], "answer": doc["Correct Answer"], "parsed_pred": parsed_pred, "original_pred": pred}
265256
return {"acc": acc}
266257

267258

268259
def lemonade_aggregate_results(results: list[dict[str, Any]]) -> float:
269260
"""
270261
Aggregate the results from the evaluation.
271-
262+
272263
Args:
273264
results: List of dicts containing individual evaluation results.
274265
Returns:
275266
overall_acc: Overall accuracy.
276267
277268
"""
269+
278270
def compute_accuracy(grouped_results):
279271
acc_dict = {}
280272
for key, samples in grouped_results.items():

0 commit comments

Comments
 (0)