|
| 1 | +"""Utility functions for the EgoTaskQA-MCQ benchmark. |
| 2 | +
|
| 3 | +Annotations are loaded from the ``nv-njb/EgoTaskQA-MCQ`` HuggingFace dataset. |
| 4 | +Videos are *not* redistributed — users must obtain them through the official |
| 5 | +EgoTaskQA license process at https://sites.google.com/view/egotaskqa. |
| 6 | +
|
| 7 | +The video directory is resolved in this order: |
| 8 | +
|
| 9 | +1. ``EGOTASKQA_VIDEO_DIR`` environment variable, if set. |
| 10 | +2. ``~/.cache/lmms_eval/egotaskqa/videos/`` (default). |
| 11 | +
|
| 12 | +Place the downloaded ``qa_videos/*.mp4`` files in that directory — no rename |
| 13 | +is needed; filenames already match the ``video_path`` field. |
| 14 | +""" |
| 15 | + |
| 16 | +import os |
| 17 | +import re |
| 18 | +from pathlib import Path |
| 19 | + |
| 20 | +from loguru import logger as eval_logger |
| 21 | + |
| 22 | + |
| 23 | +def _video_dir() -> str: |
| 24 | + override = os.environ.get("EGOTASKQA_VIDEO_DIR") |
| 25 | + if override: |
| 26 | + return override |
| 27 | + return str(Path.home() / ".cache" / "lmms_eval" / "egotaskqa" / "videos") |
| 28 | + |
| 29 | + |
| 30 | +def egotaskqa_doc_to_visual(doc): |
| 31 | + video_path = os.path.join(_video_dir(), doc["video_path"]) |
| 32 | + if not os.path.exists(video_path): |
| 33 | + eval_logger.warning( |
| 34 | + f"Video not found: {video_path}. Set EGOTASKQA_VIDEO_DIR or place " |
| 35 | + f"qa_videos/ under ~/.cache/lmms_eval/egotaskqa/videos/." |
| 36 | + ) |
| 37 | + return [video_path] |
| 38 | + |
| 39 | + |
| 40 | +def egotaskqa_doc_to_text(doc, lmms_eval_specific_kwargs=None): |
| 41 | + if lmms_eval_specific_kwargs is None: |
| 42 | + lmms_eval_specific_kwargs = {} |
| 43 | + pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") |
| 44 | + post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") |
| 45 | + |
| 46 | + question = ( |
| 47 | + "Select the best answer to the following multiple-choice question " |
| 48 | + f"based on the video.\n{doc['q']}\nOptions:\n" |
| 49 | + ) |
| 50 | + options = doc["option"] |
| 51 | + for letter in sorted(options.keys()): |
| 52 | + question += f"({letter}) {options[letter]}\n" |
| 53 | + |
| 54 | + return f"{pre_prompt}{question}{post_prompt}" |
| 55 | + |
| 56 | + |
| 57 | +def _extract_answer(response, options): |
| 58 | + letters = sorted(options.keys()) |
| 59 | + |
| 60 | + response = response.replace("answer", "").replace("Answer", "") |
| 61 | + pred_answer = re.findall(r"[\(\ ]*([A-E])[\)\ ]*", response) |
| 62 | + |
| 63 | + if pred_answer: |
| 64 | + pred_letter = pred_answer[0].strip() |
| 65 | + if pred_letter in letters: |
| 66 | + return pred_letter |
| 67 | + |
| 68 | + for letter in letters: |
| 69 | + opt = options[letter].strip().strip(".") |
| 70 | + if opt.lower() in response.lower(): |
| 71 | + return letter |
| 72 | + |
| 73 | + return "" |
| 74 | + |
| 75 | + |
| 76 | +def egotaskqa_process_results(doc, results): |
| 77 | + pred = results[0] |
| 78 | + pred_ans = _extract_answer(pred, doc["option"]) |
| 79 | + return { |
| 80 | + "egotaskqa_accuracy": { |
| 81 | + "pred_answer": pred_ans, |
| 82 | + "ground_truth": doc["a"], |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + |
| 87 | +def egotaskqa_aggregate_results(results): |
| 88 | + correct = sum(1 for r in results if r["pred_answer"] == r["ground_truth"]) |
| 89 | + return correct / len(results) if results else 0 |
0 commit comments