Skip to content

Commit 300772b

Browse files
feat: add EgoPlan-Bench2 task
1 parent 7108c2c commit 300772b

2 files changed

Lines changed: 106 additions & 0 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
dataset_path: nv-njb/egoplan-bench2
2+
dataset_kwargs:
3+
features: !function utils.egoplan2_features
4+
task: egoplan2
5+
test_split: train
6+
output_type: generate_until
7+
doc_to_visual: !function utils.egoplan2_doc_to_visual
8+
doc_to_text: !function utils.egoplan2_doc_to_text
9+
doc_to_target: "ground_truth"
10+
generation_kwargs:
11+
max_new_tokens: 4096
12+
temperature: 0
13+
top_p: 1.0
14+
num_beams: 1
15+
do_sample: false
16+
process_results: !function utils.egoplan2_process_results
17+
metric_list:
18+
- metric: egoplan2_mcq_accuracy
19+
aggregation: !function utils.egoplan2_aggregate_results
20+
higher_is_better: true
21+
lmms_eval_specific_kwargs:
22+
default:
23+
pre_prompt: ""
24+
post_prompt: "\nAnswer with the option's letter from the given choices directly."
25+
gpt4v:
26+
pre_prompt: ""
27+
post_prompt: "\nAnswer the question with A, B, C, or D."
28+
xcomposer2_4khd:
29+
pre_prompt: "[UNUSED_TOKEN_146]user\n"
30+
post_prompt: " Answer this question with A, B, C, or D.[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"
31+
metadata:
32+
version: 0.1

lmms_eval/tasks/egoplan2/utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import re
2+
3+
import datasets
4+
5+
egoplan2_features = datasets.Features(
6+
{
7+
"sample_id": datasets.Value("string"),
8+
"domain": datasets.Value("string"),
9+
"task_goal": datasets.Value("string"),
10+
"task_start_frame": datasets.Value("int64"),
11+
"current_observation_frame": datasets.Value("int64"),
12+
"formatted_question": datasets.Value("string"),
13+
"choice_a": datasets.Value("string"),
14+
"choice_b": datasets.Value("string"),
15+
"choice_c": datasets.Value("string"),
16+
"choice_d": datasets.Value("string"),
17+
"ground_truth": datasets.Value("string"),
18+
"video_file": datasets.Value("string"),
19+
"keyframes": datasets.Sequence(datasets.Image(decode=True)),
20+
}
21+
)
22+
23+
24+
def egoplan2_doc_to_visual(doc):
25+
return doc["keyframes"]
26+
27+
28+
def egoplan2_doc_to_text(doc, lmms_eval_specific_kwargs=None):
29+
return doc["formatted_question"]
30+
31+
32+
def extract_characters_regex(s):
33+
s = s.strip()
34+
answer_prefixes = [
35+
"The best answer is",
36+
"The correct answer is",
37+
"The answer is",
38+
"The answer",
39+
"The best option is" "The correct option is",
40+
"Best answer:" "Best option:",
41+
]
42+
for answer_prefix in answer_prefixes:
43+
s = s.replace(answer_prefix, "")
44+
45+
if len(s.split()) > 10 and not re.search("[ABCD]", s):
46+
return ""
47+
48+
matches = re.search(r"[ABCD]", s)
49+
if matches is None:
50+
return ""
51+
return matches[0]
52+
53+
54+
def egoplan2_process_results(doc, results):
55+
pred = results[0]
56+
pred_ans = extract_characters_regex(pred)
57+
# Only keep fields needed for aggregation (exclude keyframes to avoid OOM
58+
# during multi-GPU gather_object which pickles the entire dict).
59+
data_dict = {
60+
"sample_id": doc.get("sample_id"),
61+
"pred_answer": pred_ans,
62+
"ground_truth": doc["ground_truth"],
63+
}
64+
return {"egoplan2_mcq_accuracy": data_dict}
65+
66+
67+
def egoplan2_aggregate_results(results):
68+
correct_num = 0
69+
for result in results:
70+
if result["pred_answer"] == result["ground_truth"]:
71+
correct_num += 1
72+
question_num = len(results)
73+
accuracy = correct_num / question_num
74+
return accuracy

0 commit comments

Comments
 (0)