Skip to content

Commit 88ab7ea

Browse files
committed
refactor(recording): make filename and required_context required
The previous defaults (`filename="{task_name}_ep{episode_idx}_{status}.mp4"` and `required_context=("episode_idx",)`) were a footgun: - `task_name` is not a universal key. RoboMME exposes `env_id`, LIBERO uses `suite`/`task_id`, SimplerEnv uses `task` — the default name doesn't fit any of them out of the box. - The default `required_context` only covered one of the two keys the default template referenced. A caller passing `episode_idx` but forgetting `task_name` would pass `start()` validation, succeed at every `record()`, then silently lose the mp4 at `save()` when `str.format` raised `KeyError`. Make both required so each benchmark spells out its naming scheme and its context contract once at construction. Mismatches surface at `start()` rather than as a dropped mp4 at the end of the episode. RoboMME wiring updated to pass `{env_id}_ep{episode_idx}_{status}.mp4` with `required_context=("env_id", "episode_idx")`.
1 parent 521d9d4 commit 88ab7ea

3 files changed

Lines changed: 70 additions & 34 deletions

File tree

src/vla_eval/benchmarks/recording.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
* **Logging-style filename templating**: the filename is a ``str.format``
2222
template (or callable) over a context dict that the caller passes at
2323
``start()`` time, plus a ``status`` key injected at ``save()`` time.
24-
Default template ``"{task_name}_ep{episode_idx}_{status}.mp4"``;
25-
benchmarks with richer identifiers can use any context they like, e.g.
26-
``"{suite}/{task}/{episode_idx:04d}_seed{seed}_{status}.mp4"``.
24+
``filename`` and ``required_context`` are required at construction time
25+
— there is no universal default because every benchmark names tasks
26+
differently (``env_id``, ``task_id``, ``suite/task`` …). Forcing the
27+
caller to spell it out catches mismatched context keys at ``start()``
28+
rather than as a silent dropped mp4 at ``save()``.
2729
* **Best-effort**: every encode-side failure logs a warning with the
2830
context for debuggability and clears state — a corrupted video should
2931
never bring down an otherwise good eval episode.
@@ -34,13 +36,13 @@
3436
3537
recorder = EpisodeVideoRecorder(
3638
output_dir="/workspace/results/videos",
37-
# filename can stay default, or e.g.
38-
# filename="{suite}/{task}_ep{episode_idx}_{status}.mp4",
39+
filename="{env_id}_ep{episode_idx}_{status}.mp4",
40+
required_context=("env_id", "episode_idx"),
3941
fps=20,
4042
)
4143
4244
# In benchmark.reset(task):
43-
recorder.start({"task_name": task["env_id"], "episode_idx": task["episode_idx"]})
45+
recorder.start({"env_id": task["env_id"], "episode_idx": task["episode_idx"]})
4446
recorder.record(initial_frame)
4547
4648
# In benchmark.step(action):
@@ -100,9 +102,9 @@ class EpisodeVideoRecorder:
100102
def __init__(
101103
self,
102104
output_dir: str | os.PathLike[str],
103-
filename: FilenameSpec = "{task_name}_ep{episode_idx}_{status}.mp4",
105+
filename: FilenameSpec,
106+
required_context: Sequence[str],
104107
fps: int = 20,
105-
required_context: Sequence[str] = ("episode_idx",),
106108
writer_kwargs: Mapping[str, Any] | None = None,
107109
) -> None:
108110
"""
@@ -112,13 +114,17 @@ def __init__(
112114
``"{suite}/{task}_..."``); intermediate dirs are also created.
113115
filename: ``str.format`` template or callable producing the
114116
filename relative to ``output_dir``. Resolved at ``save()``
115-
time over ``{**start_context, "status": status}``.
116-
fps: Output framerate.
117+
time over ``{**start_context, "status": status}``. Required
118+
because every benchmark identifies tasks differently
119+
(``env_id``, ``task_id``, ``suite/task``) — there is no
120+
universally safe default.
117121
required_context: Keys that must be present in the dict passed to
118122
``start()``. ``ValueError`` is raised at ``start()`` if any
119-
are missing. Default ``("episode_idx",)`` because without an
120-
episode index, multi-episode runs of the same task collide
121-
on a single ``_ep0_`` filename.
123+
are missing. Required so callers explicitly declare the
124+
template's expectations; failing fast at ``start()`` avoids
125+
a silent ``KeyError`` -> dropped mp4 at ``save()`` time.
126+
Should include every key the ``filename`` template references.
127+
fps: Output framerate.
122128
writer_kwargs: Extra kwargs forwarded to ``imageio.get_writer``
123129
(e.g. ``{"codec": "libx264", "quality": 8}``).
124130
"""

src/vla_eval/benchmarks/robomme/benchmark.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,12 @@ def __init__(
193193
self.subgoal_mode = subgoal_mode
194194

195195
self._recorder: EpisodeVideoRecorder | None = (
196-
EpisodeVideoRecorder(output_dir=video_dir or "/workspace/results/videos", fps=20)
196+
EpisodeVideoRecorder(
197+
output_dir=video_dir or "/workspace/results/videos",
198+
filename="{env_id}_ep{episode_idx}_{status}.mp4",
199+
required_context=("env_id", "episode_idx"),
200+
fps=20,
201+
)
197202
if save_episode_video
198203
else None
199204
)
@@ -385,8 +390,12 @@ def reset(self, task: Task) -> Any:
385390
self._current_subgoal = self._extract_subgoal(info_flat)
386391

387392
if self._recorder is not None:
388-
task_name = task.get("name") or task.get("env_id", "unknown")
389-
self._recorder.start({"task_name": task_name, "episode_idx": task.get("episode_idx", 0)})
393+
self._recorder.start(
394+
{
395+
"env_id": task.get("env_id") or task.get("name") or "unknown",
396+
"episode_idx": task.get("episode_idx", 0),
397+
}
398+
)
390399
front_list = obs_batch.get("front_rgb_list", [])
391400
if front_list:
392401
self._recorder.record(front_list[-1])

tests/test_recording.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,31 @@ def _count_frames(path: Path) -> int:
3737
return sum(1 for _ in r) # ty: ignore[not-iterable]
3838

3939

40+
def _rec(tmp_path: Path, **overrides: Any) -> EpisodeVideoRecorder:
41+
"""Construct a recorder with stable test-suite defaults.
42+
43+
The recorder itself has no defaults for ``filename`` /
44+
``required_context`` (every benchmark spells them out explicitly).
45+
Tests don't need to repeat that boilerplate, so this helper picks
46+
the same ``task_name``/``episode_idx`` template the original
47+
test data assumed.
48+
"""
49+
kwargs: dict[str, Any] = {
50+
"output_dir": tmp_path,
51+
"filename": "{task_name}_ep{episode_idx}_{status}.mp4",
52+
"required_context": ("task_name", "episode_idx"),
53+
}
54+
kwargs.update(overrides)
55+
return EpisodeVideoRecorder(**kwargs)
56+
57+
4058
# ---------------------------------------------------------------------------
4159
# Happy path
4260
# ---------------------------------------------------------------------------
4361

4462

4563
def test_save_writes_mp4_with_correct_framecount(tmp_path: Path) -> None:
46-
rec = EpisodeVideoRecorder(output_dir=tmp_path, fps=10)
64+
rec = _rec(tmp_path, fps=10)
4765
rec.start({"task_name": "PickCube", "episode_idx": 0})
4866
for _ in range(5):
4967
rec.record(_frame())
@@ -56,7 +74,7 @@ def test_save_writes_mp4_with_correct_framecount(tmp_path: Path) -> None:
5674

5775

5876
def test_save_uses_status_in_filename(tmp_path: Path) -> None:
59-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
77+
rec = _rec(tmp_path)
6078
rec.start({"task_name": "T", "episode_idx": 7})
6179
rec.record(_frame())
6280
final = rec.save(status="fail")
@@ -66,7 +84,7 @@ def test_save_uses_status_in_filename(tmp_path: Path) -> None:
6684

6785

6886
def test_active_flag_tracks_lifecycle(tmp_path: Path) -> None:
69-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
87+
rec = _rec(tmp_path)
7088
assert rec.active is False
7189
rec.start({"task_name": "T", "episode_idx": 0})
7290
assert rec.active is True
@@ -75,7 +93,7 @@ def test_active_flag_tracks_lifecycle(tmp_path: Path) -> None:
7593

7694

7795
def test_consecutive_episodes_each_produce_their_own_file(tmp_path: Path) -> None:
78-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
96+
rec = _rec(tmp_path)
7997
for ep in range(3):
8098
rec.start({"task_name": "T", "episode_idx": ep})
8199
rec.record(_frame())
@@ -125,18 +143,20 @@ def naming(ctx: Mapping[str, Any]) -> str:
125143
assert final == tmp_path / "abc-3-ok.mp4"
126144

127145

128-
def test_save_with_missing_template_key_is_handled(tmp_path: Path) -> None:
146+
def test_save_with_template_key_not_in_required_context_is_handled(tmp_path: Path) -> None:
147+
# required_context is the caller's contract for what must be present at
148+
# start(); it's permitted to be a subset of the keys the template uses
149+
# (e.g. an optional `seed`). When a template key is genuinely missing
150+
# at save() time, resolution should fail gracefully rather than raise.
129151
rec = EpisodeVideoRecorder(
130152
output_dir=tmp_path,
131153
filename="{task_name}_{seed}_{status}.mp4",
154+
required_context=("task_name",),
132155
)
133-
rec.start({"task_name": "T", "episode_idx": 0}) # `seed` missing
156+
rec.start({"task_name": "T"}) # `seed` missing
134157
rec.record(_frame())
135-
# Resolution happens at save() time; a missing key logs and returns None
136-
# rather than raising.
137158
final = rec.save(status="success")
138159
assert final is None
139-
# Tempfile must have been cleaned up.
140160
assert list(tmp_path.glob(".recorder-*.mp4")) == []
141161

142162

@@ -146,31 +166,31 @@ def test_save_with_missing_template_key_is_handled(tmp_path: Path) -> None:
146166

147167

148168
def test_start_missing_required_context_raises(tmp_path: Path) -> None:
149-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
169+
rec = _rec(tmp_path)
150170
with pytest.raises(ValueError, match="missing required context keys"):
151171
rec.start({"task_name": "T"}) # episode_idx missing
152172
assert rec.active is False
153173

154174

155175
def test_record_before_start_is_noop(tmp_path: Path) -> None:
156-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
176+
rec = _rec(tmp_path)
157177
rec.record(_frame()) # must not raise
158178
assert rec.active is False
159179

160180

161181
def test_save_before_start_returns_none(tmp_path: Path) -> None:
162-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
182+
rec = _rec(tmp_path)
163183
assert rec.save() is None
164184

165185

166186
def test_discard_before_start_is_noop(tmp_path: Path) -> None:
167-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
187+
rec = _rec(tmp_path)
168188
rec.discard() # must not raise
169189
assert rec.active is False
170190

171191

172192
def test_writer_open_failure_leaves_recorder_inactive(tmp_path: Path) -> None:
173-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
193+
rec = _rec(tmp_path)
174194
with patch("imageio.get_writer", side_effect=RuntimeError("nope")):
175195
rec.start({"task_name": "T", "episode_idx": 0})
176196
assert rec.active is False
@@ -187,13 +207,14 @@ def test_writer_open_failure_leaves_recorder_inactive(tmp_path: Path) -> None:
187207

188208

189209
def test_start_again_without_save_discards_prior_episode(tmp_path: Path) -> None:
190-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
210+
rec = _rec(tmp_path)
191211
rec.start({"task_name": "T", "episode_idx": 0})
192212
rec.record(_frame())
193213
# Simulate orchestrator skipping save() / discard() and starting next ep:
194214
rec.start({"task_name": "T", "episode_idx": 1})
195215
rec.record(_frame())
196216
final = rec.save(status="success")
217+
assert final is not None
197218
assert final == tmp_path / "T_ep1_success.mp4"
198219
# Only ep1 mp4 should exist; ep0's tempfile was cleaned up.
199220
mp4s = sorted(p.name for p in tmp_path.glob("*.mp4"))
@@ -202,7 +223,7 @@ def test_start_again_without_save_discards_prior_episode(tmp_path: Path) -> None
202223

203224

204225
def test_discard_cleans_up_tempfile(tmp_path: Path) -> None:
205-
rec = EpisodeVideoRecorder(output_dir=tmp_path)
226+
rec = _rec(tmp_path)
206227
rec.start({"task_name": "T", "episode_idx": 0})
207228
rec.record(_frame())
208229
rec.discard()
@@ -219,7 +240,7 @@ def test_discard_cleans_up_tempfile(tmp_path: Path) -> None:
219240
def test_output_dir_created_lazily(tmp_path: Path) -> None:
220241
target = tmp_path / "nested" / "videos"
221242
assert not target.exists()
222-
rec = EpisodeVideoRecorder(output_dir=target)
243+
rec = _rec(target)
223244
rec.start({"task_name": "T", "episode_idx": 0})
224245
rec.record(_frame())
225246
final = rec.save()
@@ -229,7 +250,7 @@ def test_output_dir_created_lazily(tmp_path: Path) -> None:
229250

230251

231252
def test_str_path_accepted(tmp_path: Path) -> None:
232-
rec = EpisodeVideoRecorder(output_dir=str(tmp_path))
253+
rec = _rec(tmp_path, output_dir=str(tmp_path))
233254
rec.start({"task_name": "T", "episode_idx": 0})
234255
rec.record(_frame())
235256
final = rec.save()

0 commit comments

Comments
 (0)