|
| 1 | +import re |
| 2 | +import zipfile |
| 3 | +from functools import lru_cache |
| 4 | +from io import BytesIO |
| 5 | +from typing import Any, Dict, List |
| 6 | + |
| 7 | +import datasets |
| 8 | +import numpy as np |
| 9 | +import requests |
| 10 | +from PIL import Image |
| 11 | + |
| 12 | +from lmms_eval.tasks._task_utils.default_template_yaml import load_default_template_yaml |
| 13 | +from lmms_eval.utils import eval_logger |
| 14 | + |
| 15 | +POINTARENA_REPO = "PointArena/pointarena-data" |
| 16 | +POINTARENA_ROWS_API = "https://datasets-server.huggingface.co/rows" |
| 17 | + |
| 18 | +PROMPT_SUFFIX_0_999 = "Your answer should be formatted as a list of tuples, i.e. [(x1, y1), (x2, y2), ...], where each tuple contains the x and y coordinates of a point satisfying the conditions above. The coordinates should be integers between 0 and 999, representing the pixel locations scaled to a 1000x1000 grid." |
| 19 | +PROMPT_SUFFIX_ORIGINAL = "Your answer should be formatted as a list of tuples, i.e. [(x1, y1), (x2, y2), ...], where each tuple contains the x and y coordinates of a point satisfying the conditions above. The coordinates should be between 0 and 1, indicating the normalized pixel locations of the points in the image." |
| 20 | +FORMAT = "Return only list of tuples, do not add anything else." |
| 21 | + |
| 22 | +config = load_default_template_yaml(__file__) |
| 23 | + |
| 24 | + |
| 25 | +def pointbench_process_docs(dataset: datasets.Dataset) -> datasets.Dataset: |
| 26 | + return dataset.map(lambda _, idx: {"question_id": idx, "row_idx": idx}, with_indices=True) |
| 27 | + |
| 28 | + |
| 29 | +def pointbench_doc_to_text(doc: Dict[str, Any], lmms_eval_specific_kwargs: Dict[str, Any] | None = None) -> str: |
| 30 | + prompt_suffix_type = config.get("metadata", {}).get("prompt_suffix_type", "0_999") |
| 31 | + suffix = PROMPT_SUFFIX_0_999 if prompt_suffix_type == "0_999" else PROMPT_SUFFIX_ORIGINAL |
| 32 | + |
| 33 | + kwargs = lmms_eval_specific_kwargs or {} |
| 34 | + pre_prompt = kwargs.get("pre_prompt", "") |
| 35 | + post_prompt = kwargs.get("post_prompt", "") |
| 36 | + user_input = str(doc.get("user_input", "")).strip() |
| 37 | + return f"{pre_prompt}{user_input} {suffix} {FORMAT}{post_prompt}".strip() |
| 38 | + |
| 39 | + |
| 40 | +@lru_cache(maxsize=4096) |
| 41 | +def _get_image_url(row_idx: int) -> str: |
| 42 | + response = requests.get( |
| 43 | + POINTARENA_ROWS_API, |
| 44 | + params={"dataset": POINTARENA_REPO, "config": "default", "split": "train", "offset": int(row_idx), "length": 1}, |
| 45 | + timeout=30, |
| 46 | + ) |
| 47 | + response.raise_for_status() |
| 48 | + payload = response.json() |
| 49 | + rows = payload.get("rows", []) |
| 50 | + if not rows: |
| 51 | + raise ValueError(f"No rows found for row_idx={row_idx}") |
| 52 | + return rows[0]["row"]["image"]["src"] |
| 53 | + |
| 54 | + |
| 55 | +def _load_image(row_idx: int) -> Image.Image: |
| 56 | + image_url = _get_image_url(row_idx) |
| 57 | + response = requests.get(image_url, timeout=60) |
| 58 | + if response.status_code == 403: |
| 59 | + _get_image_url.cache_clear() |
| 60 | + image_url = _get_image_url(row_idx) |
| 61 | + response = requests.get(image_url, timeout=60) |
| 62 | + response.raise_for_status() |
| 63 | + return Image.open(BytesIO(response.content)).convert("RGB") |
| 64 | + |
| 65 | + |
| 66 | +def pointbench_doc_to_visual(doc: Dict[str, Any]) -> List[Image.Image]: |
| 67 | + row_idx = doc.get("row_idx", doc.get("question_id")) |
| 68 | + if row_idx is None: |
| 69 | + eval_logger.warning("pointbench: missing row_idx for doc={}", doc.get("image_filename", "unknown")) |
| 70 | + return [] |
| 71 | + |
| 72 | + try: |
| 73 | + image = _load_image(int(row_idx)) |
| 74 | + except Exception as exc: |
| 75 | + eval_logger.warning("pointbench: failed to load image for row_idx={} ({})", row_idx, exc) |
| 76 | + return [] |
| 77 | + return [image] |
| 78 | + |
| 79 | + |
| 80 | +@lru_cache(maxsize=1) |
| 81 | +def _mask_zip_path() -> str: |
| 82 | + from huggingface_hub import hf_hub_download |
| 83 | + |
| 84 | + return hf_hub_download(repo_id=POINTARENA_REPO, repo_type="dataset", filename="selected_masks.zip") |
| 85 | + |
| 86 | + |
| 87 | +@lru_cache(maxsize=1) |
| 88 | +def _mask_member_map() -> Dict[str, str]: |
| 89 | + mapping: Dict[str, str] = {} |
| 90 | + with zipfile.ZipFile(_mask_zip_path()) as archive: |
| 91 | + for member in archive.namelist(): |
| 92 | + if not member.lower().endswith(".png"): |
| 93 | + continue |
| 94 | + mapping.setdefault(member.rsplit("/", 1)[-1], member) |
| 95 | + return mapping |
| 96 | + |
| 97 | + |
| 98 | +@lru_cache(maxsize=4096) |
| 99 | +def _load_mask(mask_filename: str) -> np.ndarray | None: |
| 100 | + member = _mask_member_map().get(mask_filename) |
| 101 | + if not member: |
| 102 | + return None |
| 103 | + |
| 104 | + with zipfile.ZipFile(_mask_zip_path()) as archive: |
| 105 | + with archive.open(member) as stream: |
| 106 | + mask = Image.open(BytesIO(stream.read())).convert("L") |
| 107 | + |
| 108 | + return (np.array(mask) > 127).astype(np.int32) |
| 109 | + |
| 110 | + |
| 111 | +def _text_to_points(text: str, width: int, height: int) -> np.ndarray: |
| 112 | + pattern = r"\(([-+]?\d*\.?\d+)\s*,\s*([-+]?\d*\.?\d+)\)" |
| 113 | + matches = re.findall(pattern, text) |
| 114 | + |
| 115 | + points = [] |
| 116 | + for x_raw, y_raw in matches: |
| 117 | + x = float(x_raw) |
| 118 | + y = float(y_raw) |
| 119 | + |
| 120 | + if 0.0 <= x <= 1.0 and 0.0 <= y <= 1.0: |
| 121 | + px = int(round(x * width)) |
| 122 | + py = int(round(y * height)) |
| 123 | + elif 0.0 <= x <= 1000.0 and 0.0 <= y <= 1000.0: |
| 124 | + px = int(round((x / 1000.0) * width)) |
| 125 | + py = int(round((y / 1000.0) * height)) |
| 126 | + else: |
| 127 | + px = int(round(x)) |
| 128 | + py = int(round(y)) |
| 129 | + |
| 130 | + points.append((px, py)) |
| 131 | + |
| 132 | + return np.array(points, dtype=np.int32) |
| 133 | + |
| 134 | + |
| 135 | +def pointbench_process_results(doc: Dict[str, Any], result: List[str]) -> Dict[str, Dict[str, Any]]: |
| 136 | + key_name = "pointbench_acc" |
| 137 | + mask_filename = str(doc.get("mask_filename", "")) |
| 138 | + mask = _load_mask(mask_filename) |
| 139 | + response = result[0] if result else "" |
| 140 | + |
| 141 | + if mask is None: |
| 142 | + eval_logger.warning("pointbench: failed to find mask for file={}", mask_filename) |
| 143 | + submission = { |
| 144 | + "id": doc.get("question_id", doc.get("image_filename", "unknown")), |
| 145 | + "pred": response, |
| 146 | + "parsed_points": [], |
| 147 | + "accuracy": 0.0, |
| 148 | + "category": doc.get("category", "unknown"), |
| 149 | + } |
| 150 | + return {key_name: submission} |
| 151 | + |
| 152 | + points = _text_to_points(response, mask.shape[1], mask.shape[0]) |
| 153 | + acc = 0.0 |
| 154 | + if len(points) > 0: |
| 155 | + in_range = (points[:, 0] >= 0) & (points[:, 0] < mask.shape[1]) & (points[:, 1] >= 0) & (points[:, 1] < mask.shape[0]) |
| 156 | + acc = np.concatenate([mask[points[in_range, 1], points[in_range, 0]], np.zeros(points.shape[0] - in_range.sum())]).mean() |
| 157 | + |
| 158 | + submission = { |
| 159 | + "id": doc.get("question_id", doc.get("image_filename", "unknown")), |
| 160 | + "pred": response, |
| 161 | + "parsed_points": list(map(tuple, points.tolist())), |
| 162 | + "accuracy": float(acc), |
| 163 | + "category": doc.get("category", "unknown"), |
| 164 | + } |
| 165 | + return {key_name: submission} |
| 166 | + |
| 167 | + |
| 168 | +def pointbench_aggregate_results(results: List[Dict[str, Any]]) -> float: |
| 169 | + if not results: |
| 170 | + return 0.0 |
| 171 | + return float(np.mean([sample.get("accuracy", 0.0) for sample in results])) |
0 commit comments