|
14 | 14 | "counting & existence", |
15 | 15 | "spatial relationship reasoning", |
16 | 16 | "object localization & positioning", |
17 | | - "depth & 3d understanding", |
18 | | - "movement navigation & intent prediction", |
| 17 | + "3d information understanding", |
| 18 | + "movement prediction & navigation", |
19 | 19 | "multi-view & cross-image reasoning", |
20 | 20 | } |
21 | 21 |
|
| 22 | +# Mapping from category name to metric key suffix |
| 23 | +CATEGORY_TO_METRIC_KEY = { |
| 24 | + "3d information understanding": "3d_information_understanding", |
| 25 | + "counting & existence": "counting_and_existence", |
| 26 | + "movement prediction & navigation": "movement_prediction_and_navigation", |
| 27 | + "multi-view & cross-image reasoning": "multiview_and_crossimage_reasoning", |
| 28 | + "object localization & positioning": "object_localization_and_positioning", |
| 29 | + "spatial relationship reasoning": "spatial_relationship_reasoning", |
| 30 | +} |
| 31 | + |
22 | 32 | # Get the cache directory from the config file |
23 | 33 | hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/") |
24 | 34 | # cache_dir = os.path.join(hf_home, cache_dir) |
@@ -228,11 +238,18 @@ def spatial_process_results(doc, results): |
228 | 238 | "total": 1.0 - 1.0 / len(all_choices), |
229 | 239 | } |
230 | 240 |
|
231 | | - return { |
| 241 | + result = { |
232 | 242 | "accuracy": accuracy_dict, |
233 | 243 | "chance_adjusted_acc": chance_adjusted_accuracy_dict, |
234 | 244 | } |
235 | 245 |
|
| 246 | + # Per-category accuracy and chance-adjusted accuracy |
| 247 | + for cat_name, metric_key in CATEGORY_TO_METRIC_KEY.items(): |
| 248 | + result[f"{metric_key}_acc"] = {"score": score, "category": category, "target_category": cat_name} |
| 249 | + result[f"{metric_key}_caa"] = {"score": adjusted_score, "category": category, "target_category": cat_name, "total": 1.0 - 1.0 / len(all_choices)} |
| 250 | + |
| 251 | + return result |
| 252 | + |
236 | 253 |
|
237 | 254 | def spatial_aggregate_results(results): |
238 | 255 | total_correct, total_examples = 0, 0 |
@@ -275,3 +292,71 @@ def spatial_aggregate_results(results): |
275 | 292 | # f.write("=" * 50 + "\n") |
276 | 293 |
|
277 | 294 | return round(overall_accuracy, 5) |
| 295 | + |
| 296 | + |
| 297 | +def _aggregate_category_acc(results, target_category: str) -> float: |
| 298 | + total_correct = 0 |
| 299 | + total_examples = 0 |
| 300 | + for r in results: |
| 301 | + if r["category"] == target_category: |
| 302 | + total_correct += r["score"] |
| 303 | + total_examples += 1 |
| 304 | + return round((total_correct / total_examples) * 100, 5) if total_examples > 0 else 0.0 |
| 305 | + |
| 306 | + |
| 307 | +def _aggregate_category_caa(results, target_category: str) -> float: |
| 308 | + total_adjusted = 0.0 |
| 309 | + total_baseline = 0.0 |
| 310 | + for r in results: |
| 311 | + if r["category"] == target_category: |
| 312 | + total_adjusted += r["score"] |
| 313 | + total_baseline += r["total"] |
| 314 | + return round((total_adjusted / total_baseline) * 100, 5) if total_baseline > 0 else 0.0 |
| 315 | + |
| 316 | + |
| 317 | +def aggregate_3d_information_understanding_acc(results): |
| 318 | + return _aggregate_category_acc(results, "3d information understanding") |
| 319 | + |
| 320 | + |
| 321 | +def aggregate_3d_information_understanding_caa(results): |
| 322 | + return _aggregate_category_caa(results, "3d information understanding") |
| 323 | + |
| 324 | + |
| 325 | +def aggregate_counting_and_existence_acc(results): |
| 326 | + return _aggregate_category_acc(results, "counting & existence") |
| 327 | + |
| 328 | + |
| 329 | +def aggregate_counting_and_existence_caa(results): |
| 330 | + return _aggregate_category_caa(results, "counting & existence") |
| 331 | + |
| 332 | + |
| 333 | +def aggregate_movement_prediction_and_navigation_acc(results): |
| 334 | + return _aggregate_category_acc(results, "movement prediction & navigation") |
| 335 | + |
| 336 | + |
| 337 | +def aggregate_movement_prediction_and_navigation_caa(results): |
| 338 | + return _aggregate_category_caa(results, "movement prediction & navigation") |
| 339 | + |
| 340 | + |
| 341 | +def aggregate_multiview_and_crossimage_reasoning_acc(results): |
| 342 | + return _aggregate_category_acc(results, "multi-view & cross-image reasoning") |
| 343 | + |
| 344 | + |
| 345 | +def aggregate_multiview_and_crossimage_reasoning_caa(results): |
| 346 | + return _aggregate_category_caa(results, "multi-view & cross-image reasoning") |
| 347 | + |
| 348 | + |
| 349 | +def aggregate_object_localization_and_positioning_acc(results): |
| 350 | + return _aggregate_category_acc(results, "object localization & positioning") |
| 351 | + |
| 352 | + |
| 353 | +def aggregate_object_localization_and_positioning_caa(results): |
| 354 | + return _aggregate_category_caa(results, "object localization & positioning") |
| 355 | + |
| 356 | + |
| 357 | +def aggregate_spatial_relationship_reasoning_acc(results): |
| 358 | + return _aggregate_category_acc(results, "spatial relationship reasoning") |
| 359 | + |
| 360 | + |
| 361 | +def aggregate_spatial_relationship_reasoning_caa(results): |
| 362 | + return _aggregate_category_caa(results, "spatial relationship reasoning") |
0 commit comments