Skip to content

Commit 0a065f3

Browse files
authored
fix: updated sitebench to report sub category score (#1282)
1 parent 2c08ee5 commit 0a065f3

6 files changed

Lines changed: 247 additions & 12 deletions

File tree

lmms_eval/tasks/sitebench/merge_results.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,16 @@ def find_latest_sitebench_files(logs_dir: str) -> tuple[str | None, str | None]:
208208
return image_path, video_path
209209

210210

211+
SUBCATEGORIES = {
212+
"3d information understanding",
213+
"counting & existence",
214+
"movement prediction & navigation",
215+
"multi-view & cross-image reasoning",
216+
"object localization & positioning",
217+
"spatial relationship reasoning",
218+
}
219+
220+
211221
def print_results(name: str, stats: dict, category_stats: dict = None, random_acc: float = None):
212222
"""Print formatted results."""
213223
print(f"\n{'='*60}")
@@ -223,6 +233,14 @@ def print_results(name: str, stats: dict, category_stats: dict = None, random_ac
223233
if random_acc is not None:
224234
print(f"Random Expected Accuracy: {random_acc*100:.2f}%")
225235

236+
# Print sub-category breakdown from metric_stats
237+
metric_stats = stats.get("metric_stats", {})
238+
subcat_stats = {k: v for k, v in metric_stats.items() if k in SUBCATEGORIES}
239+
if subcat_stats:
240+
subcat_df = stats_to_df(subcat_stats, "Sub-Category")
241+
print("\nSub-Category Breakdown:")
242+
print(subcat_df.to_string(index=False))
243+
226244
if category_stats:
227245
cat_df = stats_to_df(category_stats, "Category")
228246
print("\nCategory Breakdown:")
@@ -354,23 +372,35 @@ def main():
354372

355373
# Save to output file if requested
356374
if args.output:
375+
376+
def _stats_to_output(stats_dict: dict) -> dict:
377+
"""Convert a stats dict with acc/caa num/den to output format."""
378+
out = {}
379+
acc = stats_dict["acc_num"] / stats_dict["acc_den"] * 100 if stats_dict["acc_den"] > 0 else 0
380+
caa = stats_dict["caa_num"] / stats_dict["caa_den"] * 100 if stats_dict["caa_den"] > 0 else 0
381+
out["accuracy"] = acc
382+
out["caa"] = caa
383+
out["count"] = int(stats_dict["acc_den"])
384+
return out
385+
386+
def _subcat_output(metric_stats: dict) -> dict:
387+
"""Extract sub-category scores from metric_stats."""
388+
return {k: _stats_to_output(v) for k, v in metric_stats.items() if k in SUBCATEGORIES}
389+
357390
output_data = {
358391
"image": {
359392
"file": image_path,
360-
"accuracy": (image_stats["overall"]["acc_num"] / image_stats["overall"]["acc_den"] * 100 if image_stats["overall"]["acc_den"] > 0 else 0),
361-
"caa": (image_stats["overall"]["caa_num"] / image_stats["overall"]["caa_den"] * 100 if image_stats["overall"]["caa_den"] > 0 else 0),
362-
"count": int(image_stats["overall"]["acc_den"]),
393+
**_stats_to_output(image_stats["overall"]),
394+
"subcategories": _subcat_output(image_stats.get("metric_stats", {})),
363395
},
364396
"video": {
365397
"file": video_path,
366-
"accuracy": (video_stats["overall"]["acc_num"] / video_stats["overall"]["acc_den"] * 100 if video_stats["overall"]["acc_den"] > 0 else 0),
367-
"caa": (video_stats["overall"]["caa_num"] / video_stats["overall"]["caa_den"] * 100 if video_stats["overall"]["caa_den"] > 0 else 0),
368-
"count": int(video_stats["overall"]["acc_den"]),
398+
**_stats_to_output(video_stats["overall"]),
399+
"subcategories": _subcat_output(video_stats.get("metric_stats", {})),
369400
},
370401
"combined": {
371-
"accuracy": (combined_overall["acc_num"] / combined_overall["acc_den"] * 100 if combined_overall["acc_den"] > 0 else 0),
372-
"caa": (combined_overall["caa_num"] / combined_overall["caa_den"] * 100 if combined_overall["caa_den"] > 0 else 0),
373-
"count": int(combined_overall["acc_den"]),
402+
**_stats_to_output(combined_overall),
403+
"subcategories": _subcat_output(combined_metric),
374404
},
375405
}
376406
with open(args.output, "w") as f:

lmms_eval/tasks/sitebench/multi_image_input/site_video_multiimage.yaml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,42 @@ metric_list:
2525
- metric: chance_adjusted_acc
2626
aggregation: !function utils.spatial_aggregate_results
2727
higher_is_better: true
28+
- metric: 3d_information_understanding_acc
29+
aggregation: !function utils.aggregate_3d_information_understanding_acc
30+
higher_is_better: true
31+
- metric: 3d_information_understanding_caa
32+
aggregation: !function utils.aggregate_3d_information_understanding_caa
33+
higher_is_better: true
34+
- metric: counting_and_existence_acc
35+
aggregation: !function utils.aggregate_counting_and_existence_acc
36+
higher_is_better: true
37+
- metric: counting_and_existence_caa
38+
aggregation: !function utils.aggregate_counting_and_existence_caa
39+
higher_is_better: true
40+
- metric: movement_prediction_and_navigation_acc
41+
aggregation: !function utils.aggregate_movement_prediction_and_navigation_acc
42+
higher_is_better: true
43+
- metric: movement_prediction_and_navigation_caa
44+
aggregation: !function utils.aggregate_movement_prediction_and_navigation_caa
45+
higher_is_better: true
46+
- metric: multiview_and_crossimage_reasoning_acc
47+
aggregation: !function utils.aggregate_multiview_and_crossimage_reasoning_acc
48+
higher_is_better: true
49+
- metric: multiview_and_crossimage_reasoning_caa
50+
aggregation: !function utils.aggregate_multiview_and_crossimage_reasoning_caa
51+
higher_is_better: true
52+
- metric: object_localization_and_positioning_acc
53+
aggregation: !function utils.aggregate_object_localization_and_positioning_acc
54+
higher_is_better: true
55+
- metric: object_localization_and_positioning_caa
56+
aggregation: !function utils.aggregate_object_localization_and_positioning_caa
57+
higher_is_better: true
58+
- metric: spatial_relationship_reasoning_acc
59+
aggregation: !function utils.aggregate_spatial_relationship_reasoning_acc
60+
higher_is_better: true
61+
- metric: spatial_relationship_reasoning_caa
62+
aggregation: !function utils.aggregate_spatial_relationship_reasoning_caa
63+
higher_is_better: true
2864
lmms_eval_specific_kwargs:
2965
default:
3066
post_prompt: "Give me the answer letter directly. The best answer is:"

lmms_eval/tasks/sitebench/multi_image_input/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@
77

88
from lmms_eval.tasks.sitebench.utils import (
99
UpperLetters,
10+
aggregate_3d_information_understanding_acc,
11+
aggregate_3d_information_understanding_caa,
12+
aggregate_counting_and_existence_acc,
13+
aggregate_counting_and_existence_caa,
14+
aggregate_movement_prediction_and_navigation_acc,
15+
aggregate_movement_prediction_and_navigation_caa,
16+
aggregate_multiview_and_crossimage_reasoning_acc,
17+
aggregate_multiview_and_crossimage_reasoning_caa,
18+
aggregate_object_localization_and_positioning_acc,
19+
aggregate_object_localization_and_positioning_caa,
20+
aggregate_spatial_relationship_reasoning_acc,
21+
aggregate_spatial_relationship_reasoning_caa,
1022
base_cache_dir,
1123
cache_name,
1224
spatial_aggregate_results,

lmms_eval/tasks/sitebench/site_image.yaml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,42 @@ metric_list:
2525
- metric: chance_adjusted_acc
2626
aggregation: !function utils.spatial_aggregate_results
2727
higher_is_better: true
28+
- metric: 3d_information_understanding_acc
29+
aggregation: !function utils.aggregate_3d_information_understanding_acc
30+
higher_is_better: true
31+
- metric: 3d_information_understanding_caa
32+
aggregation: !function utils.aggregate_3d_information_understanding_caa
33+
higher_is_better: true
34+
- metric: counting_and_existence_acc
35+
aggregation: !function utils.aggregate_counting_and_existence_acc
36+
higher_is_better: true
37+
- metric: counting_and_existence_caa
38+
aggregation: !function utils.aggregate_counting_and_existence_caa
39+
higher_is_better: true
40+
- metric: movement_prediction_and_navigation_acc
41+
aggregation: !function utils.aggregate_movement_prediction_and_navigation_acc
42+
higher_is_better: true
43+
- metric: movement_prediction_and_navigation_caa
44+
aggregation: !function utils.aggregate_movement_prediction_and_navigation_caa
45+
higher_is_better: true
46+
- metric: multiview_and_crossimage_reasoning_acc
47+
aggregation: !function utils.aggregate_multiview_and_crossimage_reasoning_acc
48+
higher_is_better: true
49+
- metric: multiview_and_crossimage_reasoning_caa
50+
aggregation: !function utils.aggregate_multiview_and_crossimage_reasoning_caa
51+
higher_is_better: true
52+
- metric: object_localization_and_positioning_acc
53+
aggregation: !function utils.aggregate_object_localization_and_positioning_acc
54+
higher_is_better: true
55+
- metric: object_localization_and_positioning_caa
56+
aggregation: !function utils.aggregate_object_localization_and_positioning_caa
57+
higher_is_better: true
58+
- metric: spatial_relationship_reasoning_acc
59+
aggregation: !function utils.aggregate_spatial_relationship_reasoning_acc
60+
higher_is_better: true
61+
- metric: spatial_relationship_reasoning_caa
62+
aggregation: !function utils.aggregate_spatial_relationship_reasoning_caa
63+
higher_is_better: true
2864
lmms_eval_specific_kwargs:
2965
default:
3066
pre_prompt: ""

lmms_eval/tasks/sitebench/site_video.yaml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,42 @@ metric_list:
2525
- metric: chance_adjusted_acc
2626
aggregation: !function utils.spatial_aggregate_results
2727
higher_is_better: true
28+
- metric: 3d_information_understanding_acc
29+
aggregation: !function utils.aggregate_3d_information_understanding_acc
30+
higher_is_better: true
31+
- metric: 3d_information_understanding_caa
32+
aggregation: !function utils.aggregate_3d_information_understanding_caa
33+
higher_is_better: true
34+
- metric: counting_and_existence_acc
35+
aggregation: !function utils.aggregate_counting_and_existence_acc
36+
higher_is_better: true
37+
- metric: counting_and_existence_caa
38+
aggregation: !function utils.aggregate_counting_and_existence_caa
39+
higher_is_better: true
40+
- metric: movement_prediction_and_navigation_acc
41+
aggregation: !function utils.aggregate_movement_prediction_and_navigation_acc
42+
higher_is_better: true
43+
- metric: movement_prediction_and_navigation_caa
44+
aggregation: !function utils.aggregate_movement_prediction_and_navigation_caa
45+
higher_is_better: true
46+
- metric: multiview_and_crossimage_reasoning_acc
47+
aggregation: !function utils.aggregate_multiview_and_crossimage_reasoning_acc
48+
higher_is_better: true
49+
- metric: multiview_and_crossimage_reasoning_caa
50+
aggregation: !function utils.aggregate_multiview_and_crossimage_reasoning_caa
51+
higher_is_better: true
52+
- metric: object_localization_and_positioning_acc
53+
aggregation: !function utils.aggregate_object_localization_and_positioning_acc
54+
higher_is_better: true
55+
- metric: object_localization_and_positioning_caa
56+
aggregation: !function utils.aggregate_object_localization_and_positioning_caa
57+
higher_is_better: true
58+
- metric: spatial_relationship_reasoning_acc
59+
aggregation: !function utils.aggregate_spatial_relationship_reasoning_acc
60+
higher_is_better: true
61+
- metric: spatial_relationship_reasoning_caa
62+
aggregation: !function utils.aggregate_spatial_relationship_reasoning_caa
63+
higher_is_better: true
2864
lmms_eval_specific_kwargs:
2965
default:
3066
post_prompt: "Give me the answer letter directly. The best answer is:"

lmms_eval/tasks/sitebench/utils.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,21 @@
1414
"counting & existence",
1515
"spatial relationship reasoning",
1616
"object localization & positioning",
17-
"depth & 3d understanding",
18-
"movement navigation & intent prediction",
17+
"3d information understanding",
18+
"movement prediction & navigation",
1919
"multi-view & cross-image reasoning",
2020
}
2121

22+
# Mapping from category name to metric key suffix
23+
CATEGORY_TO_METRIC_KEY = {
24+
"3d information understanding": "3d_information_understanding",
25+
"counting & existence": "counting_and_existence",
26+
"movement prediction & navigation": "movement_prediction_and_navigation",
27+
"multi-view & cross-image reasoning": "multiview_and_crossimage_reasoning",
28+
"object localization & positioning": "object_localization_and_positioning",
29+
"spatial relationship reasoning": "spatial_relationship_reasoning",
30+
}
31+
2232
# Get the cache directory from the config file
2333
hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/")
2434
# cache_dir = os.path.join(hf_home, cache_dir)
@@ -228,11 +238,18 @@ def spatial_process_results(doc, results):
228238
"total": 1.0 - 1.0 / len(all_choices),
229239
}
230240

231-
return {
241+
result = {
232242
"accuracy": accuracy_dict,
233243
"chance_adjusted_acc": chance_adjusted_accuracy_dict,
234244
}
235245

246+
# Per-category accuracy and chance-adjusted accuracy
247+
for cat_name, metric_key in CATEGORY_TO_METRIC_KEY.items():
248+
result[f"{metric_key}_acc"] = {"score": score, "category": category, "target_category": cat_name}
249+
result[f"{metric_key}_caa"] = {"score": adjusted_score, "category": category, "target_category": cat_name, "total": 1.0 - 1.0 / len(all_choices)}
250+
251+
return result
252+
236253

237254
def spatial_aggregate_results(results):
238255
total_correct, total_examples = 0, 0
@@ -275,3 +292,71 @@ def spatial_aggregate_results(results):
275292
# f.write("=" * 50 + "\n")
276293

277294
return round(overall_accuracy, 5)
295+
296+
297+
def _aggregate_category_acc(results, target_category: str) -> float:
298+
total_correct = 0
299+
total_examples = 0
300+
for r in results:
301+
if r["category"] == target_category:
302+
total_correct += r["score"]
303+
total_examples += 1
304+
return round((total_correct / total_examples) * 100, 5) if total_examples > 0 else 0.0
305+
306+
307+
def _aggregate_category_caa(results, target_category: str) -> float:
308+
total_adjusted = 0.0
309+
total_baseline = 0.0
310+
for r in results:
311+
if r["category"] == target_category:
312+
total_adjusted += r["score"]
313+
total_baseline += r["total"]
314+
return round((total_adjusted / total_baseline) * 100, 5) if total_baseline > 0 else 0.0
315+
316+
317+
def aggregate_3d_information_understanding_acc(results):
318+
return _aggregate_category_acc(results, "3d information understanding")
319+
320+
321+
def aggregate_3d_information_understanding_caa(results):
322+
return _aggregate_category_caa(results, "3d information understanding")
323+
324+
325+
def aggregate_counting_and_existence_acc(results):
326+
return _aggregate_category_acc(results, "counting & existence")
327+
328+
329+
def aggregate_counting_and_existence_caa(results):
330+
return _aggregate_category_caa(results, "counting & existence")
331+
332+
333+
def aggregate_movement_prediction_and_navigation_acc(results):
334+
return _aggregate_category_acc(results, "movement prediction & navigation")
335+
336+
337+
def aggregate_movement_prediction_and_navigation_caa(results):
338+
return _aggregate_category_caa(results, "movement prediction & navigation")
339+
340+
341+
def aggregate_multiview_and_crossimage_reasoning_acc(results):
342+
return _aggregate_category_acc(results, "multi-view & cross-image reasoning")
343+
344+
345+
def aggregate_multiview_and_crossimage_reasoning_caa(results):
346+
return _aggregate_category_caa(results, "multi-view & cross-image reasoning")
347+
348+
349+
def aggregate_object_localization_and_positioning_acc(results):
350+
return _aggregate_category_acc(results, "object localization & positioning")
351+
352+
353+
def aggregate_object_localization_and_positioning_caa(results):
354+
return _aggregate_category_caa(results, "object localization & positioning")
355+
356+
357+
def aggregate_spatial_relationship_reasoning_acc(results):
358+
return _aggregate_category_acc(results, "spatial relationship reasoning")
359+
360+
361+
def aggregate_spatial_relationship_reasoning_caa(results):
362+
return _aggregate_category_caa(results, "spatial relationship reasoning")

0 commit comments

Comments
 (0)