|
| 1 | +import random |
| 2 | +import re |
| 3 | +from typing import Dict, List, Tuple |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +assertion_prompt = """Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A or B.""" |
| 8 | + |
| 9 | +mcq_prompt = """Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, or D.""" |
| 10 | + |
| 11 | + |
| 12 | +def csbench_mcq_doc_to_text(doc: Dict, lmms_eval_specific_kwargs: Dict) -> str: |
| 13 | + q = doc["Question"] |
| 14 | + a = doc["A"] |
| 15 | + b = doc["B"] |
| 16 | + c = doc["C"] |
| 17 | + d = doc["D"] |
| 18 | + question = f"{assertion_prompt}\nQuestion: {q}\nA: {a}\nB: {b}\nC: {c}\nD: {d}\n" |
| 19 | + return question |
| 20 | + |
| 21 | + |
| 22 | +def csbench_assertion_doc_to_text(doc: Dict, lmms_eval_specific_kwargs: Dict) -> str: |
| 23 | + q = doc["Question"] |
| 24 | + question = f"{assertion_prompt}\nQuestion: {q}\n A: True\n B: False\n" |
| 25 | + return question |
| 26 | + |
| 27 | + |
| 28 | +def csbench_doc_to_target(doc: Dict) -> str: |
| 29 | + if doc["Format"].strip() == "Multiple-choice": |
| 30 | + return doc["Answer"].strip().upper() |
| 31 | + else: |
| 32 | + return "A" if doc["Answer"].strip() == "True" else "B" |
| 33 | + |
| 34 | + |
| 35 | +def csbench_doc_to_choice(doc: Dict) -> List[str]: |
| 36 | + if doc["Format"].strip() == "Multiple-choice": |
| 37 | + return ["A", "B", "C", "D"] |
| 38 | + else: |
| 39 | + return ["A", "B"] |
| 40 | + |
| 41 | + |
| 42 | +def parse_multi_choice_response(response, all_choices): |
| 43 | + """ |
| 44 | + Parse the prediction from the generated response. |
| 45 | + Return the predicted choice letter e.g., A, B, C, D. |
| 46 | + """ |
| 47 | + # Clean response of unwanted characters |
| 48 | + for char in [",", ".", "!", "?", ";", ":", "'"]: |
| 49 | + response = response.strip(char) |
| 50 | + response = " " + response + " " # Add space to avoid partial match |
| 51 | + |
| 52 | + candidates = [] |
| 53 | + # Look for choices with parentheses, e.g., (A) |
| 54 | + for choice in all_choices: |
| 55 | + if f"({choice})" in response: |
| 56 | + candidates.append(choice) |
| 57 | + |
| 58 | + # Look for simple choices, e.g., A, B, C |
| 59 | + if len(candidates) == 0: |
| 60 | + for choice in all_choices: |
| 61 | + if f" {choice} " in response: |
| 62 | + candidates.append(choice) |
| 63 | + |
| 64 | + # Look for choices with periods, e.g., A., B., C. |
| 65 | + if len(candidates) == 0: |
| 66 | + for choice in all_choices: |
| 67 | + if f"{choice}." in response: |
| 68 | + candidates.append(choice) |
| 69 | + |
| 70 | + # If no candidates, randomly choose one |
| 71 | + if len(candidates) == 0: |
| 72 | + pred_index = random.choice(all_choices) |
| 73 | + elif len(candidates) > 1: |
| 74 | + # If more than one candidate, choose the last one found |
| 75 | + start_indexes = [response.rfind(f" {can} ") for can in candidates] |
| 76 | + pred_index = candidates[np.argmax(start_indexes)] |
| 77 | + else: |
| 78 | + # If only one candidate, use it |
| 79 | + pred_index = candidates[0] |
| 80 | + |
| 81 | + return pred_index |
| 82 | + |
| 83 | + |
| 84 | +def csbench_process_results(doc: Dict, result: List[str]) -> Dict[str, float]: |
| 85 | + pred = parse_multi_choice_response(result[0], csbench_doc_to_choice(doc)) |
| 86 | + gt = csbench_doc_to_target(doc) |
| 87 | + score = 1.0 if pred == gt else 0.0 |
| 88 | + return {"accuracy": score} |
0 commit comments