Skip to content

Commit cc5267e

Browse files
committed
snapshot download
1 parent 453f106 commit cc5267e

9 files changed

Lines changed: 133 additions & 34 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,4 @@ CLAUDE.md
8383
.worktrees/
8484
Bagel/
8585
MMaDA/
86+
.codex

lmms_eval/api/task.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,8 +1036,11 @@ def _download_from_youtube(path):
10361036
force_unzip = dataset_kwargs.get("force_unzip", False)
10371037
revision = dataset_kwargs.get("revision", "main")
10381038
create_link = dataset_kwargs.get("create_link", False)
1039-
# If the user already has a cache dir, we skip download the zip files
1040-
if not os.path.exists(cache_dir):
1039+
cache_path = None
1040+
# If the user already has a cache dir, we skip downloading archives.
1041+
# Tasks that set create_link need the snapshot path even when the
1042+
# cache dir already exists as a symlink from a previous run.
1043+
if not os.path.exists(cache_dir) or (create_link and os.path.islink(cache_dir)):
10411044
cache_path = snapshot_download(repo_id=self.DATASET_PATH, revision=revision, repo_type="dataset", force_download=force_download, etag_timeout=60)
10421045
zip_files = glob(os.path.join(cache_path, "**/*.zip"), recursive=True)
10431046
tar_files = glob(os.path.join(cache_path, "**/*.tar*"), recursive=True)
@@ -1106,7 +1109,7 @@ def concat_tar_parts(tar_parts, output_tar):
11061109
untar_video_data(output_tar)
11071110

11081111
# Link cache_path to cache_dir if needed.
1109-
if create_link:
1112+
if create_link and cache_path is not None:
11101113
if not os.path.exists(cache_dir) or os.path.islink(cache_dir):
11111114
if os.path.islink(cache_dir):
11121115
os.remove(cache_dir)

lmms_eval/models/chat/fastvideo.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def _safe(name: str, default: str = "x") -> str:
7070
return s[:128]
7171

7272

73+
def _default_output_dir() -> str:
74+
hf_home = os.path.expanduser(os.getenv("HF_HOME", "~/.cache/huggingface"))
75+
return os.path.join(hf_home, "lmms_eval", "generated_videos", "fastvideo")
76+
77+
7378
_DTYPES = {
7479
"float32": torch.float32,
7580
"fp32": torch.float32,
@@ -182,9 +187,10 @@ def __init__(
182187
vae_cpu_offload: bool = True,
183188
# Misc
184189
trust_remote_code: bool = True,
185-
output_dir: str = "./fastvideo_generated_videos",
190+
output_dir: Optional[str] = None,
186191
batch_size: int = 1,
187-
# Resume support: skip samples whose output mp4 already exists.
192+
# Artifact reuse: lmms-eval's response cache stores the JSON response,
193+
# while VBVR still needs the referenced mp4 on disk.
188194
overwrite: bool = False,
189195
**kwargs,
190196
):
@@ -206,7 +212,7 @@ def __init__(
206212
self.seed = seed
207213
self.negative_prompt = negative_prompt
208214

209-
self.output_dir = os.path.abspath(os.path.expanduser(output_dir))
215+
self.output_dir = os.path.abspath(os.path.expanduser(output_dir or _default_output_dir()))
210216
os.makedirs(self.output_dir, exist_ok=True)
211217
self._tmp_img_dir = tempfile.mkdtemp(prefix="fastvideo_inputs_")
212218

@@ -522,8 +528,8 @@ def generate_until(self, requests: List[Instance]) -> List[GenerationResult]:
522528
with ThreadPoolExecutor(max_workers=WORKERS) as executor:
523529
prepared = list(executor.map(self.make_one_request, requests))
524530

525-
# Resume: if the target mp4 already exists and is non-empty, reuse it.
526-
# Set overwrite=True in model_args to force regeneration.
531+
# Reuse generated artifacts when the target mp4 already exists and is
532+
# non-empty. Set overwrite=True in model_args to force regeneration.
527533
presults: List[Optional[GenerationResult]] = [None] * len(prepared)
528534
skipped_indices: List[int] = []
529535
if not self.overwrite:

lmms_eval/tasks/vbvr/README.md

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ MP4 video; scoring is rule-based and per-task (no LLM judge, no CLIP).
1717
| `vbvr_in_domain` | In-Domain_50 only |
1818
| `vbvr_out_of_domain` | Out-of-Domain_50 only |
1919

20-
## One-time setup
20+
## Data Cache
2121

2222
The HF dataset card (`Video-Reason/VBVR-Bench-Data`) carries the base64-encoded
2323
first-frame plus **relative** paths to `ground_truth.mp4`, `first_frame.png`,
24-
`final_frame.png`, `prompt.txt` etc. The rule-based evaluators read those GT
25-
files, so you must first download the repo and point `VBVR_GT_PATH` at it:
24+
`final_frame.png`, `prompt.txt` etc. The task config uses
25+
`dataset_kwargs.cache_dir: vbvr`, so lmms-eval downloads the dataset snapshot
26+
and links it under `$HF_HOME/vbvr` by default. The rule-based evaluators resolve
27+
GT files from that cache path automatically.
2628

27-
```bash
28-
hf download Video-Reason/VBVR-Bench-Data \
29-
--repo-type dataset \
30-
--local-dir /data/VBVR-Bench
29+
If you already have a local checkout, you can still override the GT root with:
3130

31+
```bash
3232
export VBVR_GT_PATH=/data/VBVR-Bench
3333
```
3434

@@ -58,16 +58,13 @@ The model must output JSON of the form:
5858
cd /path/to/lmms-eval; or exit 1
5959
6060
# Rule-based VBVR scorers read the GT mp4s/pngs from this root.
61-
set -gx VBVR_GT_PATH /path/to/VBVR-Bench
61+
# By default this is populated automatically at $HF_HOME/vbvr.
62+
# Uncomment this only if you want to use an existing local checkout.
63+
# set -gx VBVR_GT_PATH /path/to/VBVR-Bench
6264
6365
set MODEL_DIR /path/to/Wan2.2-I2V-A14B-Diffusers
64-
set OUT_ROOT /path/to/eval_out/vbvr_wan22_full_highres
65-
set VIDEOS_DIR $OUT_ROOT/videos
66-
set METRICS_DIR $OUT_ROOT/metrics
67-
mkdir -p $VIDEOS_DIR $METRICS_DIR
6866
6967
set MODEL_ARGS "model=$MODEL_DIR"
70-
set MODEL_ARGS "$MODEL_ARGS,output_dir=$VIDEOS_DIR"
7168
set MODEL_ARGS "$MODEL_ARGS,data_parallel=4,num_gpus=2,sp_size=2,tp_size=1"
7269
set MODEL_ARGS "$MODEL_ARGS,num_inference_steps=50,num_frames=81"
7370
set MODEL_ARGS "$MODEL_ARGS,height=1024,width=1024,fps=16"
@@ -81,12 +78,16 @@ exec stdbuf -oL -eL .venv/bin/python -m lmms_eval eval \
8178
--tasks vbvr \
8279
--batch_size 1 \
8380
--log_samples \
84-
--output_path $METRICS_DIR
81+
--output_path logs
8582
```
8683

87-
Generated videos land in `$VIDEOS_DIR`; per-sample logs and aggregated metrics
88-
land in `$METRICS_DIR`. Tune `data_parallel`, `num_gpus`, `sp_size`, and the
89-
`*_cpu_offload` flags to match your hardware.
84+
Generated videos land in `$HF_HOME/lmms_eval/generated_videos/fastvideo` by
85+
default. Per-sample logs and aggregated metrics land under `--output_path`, and
86+
the detailed VBVR evaluation JSON is written through `generate_submission_file()`
87+
under `--output_path/submissions/`. Add `--use_cache <path>` only if you want
88+
lmms-eval response caching in addition to FastVideo's generated-mp4 reuse. Tune
89+
`data_parallel`, `num_gpus`, `sp_size`, and the `*_cpu_offload` flags to match
90+
your hardware.
9091

9192
## Metrics
9293

lmms_eval/tasks/vbvr/_default_template_yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
dataset_path: Video-Reason/VBVR-Bench-Data
2+
dataset_kwargs:
3+
cache_dir: vbvr
4+
video: True
5+
create_link: True
26
test_split: test
37
output_type: generate_until
48

lmms_eval/tasks/vbvr/utils.py

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
1. Each sample carries a first-frame image (base64 PNG) + a text prompt.
66
2. The model generates an MP4 and returns JSON: ``{"text": "", "videos": [path]}``.
77
3. ``vbvr_process_results`` parses the JSON, resolves the matching ground-truth
8-
folder under ``$VBVR_GT_PATH``, dispatches to the per-task rule-based
9-
evaluator from the vendored ``vbvr_bench`` package, and records a per-sample
10-
score + dimension breakdown.
8+
folder from the configured cache or ``$VBVR_GT_PATH``, dispatches to the
9+
per-task rule-based evaluator from the vendored ``vbvr_bench`` package, and
10+
records a per-sample score + dimension breakdown.
1111
4. Aggregation functions compute In-Domain / Out-of-Domain / per-category means
1212
and an overall mean matching the upstream VBVRBench output.
1313
1414
Environment variables
1515
---------------------
16-
- ``VBVR_GT_PATH``: local root of the downloaded Video-Reason/VBVR-Bench-Data
17-
dataset. Must contain ``In-Domain_50/`` and ``Out-of-Domain_50/`` folders with
18-
``{task_name}/{video_idx}/{first_frame.png,final_frame.png,ground_truth.mp4,prompt.txt}``.
16+
- ``VBVR_GT_PATH``: optional local root of the downloaded
17+
Video-Reason/VBVR-Bench-Data dataset. If unset, the task uses the configured
18+
lmms-eval/Hugging Face cache directory and falls back to ``snapshot_download``.
1919
"""
2020

2121
from __future__ import annotations
@@ -26,12 +26,18 @@
2626
import os
2727
import re
2828
from collections import defaultdict
29+
from functools import lru_cache
30+
from pathlib import Path
2931
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
3032

3133
import numpy as np
34+
import yaml
35+
from huggingface_hub import snapshot_download
3236
from loguru import logger as eval_logger
3337
from PIL import Image
3438

39+
from lmms_eval import utils as lmms_utils
40+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
3541
from lmms_eval.tasks.vbvr.vbvr_bench.evaluators import (
3642
get_evaluator,
3743
get_split,
@@ -44,11 +50,43 @@
4450
_BASE64_PREFIX = re.compile(r"^data:image/[^;]+;base64,", re.IGNORECASE)
4551

4652

53+
@lru_cache(maxsize=1)
54+
def _task_config() -> Dict[str, Any]:
55+
with open(Path(__file__).parent / "_default_template_yaml", "r", encoding="utf-8") as f:
56+
safe_data = [line for line in f if "!function" not in line]
57+
return yaml.safe_load("".join(safe_data)) or {}
58+
59+
60+
def _dataset_repo_id() -> str:
61+
return str(_task_config()["dataset_path"])
62+
63+
64+
def _cache_dir_name() -> str:
65+
return str(_task_config()["dataset_kwargs"]["cache_dir"])
66+
67+
68+
def _looks_like_vbvr_root(root: str) -> bool:
69+
return all(os.path.isdir(os.path.join(root, split)) for split in ("In-Domain_50", "Out-of-Domain_50"))
70+
71+
72+
@lru_cache(maxsize=1)
4773
def _gt_root() -> str:
4874
root = os.getenv("VBVR_GT_PATH")
49-
if not root:
50-
raise RuntimeError("VBVR_GT_PATH is not set. Download the GT with:\n" " hf download Video-Reason/VBVR-Bench-Data --repo-type dataset --local-dir <path>\n" "then `export VBVR_GT_PATH=<path>`.")
51-
return os.path.expanduser(os.path.expandvars(root))
75+
if root:
76+
root = os.path.expanduser(os.path.expandvars(root))
77+
if _looks_like_vbvr_root(root):
78+
return root
79+
raise RuntimeError(f"VBVR_GT_PATH does not look like a VBVR-Bench checkout: {root}")
80+
81+
hf_home = os.path.expanduser(os.getenv("HF_HOME", "~/.cache/huggingface"))
82+
cache_root = lmms_utils.resolve_cache_dir(_cache_dir_name(), base_dir=hf_home)
83+
if _looks_like_vbvr_root(cache_root):
84+
return cache_root
85+
86+
snapshot_root = snapshot_download(repo_id=_dataset_repo_id(), repo_type="dataset")
87+
if _looks_like_vbvr_root(snapshot_root):
88+
return snapshot_root
89+
raise RuntimeError(f"Could not locate VBVR GT files in {cache_root} or HF snapshot {snapshot_root}.")
5290

5391

5492
def _decode_base64_image(data: str) -> Image.Image:
@@ -169,6 +207,7 @@ def _fanout_metrics(entry: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
169207
"vbvr_perception": entry,
170208
"vbvr_spatiality": entry,
171209
"vbvr_transformation": entry,
210+
"submission": entry,
172211
}
173212

174213

@@ -237,6 +276,42 @@ def _agg_by(results, key: str, value: str, label: str) -> float:
237276
return mean
238277

239278

279+
def _summary(entries: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
280+
scores = [float(e["score"]) for e in entries if isinstance(e.get("score"), (int, float))]
281+
summary: Dict[str, Any] = {
282+
"overall": _mean(scores),
283+
"n": len(scores),
284+
}
285+
for split in SPLITS:
286+
split_scores = [float(e["score"]) for e in entries if e.get("split") == split and isinstance(e.get("score"), (int, float))]
287+
summary[split] = {"score": _mean(split_scores), "n": len(split_scores)}
288+
for category in CATEGORIES:
289+
category_scores = [float(e["score"]) for e in entries if e.get("category") == category and isinstance(e.get("score"), (int, float))]
290+
summary[category.lower()] = {"score": _mean(category_scores), "n": len(category_scores)}
291+
return summary
292+
293+
294+
def _submission_file_name(entries: Sequence[Dict[str, Any]]) -> str:
295+
splits = {e.get("split") for e in entries if e.get("split")}
296+
if splits == {"In_Domain"}:
297+
return "vbvr_in_domain_eval_results.json"
298+
if splits == {"Out_of_Domain"}:
299+
return "vbvr_out_of_domain_eval_results.json"
300+
return "vbvr_eval_results.json"
301+
302+
303+
def vbvr_aggregate_submission(results, args) -> None:
304+
entries = sorted(_entries(results), key=lambda e: (str(e.get("file_split", "")), str(e.get("task_name", "")), str(e.get("video_idx", ""))))
305+
path = generate_submission_file(_submission_file_name(entries), args)
306+
payload = {
307+
"summary": _summary(entries),
308+
"results": entries,
309+
}
310+
with open(path, "w", encoding="utf-8") as f:
311+
json.dump(payload, f, indent=2)
312+
eval_logger.info(f"[VBVR] Detailed evaluation results saved to {path}")
313+
314+
240315
def vbvr_aggregate_overall(results) -> float:
241316
entries = _entries(results)
242317
if not entries:

lmms_eval/tasks/vbvr/vbvr.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ metric_list:
2626
- metric: vbvr_transformation
2727
aggregation: !function utils.vbvr_aggregate_transformation
2828
higher_is_better: true
29+
- metric: submission
30+
aggregation: !function utils.vbvr_aggregate_submission
31+
higher_is_better: true
2932

3033
metadata:
3134
- version: 0.1

lmms_eval/tasks/vbvr/vbvr_in_domain.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ metric_list:
2222
- metric: vbvr_transformation
2323
aggregation: !function utils.vbvr_aggregate_transformation
2424
higher_is_better: true
25+
- metric: submission
26+
aggregation: !function utils.vbvr_aggregate_submission
27+
higher_is_better: true
2528

2629
metadata:
2730
- version: 0.1

lmms_eval/tasks/vbvr/vbvr_out_of_domain.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ metric_list:
2222
- metric: vbvr_transformation
2323
aggregation: !function utils.vbvr_aggregate_transformation
2424
higher_is_better: true
25+
- metric: submission
26+
aggregation: !function utils.vbvr_aggregate_submission
27+
higher_is_better: true
2528

2629
metadata:
2730
- version: 0.1

0 commit comments

Comments
 (0)