Skip to content

Commit cfc5a25

Browse files
authored
feat: integrate Point-Bench benchmark task (#1142) (#1157)
1 parent 0b71775 commit cfc5a25

4 files changed

Lines changed: 197 additions & 0 deletions

File tree

docs/current_tasks.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ python -m lmms_eval --tasks list_with_num
484484
- [EmbSpatial](https://github.com/EmbSpatial/EmbSpatial) (embspatial)
485485
- [ERQA](https://github.com/ERQA-Bench/ERQA) (erqa)
486486
- [OmniSpatial](https://omnispatial.github.io/) (omnispatial)
487+
- [Point-Bench](https://pointarena.github.io/) (pointbench)
487488
- [Where2Place](https://where2place.github.io/) (where2place)
488489

489490
---
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dataset_path: json
2+
dataset_kwargs:
3+
data_files:
4+
train: https://huggingface.co/datasets/PointArena/pointarena-data/raw/main/data.json
5+
output_type: generate_until
6+
process_docs: !function utils.pointbench_process_docs
7+
doc_to_visual: !function utils.pointbench_doc_to_visual
8+
doc_to_text: !function utils.pointbench_doc_to_text
9+
doc_to_target: ""
10+
process_results: !function utils.pointbench_process_results
11+
12+
metric_list:
13+
- metric: pointbench_acc
14+
aggregation: !function utils.pointbench_aggregate_results
15+
higher_is_better: true
16+
17+
generation_kwargs:
18+
max_new_tokens: 256
19+
20+
metadata:
21+
source_dataset: PointArena/pointarena-data
22+
prompt_suffix_type: "0_999"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
task: pointbench
2+
test_split: train
3+
include: _default_template_yaml
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import re
2+
import zipfile
3+
from functools import lru_cache
4+
from io import BytesIO
5+
from typing import Any, Dict, List
6+
7+
import datasets
8+
import numpy as np
9+
import requests
10+
from PIL import Image
11+
12+
from lmms_eval.tasks._task_utils.default_template_yaml import load_default_template_yaml
13+
from lmms_eval.utils import eval_logger
14+
15+
POINTARENA_REPO = "PointArena/pointarena-data"
16+
POINTARENA_ROWS_API = "https://datasets-server.huggingface.co/rows"
17+
18+
PROMPT_SUFFIX_0_999 = "Your answer should be formatted as a list of tuples, i.e. [(x1, y1), (x2, y2), ...], where each tuple contains the x and y coordinates of a point satisfying the conditions above. The coordinates should be integers between 0 and 999, representing the pixel locations scaled to a 1000x1000 grid."
19+
PROMPT_SUFFIX_ORIGINAL = "Your answer should be formatted as a list of tuples, i.e. [(x1, y1), (x2, y2), ...], where each tuple contains the x and y coordinates of a point satisfying the conditions above. The coordinates should be between 0 and 1, indicating the normalized pixel locations of the points in the image."
20+
FORMAT = "Return only list of tuples, do not add anything else."
21+
22+
config = load_default_template_yaml(__file__)
23+
24+
25+
def pointbench_process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
26+
return dataset.map(lambda _, idx: {"question_id": idx, "row_idx": idx}, with_indices=True)
27+
28+
29+
def pointbench_doc_to_text(doc: Dict[str, Any], lmms_eval_specific_kwargs: Dict[str, Any] | None = None) -> str:
30+
prompt_suffix_type = config.get("metadata", {}).get("prompt_suffix_type", "0_999")
31+
suffix = PROMPT_SUFFIX_0_999 if prompt_suffix_type == "0_999" else PROMPT_SUFFIX_ORIGINAL
32+
33+
kwargs = lmms_eval_specific_kwargs or {}
34+
pre_prompt = kwargs.get("pre_prompt", "")
35+
post_prompt = kwargs.get("post_prompt", "")
36+
user_input = str(doc.get("user_input", "")).strip()
37+
return f"{pre_prompt}{user_input} {suffix} {FORMAT}{post_prompt}".strip()
38+
39+
40+
@lru_cache(maxsize=4096)
41+
def _get_image_url(row_idx: int) -> str:
42+
response = requests.get(
43+
POINTARENA_ROWS_API,
44+
params={"dataset": POINTARENA_REPO, "config": "default", "split": "train", "offset": int(row_idx), "length": 1},
45+
timeout=30,
46+
)
47+
response.raise_for_status()
48+
payload = response.json()
49+
rows = payload.get("rows", [])
50+
if not rows:
51+
raise ValueError(f"No rows found for row_idx={row_idx}")
52+
return rows[0]["row"]["image"]["src"]
53+
54+
55+
def _load_image(row_idx: int) -> Image.Image:
56+
image_url = _get_image_url(row_idx)
57+
response = requests.get(image_url, timeout=60)
58+
if response.status_code == 403:
59+
_get_image_url.cache_clear()
60+
image_url = _get_image_url(row_idx)
61+
response = requests.get(image_url, timeout=60)
62+
response.raise_for_status()
63+
return Image.open(BytesIO(response.content)).convert("RGB")
64+
65+
66+
def pointbench_doc_to_visual(doc: Dict[str, Any]) -> List[Image.Image]:
67+
row_idx = doc.get("row_idx", doc.get("question_id"))
68+
if row_idx is None:
69+
eval_logger.warning("pointbench: missing row_idx for doc={}", doc.get("image_filename", "unknown"))
70+
return []
71+
72+
try:
73+
image = _load_image(int(row_idx))
74+
except Exception as exc:
75+
eval_logger.warning("pointbench: failed to load image for row_idx={} ({})", row_idx, exc)
76+
return []
77+
return [image]
78+
79+
80+
@lru_cache(maxsize=1)
81+
def _mask_zip_path() -> str:
82+
from huggingface_hub import hf_hub_download
83+
84+
return hf_hub_download(repo_id=POINTARENA_REPO, repo_type="dataset", filename="selected_masks.zip")
85+
86+
87+
@lru_cache(maxsize=1)
88+
def _mask_member_map() -> Dict[str, str]:
89+
mapping: Dict[str, str] = {}
90+
with zipfile.ZipFile(_mask_zip_path()) as archive:
91+
for member in archive.namelist():
92+
if not member.lower().endswith(".png"):
93+
continue
94+
mapping.setdefault(member.rsplit("/", 1)[-1], member)
95+
return mapping
96+
97+
98+
@lru_cache(maxsize=4096)
99+
def _load_mask(mask_filename: str) -> np.ndarray | None:
100+
member = _mask_member_map().get(mask_filename)
101+
if not member:
102+
return None
103+
104+
with zipfile.ZipFile(_mask_zip_path()) as archive:
105+
with archive.open(member) as stream:
106+
mask = Image.open(BytesIO(stream.read())).convert("L")
107+
108+
return (np.array(mask) > 127).astype(np.int32)
109+
110+
111+
def _text_to_points(text: str, width: int, height: int) -> np.ndarray:
112+
pattern = r"\(([-+]?\d*\.?\d+)\s*,\s*([-+]?\d*\.?\d+)\)"
113+
matches = re.findall(pattern, text)
114+
115+
points = []
116+
for x_raw, y_raw in matches:
117+
x = float(x_raw)
118+
y = float(y_raw)
119+
120+
if 0.0 <= x <= 1.0 and 0.0 <= y <= 1.0:
121+
px = int(round(x * width))
122+
py = int(round(y * height))
123+
elif 0.0 <= x <= 1000.0 and 0.0 <= y <= 1000.0:
124+
px = int(round((x / 1000.0) * width))
125+
py = int(round((y / 1000.0) * height))
126+
else:
127+
px = int(round(x))
128+
py = int(round(y))
129+
130+
points.append((px, py))
131+
132+
return np.array(points, dtype=np.int32)
133+
134+
135+
def pointbench_process_results(doc: Dict[str, Any], result: List[str]) -> Dict[str, Dict[str, Any]]:
136+
key_name = "pointbench_acc"
137+
mask_filename = str(doc.get("mask_filename", ""))
138+
mask = _load_mask(mask_filename)
139+
response = result[0] if result else ""
140+
141+
if mask is None:
142+
eval_logger.warning("pointbench: failed to find mask for file={}", mask_filename)
143+
submission = {
144+
"id": doc.get("question_id", doc.get("image_filename", "unknown")),
145+
"pred": response,
146+
"parsed_points": [],
147+
"accuracy": 0.0,
148+
"category": doc.get("category", "unknown"),
149+
}
150+
return {key_name: submission}
151+
152+
points = _text_to_points(response, mask.shape[1], mask.shape[0])
153+
acc = 0.0
154+
if len(points) > 0:
155+
in_range = (points[:, 0] >= 0) & (points[:, 0] < mask.shape[1]) & (points[:, 1] >= 0) & (points[:, 1] < mask.shape[0])
156+
acc = np.concatenate([mask[points[in_range, 1], points[in_range, 0]], np.zeros(points.shape[0] - in_range.sum())]).mean()
157+
158+
submission = {
159+
"id": doc.get("question_id", doc.get("image_filename", "unknown")),
160+
"pred": response,
161+
"parsed_points": list(map(tuple, points.tolist())),
162+
"accuracy": float(acc),
163+
"category": doc.get("category", "unknown"),
164+
}
165+
return {key_name: submission}
166+
167+
168+
def pointbench_aggregate_results(results: List[Dict[str, Any]]) -> float:
169+
if not results:
170+
return 0.0
171+
return float(np.mean([sample.get("accuracy", 0.0) for sample in results]))

0 commit comments

Comments
 (0)