Skip to content

Commit dd64554

Browse files
authored
Merge pull request #349 from alan-turing-institute/add-render-rollout-snapshot
Add rollout snapshot rendering to eval
2 parents bed4611 + 1ee7ea8 commit dd64554

6 files changed

Lines changed: 347 additions & 4 deletions

File tree

src/autocast/configs/eval/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ All eval configs support these parameters:
6262
- `video_format`: Video format (mp4 or gif)
6363
- `video_sample_index`: Sample index within batch to visualize
6464
- `fps`: Frames per second for videos
65+
- `save_rollout_snapshots`: Save still rollout panels from raw tensors
66+
- `rollout_snapshot_timesteps`: Timestep indices shown in each still panel
67+
- `rollout_snapshot_channels`: Channel indices to render (`null` means all)
68+
- `rollout_snapshot_dir`: Custom snapshot directory (default:
69+
work_dir/videos/snapshots)
6570
- `accelerator`: Accelerator for evaluation (auto, cpu, cuda, mps)
6671
- `devices`: Number of GPUs for DDP evaluation (default: 1; set explicitly,
6772
e.g. 4, for multi-GPU runs)

src/autocast/configs/eval/default.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ video_format: mp4 # gif or mp4
5454
video_sample_index: 0
5555
fps: 5
5656

57+
# Rollout snapshot settings. When enabled, eval saves still panels for the
58+
# requested rollout samples using raw tensors rather than extracting MP4 frames.
59+
save_rollout_snapshots: false
60+
rollout_snapshot_dir: null # default: <video_dir>/snapshots
61+
rollout_snapshot_timesteps: [0, 4, 12, 30, 99]
62+
rollout_snapshot_channels: null # null means all channels
63+
rollout_snapshot_format: png
64+
5765
# Accelerator for evaluation (mirrors Lightning Fabric API)
5866
accelerator: auto # auto, cpu, cuda, or mps
5967
devices: 1 # default single-GPU eval; override with int (e.g. 4) for multi-GPU DDP

src/autocast/configs/eval/encoder_processor_decoder.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,5 @@ compute_rollout_metrics: true
5959
metric_windows: [null]
6060
metric_windows_rollout: [[0, 1], [0, 4], [6, 12], [13, 30], [31, 99]]
6161
batch_indices: [0, 1, 2, 3, 4, 5, 6, 7]
62-
preserve_aspect: true
62+
preserve_aspect: true
63+
save_rollout_snapshots: true

src/autocast/scripts/eval/encoder_processor_decoder.py

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
from autocast.scripts.training import apply_float32_matmul_precision
7676
from autocast.scripts.utils import get_default_config_path
7777
from autocast.types.batch import Batch, EncodedBatch
78-
from autocast.utils import plot_spatiotemporal_video
78+
from autocast.utils import plot_spatiotemporal_snapshots, plot_spatiotemporal_video
7979
from autocast.utils.plots import (
8080
compute_metrics_from_dataloader,
8181
compute_metrics_per_timestep_from_dataloader,
@@ -298,6 +298,13 @@ def _resolve_video_dir(eval_cfg: DictConfig, work_dir: Path) -> Path:
298298
return (work_dir / "videos").resolve()
299299

300300

301+
def _resolve_rollout_snapshot_dir(eval_cfg: DictConfig, video_dir: Path) -> Path:
302+
snapshot_dir = eval_cfg.get("rollout_snapshot_dir")
303+
if snapshot_dir is not None:
304+
return Path(snapshot_dir).expanduser().resolve()
305+
return (video_dir / "snapshots").resolve()
306+
307+
301308
def _unwrap_module(module: Any) -> Any:
302309
"""Return the underlying model when wrapped by Fabric/DDP-style wrappers."""
303310
unwrapped = module
@@ -500,6 +507,125 @@ def _collect_rollout_sample_targets_for_batch(
500507
return sample_targets
501508

502509

510+
def _select_snapshot_timesteps(
511+
timesteps: Sequence[int] | None,
512+
n_timesteps: int,
513+
) -> list[int]:
514+
if not timesteps:
515+
return []
516+
517+
selected: list[int] = []
518+
invalid: list[int] = []
519+
seen: set[int] = set()
520+
for raw_timestep in timesteps:
521+
timestep = int(raw_timestep)
522+
if 0 <= timestep < n_timesteps:
523+
if timestep not in seen:
524+
selected.append(timestep)
525+
seen.add(timestep)
526+
else:
527+
invalid.append(timestep)
528+
529+
if invalid:
530+
log.warning(
531+
"Ignoring rollout snapshot timestep(s) outside [0, %s]: %s",
532+
n_timesteps - 1,
533+
invalid,
534+
)
535+
return selected
536+
537+
538+
def _select_snapshot_channels(
539+
channels: Sequence[int] | None,
540+
n_channels: int,
541+
) -> list[int]:
542+
if channels is None:
543+
return list(range(n_channels))
544+
545+
selected: list[int] = []
546+
invalid: list[int] = []
547+
seen: set[int] = set()
548+
for raw_channel in channels:
549+
channel = int(raw_channel)
550+
if 0 <= channel < n_channels:
551+
if channel not in seen:
552+
selected.append(channel)
553+
seen.add(channel)
554+
else:
555+
invalid.append(channel)
556+
557+
if invalid:
558+
log.warning(
559+
"Ignoring rollout snapshot channel(s) outside [0, %s]: %s",
560+
n_channels - 1,
561+
invalid,
562+
)
563+
return selected
564+
565+
566+
def _prepare_rollout_snapshot_plan(
567+
*,
568+
snapshot_dir: Path | None,
569+
snapshot_timesteps: Sequence[int] | None,
570+
snapshot_channels: Sequence[int] | None,
571+
n_timesteps: int,
572+
n_channels: int,
573+
) -> tuple[Path | None, list[int], list[int]]:
574+
if snapshot_dir is None:
575+
return None, [], []
576+
577+
selected_timesteps = _select_snapshot_timesteps(snapshot_timesteps, n_timesteps)
578+
if not selected_timesteps:
579+
return None, [], []
580+
581+
selected_channels = _select_snapshot_channels(snapshot_channels, n_channels)
582+
if not selected_channels:
583+
return None, [], []
584+
585+
snapshot_dir.mkdir(parents=True, exist_ok=True)
586+
return snapshot_dir, selected_timesteps, selected_channels
587+
588+
589+
def _save_rollout_snapshot_panels(
590+
*,
591+
trues_mean: torch.Tensor,
592+
preds_mean: torch.Tensor,
593+
preds_uq: torch.Tensor | None,
594+
local_idx: int,
595+
target_idx: int,
596+
snapshot_dir: Path,
597+
snapshot_timesteps: Sequence[int],
598+
snapshot_channels: Sequence[int],
599+
snapshot_ext: str,
600+
saved_paths: list[Path],
601+
names_for_plot: list[str] | None,
602+
preserve_aspect: bool,
603+
) -> None:
604+
for channel_idx in snapshot_channels:
605+
snapshot_filename = snapshot_dir / (
606+
f"batch_{target_idx}_sample_{local_idx}_"
607+
f"channel_{channel_idx}_snapshots.{snapshot_ext}"
608+
)
609+
plot_spatiotemporal_snapshots(
610+
true=trues_mean[local_idx : local_idx + 1].cpu(),
611+
pred=preds_mean[local_idx : local_idx + 1].cpu(),
612+
pred_uq=(
613+
preds_uq[local_idx : local_idx + 1].cpu()
614+
if preds_uq is not None
615+
else None
616+
),
617+
timesteps=snapshot_timesteps,
618+
channel=channel_idx,
619+
batch_idx=0,
620+
save_path=str(snapshot_filename),
621+
title="Rollout snapshots",
622+
channel_names=names_for_plot,
623+
preserve_aspect=preserve_aspect,
624+
)
625+
saved_paths.append(snapshot_filename)
626+
log.info("Saved rollout snapshot panel to %s", snapshot_filename)
627+
628+
503629
def _render_rollouts( # noqa: PLR0912
504630
model: (
505631
EncoderProcessorDecoder
@@ -520,6 +646,10 @@ def _render_rollouts( # noqa: PLR0912
520646
channel_names: list[str] | None = None,
521647
preserve_aspect: bool = False,
522648
decode_fn: Callable | None = None,
649+
snapshot_timesteps: Sequence[int] | None = None,
650+
snapshot_dir: Path | None = None,
651+
snapshot_format: str = "png",
652+
snapshot_channels: Sequence[int] | None = None,
523653
) -> list[Path]:
524654
# Return early if no rollout indices are requested
525655
if not batch_indices:
@@ -533,6 +663,7 @@ def _render_rollouts( # noqa: PLR0912
533663
rendered_targets: set[int] = set()
534664
global_sample_offset = 0
535665
video_dir.mkdir(parents=True, exist_ok=True)
666+
snapshot_ext = snapshot_format.removeprefix(".") or "png"
536667

537668
# Perform rollouts and save videos for requested target indices.
538669
with torch.no_grad():
@@ -599,6 +730,17 @@ def _render_rollouts( # noqa: PLR0912
599730
sample_index=sample_index,
600731
global_sample_offset=global_sample_offset,
601732
)
733+
(
734+
snapshot_dir_for_batch,
735+
snapshot_timesteps_for_batch,
736+
snapshot_channels_for_batch,
737+
) = _prepare_rollout_snapshot_plan(
738+
snapshot_dir=snapshot_dir,
739+
snapshot_timesteps=snapshot_timesteps,
740+
snapshot_channels=snapshot_channels,
741+
n_timesteps=int(trues_mean.shape[1]),
742+
n_channels=n_channels,
743+
)
602744

603745
for target_idx, local_idx in sample_targets.items():
604746
if target_idx in rendered_targets:
@@ -627,6 +769,22 @@ def _render_rollouts( # noqa: PLR0912
627769
rendered_targets.add(target_idx)
628770
log.info("Saved rollout visualization to %s", filename)
629771

772+
if snapshot_dir_for_batch is not None:
773+
_save_rollout_snapshot_panels(
774+
trues_mean=trues_mean,
775+
preds_mean=preds_mean,
776+
preds_uq=preds_uq,
777+
local_idx=local_idx,
778+
target_idx=target_idx,
779+
snapshot_dir=snapshot_dir_for_batch,
780+
snapshot_timesteps=snapshot_timesteps_for_batch,
781+
snapshot_channels=snapshot_channels_for_batch,
782+
snapshot_ext=snapshot_ext,
783+
saved_paths=saved_paths,
784+
names_for_plot=names_for_plot,
785+
preserve_aspect=preserve_aspect,
786+
)
787+
630788
global_sample_offset += batch_size
631789

632790
# Check for any missing rollout sample indices that were requested
@@ -1920,6 +2078,22 @@ def coverage_factory() -> Metric:
19202078
rollout_test_loader,
19212079
max_rollout_batches,
19222080
)
2081+
save_rollout_snapshots = bool(eval_cfg.get("save_rollout_snapshots", False))
2082+
rollout_snapshot_dir = (
2083+
_resolve_rollout_snapshot_dir(eval_cfg, video_dir)
2084+
if save_rollout_snapshots
2085+
else None
2086+
)
2087+
rollout_snapshot_timesteps = (
2088+
eval_cfg.get("rollout_snapshot_timesteps", [])
2089+
if save_rollout_snapshots
2090+
else None
2091+
)
2092+
rollout_snapshot_channels = (
2093+
eval_cfg.get("rollout_snapshot_channels", None)
2094+
if save_rollout_snapshots
2095+
else None
2096+
)
19232097
_render_rollouts(
19242098
model, # pyright: ignore[reportArgumentType]
19252099
rollout_loader,
@@ -1935,6 +2109,10 @@ def coverage_factory() -> Metric:
19352109
channel_names=rollout_channel_names,
19362110
preserve_aspect=eval_cfg.get("preserve_aspect", False),
19372111
decode_fn=decode_fn,
2112+
snapshot_timesteps=rollout_snapshot_timesteps,
2113+
snapshot_dir=rollout_snapshot_dir,
2114+
snapshot_format=eval_cfg.get("rollout_snapshot_format", "png"),
2115+
snapshot_channels=rollout_snapshot_channels,
19382116
)
19392117

19402118
# Prepare metric functions for rollouts

src/autocast/utils/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from .optimizer import get_optimizer_config
2-
from .plots import plot_spatiotemporal_video
2+
from .plots import plot_spatiotemporal_snapshots, plot_spatiotemporal_video
33

4-
__all__ = ["get_optimizer_config", "plot_spatiotemporal_video"]
4+
__all__ = [
5+
"get_optimizer_config",
6+
"plot_spatiotemporal_snapshots",
7+
"plot_spatiotemporal_video",
8+
]

0 commit comments

Comments
 (0)