|
| 1 | +import random |
| 2 | +import re |
| 3 | +from typing import Dict, List, Tuple |
| 4 | + |
| 5 | + |
| 6 | +LETTERS = [chr(65 + i) for i in range(26)] # A-Z |
| 7 | + |
| 8 | + |
| 9 | +def _get_choices(doc: Dict) -> Dict[str, str]: |
| 10 | + # Accept doc["options"] or doc["choices"] as a dict of letters, or list mapped to A..Z, or letter keys at top-level |
| 11 | + if isinstance(doc.get("options"), dict): |
| 12 | + norm = {k.upper(): str(v) for k, v in doc["options"].items() if isinstance(k, str) and len(k) == 1} |
| 13 | + letters_present = [l for l in LETTERS if l in norm] |
| 14 | + if len(letters_present) >= 2: |
| 15 | + return {l: norm[l] for l in letters_present} |
| 16 | + |
| 17 | + if isinstance(doc.get("choices"), dict): |
| 18 | + norm = {k.upper(): str(v) for k, v in doc["choices"].items() if isinstance(k, str) and len(k) == 1} |
| 19 | + letters_present = [l for l in LETTERS if l in norm] |
| 20 | + if len(letters_present) >= 2: |
| 21 | + return {l: norm[l] for l in letters_present} |
| 22 | + |
| 23 | + if isinstance(doc.get("options"), list) and len(doc["options"]) >= 2: |
| 24 | + lst = [str(x) for x in doc["options"]] |
| 25 | + n = min(len(lst), len(LETTERS)) |
| 26 | + return {LETTERS[i]: lst[i] for i in range(n)} |
| 27 | + |
| 28 | + if isinstance(doc.get("choices"), list) and len(doc["choices"]) >= 2: |
| 29 | + lst = [str(x) for x in doc["choices"]] |
| 30 | + n = min(len(lst), len(LETTERS)) |
| 31 | + return {LETTERS[i]: lst[i] for i in range(n)} |
| 32 | + |
| 33 | + letters_found = [l for l in LETTERS if l in doc] |
| 34 | + if len(letters_found) >= 2: |
| 35 | + return {l: str(doc[l]) for l in letters_found} |
| 36 | + |
| 37 | + # Fallback minimal shape |
| 38 | + return {"A": str(doc.get("choice1", "")), "B": str(doc.get("choice2", ""))} |
| 39 | + |
| 40 | + |
| 41 | +def super_gpqa_doc_to_text(doc: Dict, lmms_eval_specific_kwargs: Dict) -> str: |
| 42 | + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] |
| 43 | + post_prompt = lmms_eval_specific_kwargs["post_prompt"] |
| 44 | + q = doc["question"] |
| 45 | + choices = _get_choices(doc) |
| 46 | + lines = [q, ""] |
| 47 | + for letter in sorted(choices.keys()): |
| 48 | + text = choices[letter] |
| 49 | + lines.append(f"{letter}) {text}") |
| 50 | + question = "\n".join(lines) |
| 51 | + return f"{pre_prompt}{question}{post_prompt}" |
| 52 | + |
| 53 | + |
| 54 | +def super_gpqa_doc_to_target(doc: Dict) -> str: |
| 55 | + choices = _get_choices(doc) |
| 56 | + allowed = list(choices.keys()) |
| 57 | + |
| 58 | + # Case 1: answer provided as letter or text string |
| 59 | + if "answer" in doc: |
| 60 | + ans = doc["answer"] |
| 61 | + if isinstance(ans, str): |
| 62 | + # Try to parse a letter |
| 63 | + letter = _extract_answer_letter(ans) |
| 64 | + if letter in allowed: |
| 65 | + return letter |
| 66 | + # Try exact text match |
| 67 | + ans_text = ans.strip() |
| 68 | + for l, t in choices.items(): |
| 69 | + if ans_text == t: |
| 70 | + return l |
| 71 | + else: |
| 72 | + # Non-string answer; try to coerce to index |
| 73 | + try: |
| 74 | + idx = int(ans) |
| 75 | + if 0 <= idx < len(allowed): |
| 76 | + return allowed[idx] |
| 77 | + except Exception: |
| 78 | + pass |
| 79 | + |
| 80 | + # Case 2: explicit index id |
| 81 | + for key in ["answer_id", "label", "gold", "correct_index", "target"]: |
| 82 | + if key in doc: |
| 83 | + try: |
| 84 | + idx = int(doc[key]) |
| 85 | + if 0 <= idx < len(allowed): |
| 86 | + return allowed[idx] |
| 87 | + except Exception: |
| 88 | + continue |
| 89 | + |
| 90 | + # Case 3: if options are top-level letter keys and value of "answer" matches text (handled above) |
| 91 | + |
| 92 | + # Fallback to first option |
| 93 | + return allowed[0] if allowed else "A" |
| 94 | + |
| 95 | + |
| 96 | +def super_gpqa_doc_to_choice(doc: Dict) -> List[str]: |
| 97 | + choices = _get_choices(doc) |
| 98 | + return list(choices.keys()) |
| 99 | + |
| 100 | + |
| 101 | +def _extract_answer_letter(response: str, allowed_letters: List[str] | None = None) -> str | None: |
| 102 | + response = (response or "").strip() |
| 103 | + # Common patterns: "Answer: A", "(A)", "A.", "A)" or lone letter |
| 104 | + patterns = [ |
| 105 | + r"(?i)answer\s*:\s*([A-Z])", |
| 106 | + r"\(([A-Z])\)", |
| 107 | + r"^\s*([A-Z])\s*[\.)\]]", |
| 108 | + r"(?:^|\s)([A-Z])(?:\s|$)", |
| 109 | + ] |
| 110 | + for pattern in patterns: |
| 111 | + m = re.search(pattern, response, flags=re.IGNORECASE) |
| 112 | + if m: |
| 113 | + cand = m.group(1).upper() |
| 114 | + if allowed_letters is None or cand in allowed_letters: |
| 115 | + return cand |
| 116 | + |
| 117 | + letters = re.findall(r"[A-Z]", response.upper()) |
| 118 | + if allowed_letters is not None: |
| 119 | + letters = [l for l in letters if l in allowed_letters] |
| 120 | + if len(letters) == 1: |
| 121 | + return letters[0] |
| 122 | + return None |
| 123 | + |
| 124 | + |
| 125 | +def _parse_multi_choice_response(response: str, all_choices: List[str]) -> str: |
| 126 | + # Clean response of unwanted characters |
| 127 | + for char in [",", ".", "!", "?", ";", ":", "'"]: |
| 128 | + response = response.strip(char) |
| 129 | + response = " " + response + " " # Add space to avoid partial match |
| 130 | + |
| 131 | + candidates: List[str] = [] |
| 132 | + # Look for choices with parentheses, e.g., (A) |
| 133 | + for choice in all_choices: |
| 134 | + if f"({choice})" in response: |
| 135 | + candidates.append(choice) |
| 136 | + |
| 137 | + # Look for simple choices, e.g., A, B, C |
| 138 | + if len(candidates) == 0: |
| 139 | + for choice in all_choices: |
| 140 | + if f" {choice} " in response: |
| 141 | + candidates.append(choice) |
| 142 | + |
| 143 | + # Look for choices with periods, e.g., A., B., C. |
| 144 | + if len(candidates) == 0: |
| 145 | + for choice in all_choices: |
| 146 | + if f"{choice}." in response or f"{choice})" in response: |
| 147 | + candidates.append(choice) |
| 148 | + |
| 149 | + if len(candidates) == 0: |
| 150 | + # Fallback to regex extractor |
| 151 | + letter = _extract_answer_letter(response, allowed_letters=all_choices) |
| 152 | + return letter if letter in all_choices else (all_choices[0] if all_choices else "A") |
| 153 | + elif len(candidates) > 1: |
| 154 | + # If more than one candidate, choose the last occurrence |
| 155 | + start_indexes = [(response.rfind(f" {can} "), can) for can in candidates] |
| 156 | + start_indexes.sort() |
| 157 | + return start_indexes[-1][1] |
| 158 | + else: |
| 159 | + return candidates[0] |
| 160 | + |
| 161 | + |
| 162 | +def super_gpqa_process_results(doc: Dict, result: List[str]) -> Dict[str, float]: |
| 163 | + response = result[0].strip() if result else "" |
| 164 | + all_choices = super_gpqa_doc_to_choice(doc) |
| 165 | + pred = _parse_multi_choice_response(response, all_choices) |
| 166 | + gt = super_gpqa_doc_to_target(doc) |
| 167 | + score = 1.0 if pred == gt else 0.0 |
| 168 | + return {"accuracy": score} |
| 169 | + |
| 170 | + |
| 171 | +# Few-shot multishot builder |
| 172 | +FEWSHOT_PROMPT = ( |
| 173 | + "Answer the following multiple-choice question. There is only one correct answer. The\n" |
| 174 | + "last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.\n" |
| 175 | + "Question: A refracting telescope consists of two converging lenses separated by 100 cm.\n" |
| 176 | + "The eye-piece lens has a focal length of 20 cm. The angular magnification of the telescope\n" |
| 177 | + "is ( ).\n" |
| 178 | + "A) 10\nB) 40\nC) 6\n69\nD) 25\nE) 15\nF) 50\nG) 30\nH) 4\nI) 5\nJ) 20\n" |
| 179 | + "Answer: Let's think step by step. In a refracting telescope, if both lenses are converging,\n" |
| 180 | + "the focus of both lenses must be between the two lenses, and thus the focal lengths of the\n" |
| 181 | + "two lenses must add up to their separation. Since the focal length of one lens is 20 cm, the\n" |
| 182 | + "focal length of the other must be 80 cm. The magnification is the ratio of these two focal\n" |
| 183 | + "lengths, or 4.\n" |
| 184 | + "Answer: H.\n" |
| 185 | + "Question: Say the pupil of your eye has a diameter of 5 mm and you have a telescope\n" |
| 186 | + "with an aperture of 50 cm. How much more light can the telescope gather than your eye?\n" |
| 187 | + "A) 1000 times more\nB) 50 times more\nC) 5000 times more\nD) 500 times more\nE) 10000 times more\nF) 20000 times more\nG) 2000 times more\nH) 100 times more\nI) 10 times more\nJ) N/A\n" |
| 188 | + "Answer: Let's think step by step. The amount of light a telescope can gather compared to the human eye is proportional to the area of its apertures. The area of a circle is given by the formula A = pi * (D/2)^2, where D is the diameter. Therefore, the relative light-gathering power is calculated as:\n" |
| 189 | + "(50 cm / 0.1 cm)^2 / (5 mm / 0.1 cm)^2 = 500^2 / 5^2 = 10000.\n" |
| 190 | + "Answer: E.\n" |
| 191 | + "Question: Where do most short-period comets come from and how do we know?\n" |
| 192 | + "A) The Kuiper belt; short period comets tend to be in the plane of the solar system like the Kuiper belt.\n" |
| 193 | + "B) The asteroid belt; short period comets tend to come from random directions indicating a spherical distribution of comets called the asteroid belt.\n" |
| 194 | + "C) The asteroid belt; short period comets tend to be in the plane of the solar system just like the asteroid belt.\n" |
| 195 | + "D) The Oort cloud; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the Oort cloud.\n" |
| 196 | + "E) The Oort Cloud; short period comets tend to come from random directions indicating a spherical distribution of comets called the Oort Cloud.\n" |
| 197 | + "F) The Oort cloud; short period comets tend to be in the plane of the solar system just like the Oort cloud.\n" |
| 198 | + "G) The asteroid belt; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the asteroid belt.\n" |
| 199 | + "Answer: Let's think step by step. Most short-period comets originate from the Kuiper belt. This is deduced from the observation that these comets tend to follow orbits that lie in the plane of the solar system, similar to the distribution of objects in the Kuiper belt itself. Thus, the alignment of these cometary orbits with the ecliptic plane points to their Kuiper belt origin.\n" |
| 200 | + "Answer: A.\n" |
| 201 | + "Question: Colors in a soap bubble result from light ( ).\n" |
| 202 | + "A) dispersion\nB) deflection\nC) refraction\nD) reflection\nE) interference\nF) converted to a different frequency\nG) polarization\nH) absorption\nI) diffraction\nJ) transmission\n" |
| 203 | + "Answer: Let's think step by step.The colorful patterns observed in a soap bubble are caused by the phenomenon of light interference. This occurs when light waves bounce between the two surfaces of the soap film, combining constructively or destructively based on their phase differences and the varying thickness of the film. These interactions result in vibrant color patterns due to variations in the intensity of different wavelengths of light.\n" |
| 204 | + "Answer: E.\n" |
| 205 | + "Question: A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?\n" |
| 206 | + "A) 240 W\nB) 120 W\nC) 10 W\nD) 480 W\nE) 360 W\nF) 200 W\nG) 30 W\nH) 150 W\nI) 60 W\nJ) 300 W\n" |
| 207 | + "Answer: Let's think step by step. The rate of energy usage, known as power, in an electrical circuit is calculated by the product of voltage and current. For a microwave oven connected to a 120 V outlet and drawing a current of 2 amps, the power consumption can be calculated as follows:\n" |
| 208 | + "Power = Voltage × Current = 120 V × 2 A = 240 W.\n" |
| 209 | + "Therefore, the microwave oven uses energy at a rate of 240 watts.\n" |
| 210 | + "Answer: A.\n" |
| 211 | +) |
| 212 | + |
| 213 | + |
| 214 | +def super_gpqa_multishot_doc_to_text(doc: Dict, lmms_eval_specific_kwargs: Dict) -> str: |
| 215 | + # Build few-shot + current question in the requested format |
| 216 | + q = doc["question"] |
| 217 | + choices = _get_choices(doc) |
| 218 | + current_lines = [f"Question: {q}", ""] |
| 219 | + for letter in sorted(choices.keys()): |
| 220 | + current_lines.append(f"{letter}) {choices[letter]}") |
| 221 | + current_block = "\n".join(current_lines) |
| 222 | + return FEWSHOT_PROMPT + "\n" + current_block + "\nAnswer: Let's think step by step." |
| 223 | + |
| 224 | + |
0 commit comments