|
| 1 | +import re |
| 2 | +from collections import defaultdict |
| 3 | +from functools import lru_cache |
| 4 | + |
| 5 | +import datasets |
| 6 | +from huggingface_hub import hf_hub_download |
| 7 | +from loguru import logger as eval_logger |
| 8 | +from PIL import Image |
| 9 | + |
| 10 | +MME_CC_DATASET_REPO = "MaxwellWen/MME-CC" |
| 11 | + |
| 12 | +_CODE_FENCE_PATTERN = re.compile(r"```(?:json)?|```", re.IGNORECASE) |
| 13 | + |
| 14 | + |
| 15 | +def _resolve_prompt_kwargs(lmms_eval_specific_kwargs): |
| 16 | + kwargs = lmms_eval_specific_kwargs or {} |
| 17 | + if isinstance(kwargs.get("default"), dict): |
| 18 | + merged_kwargs = dict(kwargs["default"]) |
| 19 | + for key, value in kwargs.items(): |
| 20 | + if key != "default": |
| 21 | + merged_kwargs[key] = value |
| 22 | + return merged_kwargs |
| 23 | + return kwargs |
| 24 | + |
| 25 | + |
| 26 | +def _extract_subtask(doc): |
| 27 | + image_list = doc.get("image_list") |
| 28 | + if isinstance(image_list, list) and image_list: |
| 29 | + first_image_path = str(image_list[0]).strip() |
| 30 | + if "/" in first_image_path: |
| 31 | + return first_image_path.split("/", 1)[0] |
| 32 | + |
| 33 | + extra = doc.get("extra") |
| 34 | + if isinstance(extra, dict): |
| 35 | + for key in ["Subtask", "subtask"]: |
| 36 | + value = extra.get(key) |
| 37 | + if isinstance(value, str) and value.strip(): |
| 38 | + return value.strip().replace(" ", "_") |
| 39 | + |
| 40 | + return "unknown" |
| 41 | + |
| 42 | + |
| 43 | +def _extract_reference_answer(doc): |
| 44 | + ground_truth = doc.get("ground_truth") |
| 45 | + if not isinstance(ground_truth, dict): |
| 46 | + return "" |
| 47 | + |
| 48 | + raw_reference = ground_truth.get("answer", "") |
| 49 | + if not isinstance(raw_reference, str): |
| 50 | + return "" |
| 51 | + |
| 52 | + reference = raw_reference.strip() |
| 53 | + if "## The correct answer is:" in reference: |
| 54 | + reference = reference.split("## The correct answer is:", 1)[1].strip() |
| 55 | + if "## Scoring criteria:" in reference: |
| 56 | + reference = reference.split("## Scoring criteria:", 1)[0].strip() |
| 57 | + return reference |
| 58 | + |
| 59 | + |
| 60 | +def _normalize_answer(text): |
| 61 | + if not isinstance(text, str): |
| 62 | + return "" |
| 63 | + |
| 64 | + normalized = text.strip() |
| 65 | + if "</think>" in normalized: |
| 66 | + normalized = normalized.split("</think>")[-1].strip() |
| 67 | + |
| 68 | + normalized = _CODE_FENCE_PATTERN.sub("", normalized) |
| 69 | + normalized = re.sub(r"\s+", " ", normalized) |
| 70 | + return normalized.strip().casefold() |
| 71 | + |
| 72 | + |
| 73 | +@lru_cache(maxsize=4096) |
| 74 | +def _download_image(image_path): |
| 75 | + return hf_hub_download( |
| 76 | + repo_id=MME_CC_DATASET_REPO, |
| 77 | + repo_type="dataset", |
| 78 | + filename=image_path, |
| 79 | + ) |
| 80 | + |
| 81 | + |
| 82 | +def mme_cc_process_docs(dataset): |
| 83 | + processed_docs = [] |
| 84 | + for doc in dataset: |
| 85 | + updated_doc = dict(doc) |
| 86 | + updated_doc["subtask"] = _extract_subtask(updated_doc) |
| 87 | + updated_doc["target_answer"] = _extract_reference_answer(updated_doc) |
| 88 | + processed_docs.append(updated_doc) |
| 89 | + |
| 90 | + eval_logger.info("[mme_cc] Loaded {} samples", len(processed_docs)) |
| 91 | + return datasets.Dataset.from_list(processed_docs) |
| 92 | + |
| 93 | + |
| 94 | +def mme_cc_doc_to_visual(doc): |
| 95 | + visuals = [] |
| 96 | + for image_path in doc.get("image_list", []): |
| 97 | + if not isinstance(image_path, str) or not image_path.strip(): |
| 98 | + continue |
| 99 | + local_path = _download_image(image_path) |
| 100 | + with Image.open(local_path) as image: |
| 101 | + visuals.append(image.convert("RGB")) |
| 102 | + return visuals |
| 103 | + |
| 104 | + |
| 105 | +def mme_cc_doc_to_text(doc, lmms_eval_specific_kwargs=None): |
| 106 | + prompt_kwargs = _resolve_prompt_kwargs(lmms_eval_specific_kwargs) |
| 107 | + pre_prompt = prompt_kwargs.get("pre_prompt", "") |
| 108 | + post_prompt = prompt_kwargs.get("post_prompt", "") |
| 109 | + prompt = str(doc.get("prompt", "")).strip() |
| 110 | + return f"{pre_prompt}{prompt}{post_prompt}".strip() |
| 111 | + |
| 112 | + |
| 113 | +def mme_cc_doc_to_messages(doc, lmms_eval_specific_kwargs=None): |
| 114 | + prompt = mme_cc_doc_to_text(doc, lmms_eval_specific_kwargs=lmms_eval_specific_kwargs) |
| 115 | + content = [] |
| 116 | + for image in mme_cc_doc_to_visual(doc): |
| 117 | + content.append({"type": "image", "url": image}) |
| 118 | + content.append({"type": "text", "text": prompt}) |
| 119 | + return [{"role": "user", "content": content}] |
| 120 | + |
| 121 | + |
| 122 | +def mme_cc_doc_to_target(doc): |
| 123 | + target_answer = doc.get("target_answer") |
| 124 | + if isinstance(target_answer, str) and target_answer.strip(): |
| 125 | + return target_answer |
| 126 | + return _extract_reference_answer(doc) |
| 127 | + |
| 128 | + |
| 129 | +def mme_cc_process_results(doc, results): |
| 130 | + prediction = results[0] if results else "" |
| 131 | + reference = mme_cc_doc_to_target(doc) |
| 132 | + |
| 133 | + exact_match = 1.0 if _normalize_answer(prediction) == _normalize_answer(reference) else 0.0 |
| 134 | + answered = 1.0 if isinstance(prediction, str) and prediction.strip() else 0.0 |
| 135 | + subtask = str(doc.get("subtask", _extract_subtask(doc))) |
| 136 | + |
| 137 | + return { |
| 138 | + "mme_cc_exact_match": {"score": exact_match, "total": 1.0, "subtask": subtask}, |
| 139 | + "mme_cc_answered_rate": {"score": answered, "total": 1.0}, |
| 140 | + } |
| 141 | + |
| 142 | + |
| 143 | +def _aggregate_score(results): |
| 144 | + total_score = 0.0 |
| 145 | + total_count = 0.0 |
| 146 | + |
| 147 | + for result in results: |
| 148 | + if isinstance(result, dict): |
| 149 | + total_score += float(result.get("score", 0.0)) |
| 150 | + total_count += float(result.get("total", 1.0)) |
| 151 | + else: |
| 152 | + total_score += float(result) |
| 153 | + total_count += 1.0 |
| 154 | + |
| 155 | + if total_count == 0.0: |
| 156 | + return 0.0 |
| 157 | + return total_score / total_count |
| 158 | + |
| 159 | + |
| 160 | +def mme_cc_aggregate_exact_match(results): |
| 161 | + subtask_stats = defaultdict(lambda: {"score": 0.0, "total": 0.0}) |
| 162 | + |
| 163 | + for result in results: |
| 164 | + if not isinstance(result, dict): |
| 165 | + continue |
| 166 | + subtask = str(result.get("subtask", "unknown")) |
| 167 | + subtask_stats[subtask]["score"] += float(result.get("score", 0.0)) |
| 168 | + subtask_stats[subtask]["total"] += float(result.get("total", 1.0)) |
| 169 | + |
| 170 | + for subtask in sorted(subtask_stats): |
| 171 | + total = subtask_stats[subtask]["total"] |
| 172 | + if total == 0.0: |
| 173 | + continue |
| 174 | + score = subtask_stats[subtask]["score"] / total |
| 175 | + eval_logger.info("[mme_cc] {} exact_match: {:.3f} (n={})", subtask, score, int(total)) |
| 176 | + |
| 177 | + return _aggregate_score(results) |
| 178 | + |
| 179 | + |
| 180 | +def mme_cc_aggregate_answered_rate(results): |
| 181 | + return _aggregate_score(results) |
0 commit comments