Skip to content

Commit 48e8b59

Browse files
KelvinDo183pbcong
andauthored
add csbench (#841)
* add csbench * run precommit --------- Co-authored-by: pbcong <congphamba2005@gmail.com>
1 parent 5f5a82c commit 48e8b59

4 files changed

Lines changed: 142 additions & 0 deletions

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
group: csbench
2+
task:
3+
- csbench_mcq
4+
- csbench_assertion
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
dataset_path: lmms-lab/CSBench_Assertion
2+
dataset_kwargs:
3+
token: True
4+
test_split: assertion
5+
task: "csbench_assertion"
6+
7+
doc_to_text: !function utils.csbench_assertion_doc_to_text
8+
doc_to_target: !function utils.csbench_doc_to_target
9+
doc_to_choice: !function utils.csbench_doc_to_choice
10+
11+
lmms_eval_specific_kwargs:
12+
default:
13+
pre_prompt: ""
14+
post_prompt: ""
15+
16+
metric_list:
17+
- metric: accuracy
18+
aggregation: mean
19+
higher_is_better: true
20+
21+
process_results: !function utils.csbench_process_results
22+
23+
metadata:
24+
version: 0.0
25+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
dataset_path: lmms-lab/CSBench_MCQ
2+
dataset_kwargs:
3+
token: True
4+
test_split: mcq
5+
task: "csbench_mcq"
6+
7+
doc_to_text: !function utils.csbench_mcq_doc_to_text
8+
doc_to_target: !function utils.csbench_doc_to_target
9+
doc_to_choice: !function utils.csbench_doc_to_choice
10+
11+
lmms_eval_specific_kwargs:
12+
default:
13+
pre_prompt: ""
14+
post_prompt: ""
15+
16+
metric_list:
17+
- metric: accuracy
18+
aggregation: mean
19+
higher_is_better: true
20+
21+
process_results: !function utils.csbench_process_results
22+
23+
metadata:
24+
version: 0.0
25+

lmms_eval/tasks/csbench/utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import random
2+
import re
3+
from typing import Dict, List, Tuple
4+
5+
import numpy as np
6+
7+
assertion_prompt = """Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A or B."""
8+
9+
mcq_prompt = """Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, or D."""
10+
11+
12+
def csbench_mcq_doc_to_text(doc: Dict, lmms_eval_specific_kwargs: Dict) -> str:
13+
q = doc["Question"]
14+
a = doc["A"]
15+
b = doc["B"]
16+
c = doc["C"]
17+
d = doc["D"]
18+
question = f"{assertion_prompt}\nQuestion: {q}\nA: {a}\nB: {b}\nC: {c}\nD: {d}\n"
19+
return question
20+
21+
22+
def csbench_assertion_doc_to_text(doc: Dict, lmms_eval_specific_kwargs: Dict) -> str:
23+
q = doc["Question"]
24+
question = f"{assertion_prompt}\nQuestion: {q}\n A: True\n B: False\n"
25+
return question
26+
27+
28+
def csbench_doc_to_target(doc: Dict) -> str:
29+
if doc["Format"].strip() == "Multiple-choice":
30+
return doc["Answer"].strip().upper()
31+
else:
32+
return "A" if doc["Answer"].strip() == "True" else "B"
33+
34+
35+
def csbench_doc_to_choice(doc: Dict) -> List[str]:
36+
if doc["Format"].strip() == "Multiple-choice":
37+
return ["A", "B", "C", "D"]
38+
else:
39+
return ["A", "B"]
40+
41+
42+
def parse_multi_choice_response(response, all_choices):
43+
"""
44+
Parse the prediction from the generated response.
45+
Return the predicted choice letter e.g., A, B, C, D.
46+
"""
47+
# Clean response of unwanted characters
48+
for char in [",", ".", "!", "?", ";", ":", "'"]:
49+
response = response.strip(char)
50+
response = " " + response + " " # Add space to avoid partial match
51+
52+
candidates = []
53+
# Look for choices with parentheses, e.g., (A)
54+
for choice in all_choices:
55+
if f"({choice})" in response:
56+
candidates.append(choice)
57+
58+
# Look for simple choices, e.g., A, B, C
59+
if len(candidates) == 0:
60+
for choice in all_choices:
61+
if f" {choice} " in response:
62+
candidates.append(choice)
63+
64+
# Look for choices with periods, e.g., A., B., C.
65+
if len(candidates) == 0:
66+
for choice in all_choices:
67+
if f"{choice}." in response:
68+
candidates.append(choice)
69+
70+
# If no candidates, randomly choose one
71+
if len(candidates) == 0:
72+
pred_index = random.choice(all_choices)
73+
elif len(candidates) > 1:
74+
# If more than one candidate, choose the last one found
75+
start_indexes = [response.rfind(f" {can} ") for can in candidates]
76+
pred_index = candidates[np.argmax(start_indexes)]
77+
else:
78+
# If only one candidate, use it
79+
pred_index = candidates[0]
80+
81+
return pred_index
82+
83+
84+
def csbench_process_results(doc: Dict, result: List[str]) -> Dict[str, float]:
85+
pred = parse_multi_choice_response(result[0], csbench_doc_to_choice(doc))
86+
gt = csbench_doc_to_target(doc)
87+
score = 1.0 if pred == gt else 0.0
88+
return {"accuracy": score}

0 commit comments

Comments
 (0)