7575from autocast .scripts .training import apply_float32_matmul_precision
7676from autocast .scripts .utils import get_default_config_path
7777from autocast .types .batch import Batch , EncodedBatch
78- from autocast .utils import plot_spatiotemporal_video
78+ from autocast .utils import plot_spatiotemporal_snapshots , plot_spatiotemporal_video
7979from autocast .utils .plots import (
8080 compute_metrics_from_dataloader ,
8181 compute_metrics_per_timestep_from_dataloader ,
@@ -298,6 +298,13 @@ def _resolve_video_dir(eval_cfg: DictConfig, work_dir: Path) -> Path:
298298 return (work_dir / "videos" ).resolve ()
299299
300300
301+ def _resolve_rollout_snapshot_dir (eval_cfg : DictConfig , video_dir : Path ) -> Path :
302+ snapshot_dir = eval_cfg .get ("rollout_snapshot_dir" )
303+ if snapshot_dir is not None :
304+ return Path (snapshot_dir ).expanduser ().resolve ()
305+ return (video_dir / "snapshots" ).resolve ()
306+
307+
301308def _unwrap_module (module : Any ) -> Any :
302309 """Return the underlying model when wrapped by Fabric/DDP-style wrappers."""
303310 unwrapped = module
@@ -500,6 +507,125 @@ def _collect_rollout_sample_targets_for_batch(
500507 return sample_targets
501508
502509
510+ def _select_snapshot_timesteps (
511+ timesteps : Sequence [int ] | None ,
512+ n_timesteps : int ,
513+ ) -> list [int ]:
514+ if not timesteps :
515+ return []
516+
517+ selected : list [int ] = []
518+ invalid : list [int ] = []
519+ seen : set [int ] = set ()
520+ for raw_timestep in timesteps :
521+ timestep = int (raw_timestep )
522+ if 0 <= timestep < n_timesteps :
523+ if timestep not in seen :
524+ selected .append (timestep )
525+ seen .add (timestep )
526+ else :
527+ invalid .append (timestep )
528+
529+ if invalid :
530+ log .warning (
531+ "Ignoring rollout snapshot timestep(s) outside [0, %s]: %s" ,
532+ n_timesteps - 1 ,
533+ invalid ,
534+ )
535+ return selected
536+
537+
538+ def _select_snapshot_channels (
539+ channels : Sequence [int ] | None ,
540+ n_channels : int ,
541+ ) -> list [int ]:
542+ if channels is None :
543+ return list (range (n_channels ))
544+
545+ selected : list [int ] = []
546+ invalid : list [int ] = []
547+ seen : set [int ] = set ()
548+ for raw_channel in channels :
549+ channel = int (raw_channel )
550+ if 0 <= channel < n_channels :
551+ if channel not in seen :
552+ selected .append (channel )
553+ seen .add (channel )
554+ else :
555+ invalid .append (channel )
556+
557+ if invalid :
558+ log .warning (
559+ "Ignoring rollout snapshot channel(s) outside [0, %s]: %s" ,
560+ n_channels - 1 ,
561+ invalid ,
562+ )
563+ return selected
564+
565+
566+ def _prepare_rollout_snapshot_plan (
567+ * ,
568+ snapshot_dir : Path | None ,
569+ snapshot_timesteps : Sequence [int ] | None ,
570+ snapshot_channels : Sequence [int ] | None ,
571+ n_timesteps : int ,
572+ n_channels : int ,
573+ ) -> tuple [Path | None , list [int ], list [int ]]:
574+ if snapshot_dir is None :
575+ return None , [], []
576+
577+ selected_timesteps = _select_snapshot_timesteps (snapshot_timesteps , n_timesteps )
578+ if not selected_timesteps :
579+ return None , [], []
580+
581+ selected_channels = _select_snapshot_channels (snapshot_channels , n_channels )
582+ if not selected_channels :
583+ return None , [], []
584+
585+ snapshot_dir .mkdir (parents = True , exist_ok = True )
586+ return snapshot_dir , selected_timesteps , selected_channels
587+
588+
589+ def _save_rollout_snapshot_panels (
590+ * ,
591+ trues_mean : torch .Tensor ,
592+ preds_mean : torch .Tensor ,
593+ preds_uq : torch .Tensor | None ,
594+ local_idx : int ,
595+ target_idx : int ,
596+ snapshot_dir : Path ,
597+ snapshot_timesteps : Sequence [int ],
598+ snapshot_channels : Sequence [int ],
599+ snapshot_ext : str ,
600+ saved_paths : list [Path ],
601+ names_for_plot : list [str ] | None ,
602+ preserve_aspect : bool ,
603+ ) -> None :
604+ for channel_idx in snapshot_channels :
605+ snapshot_filename = snapshot_dir / (
606+ f"batch_{ target_idx } _sample_{ local_idx } _"
607+ f"channel_{ channel_idx } _snapshots.{ snapshot_ext } "
608+ )
609+ plot_spatiotemporal_snapshots (
610+ true = trues_mean [local_idx : local_idx + 1 ].cpu (),
611+ pred = preds_mean [local_idx : local_idx + 1 ].cpu (),
612+ pred_uq = (
613+ preds_uq [local_idx : local_idx + 1 ].cpu ()
614+ if preds_uq is not None
615+ else None
616+ ),
617+ timesteps = snapshot_timesteps ,
618+ channel = channel_idx ,
619+ batch_idx = 0 ,
620+ save_path = str (snapshot_filename ),
621+ title = "Rollout snapshots" ,
622+ channel_names = names_for_plot ,
623+ preserve_aspect = preserve_aspect ,
624+ )
625+ saved_paths .append (snapshot_filename )
626+ log .info ("Saved rollout snapshot panel to %s" , snapshot_filename )
627+
628+
503629def _render_rollouts ( # noqa: PLR0912
504630 model : (
505631 EncoderProcessorDecoder
@@ -520,6 +646,10 @@ def _render_rollouts( # noqa: PLR0912
520646 channel_names : list [str ] | None = None ,
521647 preserve_aspect : bool = False ,
522648 decode_fn : Callable | None = None ,
649+ snapshot_timesteps : Sequence [int ] | None = None ,
650+ snapshot_dir : Path | None = None ,
651+ snapshot_format : str = "png" ,
652+ snapshot_channels : Sequence [int ] | None = None ,
523653) -> list [Path ]:
524654 # Return early if no rollout indices are requested
525655 if not batch_indices :
@@ -533,6 +663,7 @@ def _render_rollouts( # noqa: PLR0912
533663 rendered_targets : set [int ] = set ()
534664 global_sample_offset = 0
535665 video_dir .mkdir (parents = True , exist_ok = True )
666+ snapshot_ext = snapshot_format .removeprefix ("." ) or "png"
536667
537668 # Perform rollouts and save videos for requested target indices.
538669 with torch .no_grad ():
@@ -599,6 +730,17 @@ def _render_rollouts( # noqa: PLR0912
599730 sample_index = sample_index ,
600731 global_sample_offset = global_sample_offset ,
601732 )
733+ (
734+ snapshot_dir_for_batch ,
735+ snapshot_timesteps_for_batch ,
736+ snapshot_channels_for_batch ,
737+ ) = _prepare_rollout_snapshot_plan (
738+ snapshot_dir = snapshot_dir ,
739+ snapshot_timesteps = snapshot_timesteps ,
740+ snapshot_channels = snapshot_channels ,
741+ n_timesteps = int (trues_mean .shape [1 ]),
742+ n_channels = n_channels ,
743+ )
602744
603745 for target_idx , local_idx in sample_targets .items ():
604746 if target_idx in rendered_targets :
@@ -627,6 +769,22 @@ def _render_rollouts( # noqa: PLR0912
627769 rendered_targets .add (target_idx )
628770 log .info ("Saved rollout visualization to %s" , filename )
629771
772+ if snapshot_dir_for_batch is not None :
773+ _save_rollout_snapshot_panels (
774+ trues_mean = trues_mean ,
775+ preds_mean = preds_mean ,
776+ preds_uq = preds_uq ,
777+ local_idx = local_idx ,
778+ target_idx = target_idx ,
779+ snapshot_dir = snapshot_dir_for_batch ,
780+ snapshot_timesteps = snapshot_timesteps_for_batch ,
781+ snapshot_channels = snapshot_channels_for_batch ,
782+ snapshot_ext = snapshot_ext ,
783+ saved_paths = saved_paths ,
784+ names_for_plot = names_for_plot ,
785+ preserve_aspect = preserve_aspect ,
786+ )
787+
630788 global_sample_offset += batch_size
631789
632790 # Check for any missing rollout sample indices that were requested
@@ -1920,6 +2078,22 @@ def coverage_factory() -> Metric:
19202078 rollout_test_loader ,
19212079 max_rollout_batches ,
19222080 )
2081+ save_rollout_snapshots = bool (eval_cfg .get ("save_rollout_snapshots" , False ))
2082+ rollout_snapshot_dir = (
2083+ _resolve_rollout_snapshot_dir (eval_cfg , video_dir )
2084+ if save_rollout_snapshots
2085+ else None
2086+ )
2087+ rollout_snapshot_timesteps = (
2088+ eval_cfg .get ("rollout_snapshot_timesteps" , [])
2089+ if save_rollout_snapshots
2090+ else None
2091+ )
2092+ rollout_snapshot_channels = (
2093+ eval_cfg .get ("rollout_snapshot_channels" , None )
2094+ if save_rollout_snapshots
2095+ else None
2096+ )
19232097 _render_rollouts (
19242098 model , # pyright: ignore[reportArgumentType]
19252099 rollout_loader ,
@@ -1935,6 +2109,10 @@ def coverage_factory() -> Metric:
19352109 channel_names = rollout_channel_names ,
19362110 preserve_aspect = eval_cfg .get ("preserve_aspect" , False ),
19372111 decode_fn = decode_fn ,
2112+ snapshot_timesteps = rollout_snapshot_timesteps ,
2113+ snapshot_dir = rollout_snapshot_dir ,
2114+ snapshot_format = eval_cfg .get ("rollout_snapshot_format" , "png" ),
2115+ snapshot_channels = rollout_snapshot_channels ,
19382116 )
19392117
19402118 # Prepare metric functions for rollouts
0 commit comments