|
| 1 | +import ast |
| 2 | +import json |
| 3 | +import re |
| 4 | +from collections import defaultdict |
| 5 | +from io import BytesIO |
| 6 | +from typing import Any, Optional |
| 7 | + |
| 8 | +from loguru import logger as eval_logger |
| 9 | +from PIL import Image |
| 10 | + |
| 11 | + |
| 12 | +def _coerce_bool(value: Any) -> Optional[bool]: |
| 13 | + if isinstance(value, bool): |
| 14 | + return value |
| 15 | + |
| 16 | + if isinstance(value, str): |
| 17 | + normalized = value.strip().lower() |
| 18 | + if normalized in {"true", "yes", "y", "1"}: |
| 19 | + return True |
| 20 | + if normalized in {"false", "no", "n", "0"}: |
| 21 | + return False |
| 22 | + |
| 23 | + return None |
| 24 | + |
| 25 | + |
| 26 | +def _parse_bool_from_serialized(candidate: str) -> Optional[bool]: |
| 27 | + candidate = candidate.strip() |
| 28 | + if not candidate: |
| 29 | + return None |
| 30 | + |
| 31 | + for parser in (json.loads, ast.literal_eval): |
| 32 | + try: |
| 33 | + parsed = parser(candidate) |
| 34 | + except Exception: |
| 35 | + continue |
| 36 | + |
| 37 | + if isinstance(parsed, dict): |
| 38 | + for key in ("answer", "Answer", "ANSWER"): |
| 39 | + if key in parsed: |
| 40 | + return _coerce_bool(parsed[key]) |
| 41 | + else: |
| 42 | + parsed_bool = _coerce_bool(parsed) |
| 43 | + if parsed_bool is not None: |
| 44 | + return parsed_bool |
| 45 | + |
| 46 | + return None |
| 47 | + |
| 48 | + |
| 49 | +def _parse_bool_from_response(response: str) -> Optional[bool]: |
| 50 | + if not response: |
| 51 | + return None |
| 52 | + |
| 53 | + cleaned = response.strip() |
| 54 | + |
| 55 | + direct = _coerce_bool(cleaned.strip("`\"' ")) |
| 56 | + if direct is not None: |
| 57 | + return direct |
| 58 | + |
| 59 | + serialized_candidates = [cleaned] |
| 60 | + serialized_candidates.extend(re.findall(r"```(?:json)?\s*(.*?)\s*```", cleaned, flags=re.IGNORECASE | re.DOTALL)) |
| 61 | + serialized_candidates.extend(re.findall(r"\{[\s\S]*?\}", cleaned)) |
| 62 | + |
| 63 | + for candidate in serialized_candidates: |
| 64 | + parsed = _parse_bool_from_serialized(candidate) |
| 65 | + if parsed is not None: |
| 66 | + return parsed |
| 67 | + |
| 68 | + lowered = cleaned.lower() |
| 69 | + |
| 70 | + answer_match = re.search(r'"answer"\s*:\s*(true|false)', lowered) |
| 71 | + if answer_match: |
| 72 | + return answer_match.group(1) == "true" |
| 73 | + |
| 74 | + answer_match = re.search(r"\banswer\s*[:=]\s*(true|false)\b", lowered) |
| 75 | + if answer_match: |
| 76 | + return answer_match.group(1) == "true" |
| 77 | + |
| 78 | + token_match = re.search(r"\b(true|false|yes|no)\b", lowered) |
| 79 | + if token_match: |
| 80 | + return _coerce_bool(token_match.group(1)) |
| 81 | + |
| 82 | + return None |
| 83 | + |
| 84 | + |
| 85 | +def viverbench_doc_to_visual(doc: dict[str, Any]) -> list: |
| 86 | + visuals = [] |
| 87 | + for image_bytes in doc.get("img", []): |
| 88 | + image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| 89 | + visuals.append(image) |
| 90 | + return visuals |
| 91 | + |
| 92 | + |
| 93 | +def viverbench_doc_to_text(doc: dict[str, Any], lmms_eval_specific_kwargs: Optional[dict[str, Any]] = None) -> str: |
| 94 | + if lmms_eval_specific_kwargs is None: |
| 95 | + lmms_eval_specific_kwargs = {} |
| 96 | + |
| 97 | + pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") |
| 98 | + post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") |
| 99 | + question = doc["question"] |
| 100 | + |
| 101 | + text_parts = [part for part in (pre_prompt, question, post_prompt) if part] |
| 102 | + return "\n".join(text_parts) |
| 103 | + |
| 104 | + |
| 105 | +def viverbench_process_results(doc: dict[str, Any], results: list[str]) -> dict[str, dict[str, Any]]: |
| 106 | + response = results[0] if results else "" |
| 107 | + pred_answer = _parse_bool_from_response(response) |
| 108 | + target_answer = bool(doc["answer"]) |
| 109 | + |
| 110 | + submission = { |
| 111 | + "prompt_id": doc.get("prompt_id", ""), |
| 112 | + "task": doc.get("task", "unknown"), |
| 113 | + "target_answer": target_answer, |
| 114 | + "pred_answer": pred_answer, |
| 115 | + "raw_response": response, |
| 116 | + "is_correct": pred_answer is not None and pred_answer == target_answer, |
| 117 | + } |
| 118 | + |
| 119 | + return {"viverbench_acc": submission} |
| 120 | + |
| 121 | + |
| 122 | +def viverbench_aggregate_results(results: list[dict[str, Any]]) -> float: |
| 123 | + if not results: |
| 124 | + return 0.0 |
| 125 | + |
| 126 | + by_task = defaultdict(lambda: {"correct": 0, "total": 0}) |
| 127 | + total_correct = 0 |
| 128 | + |
| 129 | + for result in results: |
| 130 | + task = result.get("task", "unknown") |
| 131 | + is_correct = bool(result.get("is_correct", False)) |
| 132 | + |
| 133 | + by_task[task]["total"] += 1 |
| 134 | + if is_correct: |
| 135 | + by_task[task]["correct"] += 1 |
| 136 | + total_correct += 1 |
| 137 | + |
| 138 | + for task in sorted(by_task.keys()): |
| 139 | + stats = by_task[task] |
| 140 | + task_acc = stats["correct"] / stats["total"] if stats["total"] else 0.0 |
| 141 | + eval_logger.info(f"ViVerBench - {task}: {task_acc:.4f} ({stats['correct']}/{stats['total']})") |
| 142 | + |
| 143 | + overall_acc = total_correct / len(results) |
| 144 | + eval_logger.info(f"ViVerBench - overall: {overall_acc:.4f} ({total_correct}/{len(results)})") |
| 145 | + return overall_acc |
0 commit comments