Skip to content

Commit 9830e1d

Browse files
authored
Merge pull request #327 from alan-turing-institute/add-eval-modes
Add eval.mode selector for ambient vs latent rollout
2 parents b304223 + d53d411 commit 9830e1d

5 files changed

Lines changed: 527 additions & 1 deletion

File tree

src/autocast/configs/eval/README.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ python -m autocast.scripts.eval.encoder_processor_decoder \
4444
All eval configs support these parameters:
4545

4646
- `checkpoint`: Path to model checkpoint (required for evaluation)
47+
- `mode`: Evaluation regime (`auto` | `ambient` | `latent`). Controls the
48+
**rollout space**, not just the metrics space. See
49+
[Ambient vs latent rollout](#ambient-vs-latent-rollout) below.
4750
- `metrics`: List of metrics to compute (default includes mse/mae/rmse/vrmse,
4851
power spectrum scores `psrmse*`, cross-correlation spectrum scores `pscc*`,
4952
and ensemble scores `crps`, `fcrps`, `afcrps`, `energy`, `ssr`; `variogram`
@@ -79,3 +82,51 @@ On SLURM, `srun` propagates `LOCAL_RANK` / `WORLD_SIZE` into the
7982
process so Fabric DDP initialises automatically — no extra flags needed.
8083
- `max_rollout_steps`: Maximum number of rollout steps
8184
- `free_running_only`: Whether to disable teacher forcing
85+
86+
## Ambient vs latent rollout
87+
88+
Processor checkpoints trained on cached latents can be evaluated in two
89+
qualitatively different regimes. The `eval.mode` knob makes the choice
90+
explicit and surfaces clear errors when the rest of the config is
91+
inconsistent with the request.
92+
93+
- `eval.mode=auto` (default) preserves historical behavior: the script picks
94+
a path based on `(checkpoint type, datamodule batch type,
95+
autoencoder_checkpoint)`.
96+
- `eval.mode=ambient` forces full `encoder -> processor -> decoder` rollout.
97+
Each rollout step decodes to ambient fields and re-encodes on the next
98+
step, so decode/encode drift is included in the metrics. **This is the
99+
apples-to-apples regime for comparing against baselines that natively roll
100+
out in data space (e.g. a CRPS comparison against a non-autoencoder
101+
model).** Requires `autoencoder_checkpoint=<ae.ckpt>` and a raw-Batch
102+
datamodule. When the current datamodule yields `EncodedBatch` (cached
103+
latents), eval auto-substitutes the datamodule from
104+
`<cache_dir>/autoencoder_config.yaml` saved by `autocast cache-latents`.
105+
Pass `datamodule=...` explicitly to override the default.
106+
- `eval.mode=latent` forces latent-space rollout: the processor's predicted
107+
latent is fed back as the next latent input; the encoder is invoked only
108+
once. Metrics are decoded to data space via the decoder saved alongside
109+
the cached latents when available, otherwise they are reported in latent
110+
space. Requires an `EncodedBatch` / cached-latents datamodule.
111+
112+
### Running the ambient ablation
113+
114+
Given an autoencoder checkpoint and a processor checkpoint trained on its
115+
cached latents, a minimal invocation is:
116+
117+
```bash
118+
# Ambient (encoder -> processor -> decoder at every rollout step)
119+
autocast eval --workdir <processor_workdir> \
120+
eval.mode=ambient \
121+
eval.checkpoint=<processor.ckpt> \
122+
autoencoder_checkpoint=<autoencoder.ckpt>
123+
124+
# Latent (processor rollout stays in latent space; decoded only for metrics)
125+
autocast eval --workdir <processor_workdir> \
126+
eval.mode=latent \
127+
eval.checkpoint=<processor.ckpt>
128+
```
129+
130+
The ambient run will differ from the latent run by exactly the
131+
decode/encode drift accumulated over rollout steps, which is the relevant
132+
delta when comparing against purely-ambient baselines.

src/autocast/configs/eval/default.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,26 @@
22
# Path to checkpoint for evaluation (required for eval)
33
checkpoint: null
44

5+
# Evaluation mode selector (controls rollout space, not just metrics space).
6+
#
7+
# auto (default) infer from checkpoint type + batch type + autoencoder_checkpoint.
8+
# Preserves historical behavior.
9+
# ambient Force full encoder -> processor -> decoder rollout. Each rollout step
10+
# decodes and re-encodes, so decode/encode drift is included in the
11+
# metrics -- this is the apples-to-apples regime for comparing against
12+
# models that natively roll out in ambient/data space (e.g. CRPS baselines).
13+
# Requires `autoencoder_checkpoint=<ae.ckpt>` and a raw-Batch datamodule.
14+
# When the datamodule yields EncodedBatch (cached latents), the eval
15+
# script auto-substitutes the datamodule from
16+
# `<cache_dir>/autoencoder_config.yaml` written by `autocast cache-latents`.
17+
# Pass `datamodule=...` explicitly to override that default.
18+
# latent Force latent-space rollout (processor predictions are fed back as
19+
# latents; encoder is not re-invoked). Metrics are decoded to data
20+
# space via the decoder saved alongside the cached latents if
21+
# available, otherwise computed in latent space. Requires an
22+
# EncodedBatch datamodule (cached latents).
23+
mode: auto
24+
525
# Evaluation metrics to compute
626
metrics:
727
- mse

src/autocast/scripts/eval/encoder_processor_decoder.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@
137137

138138
MEMORY_INTENSIVE_METRICS = {"variogram"}
139139

140+
EVAL_MODES = ("auto", "ambient", "latent")
141+
142+
# Resolved eval paths exposed for validation / testing. Each corresponds to
143+
# exactly one branch in `run_evaluation`'s model-selection block.
144+
EVAL_PATH_AMBIENT_EPD = "ambient_epd" # full EPD checkpoint or processor+AE
145+
EVAL_PATH_LATENT_CACHED_WITH_DECODER = "latent_cached_with_decoder" # Mode 2
146+
EVAL_PATH_LATENT_CACHED_LATENT_ONLY = "latent_cached_latent_only" # fallback
147+
140148

141149
def _decode_tensor(
142150
x: torch.Tensor,
@@ -1071,6 +1079,121 @@ def _load_autoencoder_config_from_cache(cache_dir: Path) -> DictConfig | None:
10711079
return None
10721080

10731081

1082+
def _normalize_eval_mode(mode: Any) -> str:
1083+
"""Normalize and validate the eval.mode config value."""
1084+
if mode is None:
1085+
return "auto"
1086+
mode_str = str(mode).strip().lower()
1087+
if mode_str not in EVAL_MODES:
1088+
msg = f"Unknown eval.mode={mode!r}. Valid values: {', '.join(EVAL_MODES)}."
1089+
raise ValueError(msg)
1090+
return mode_str
1091+
1092+
1093+
def _maybe_swap_to_ambient_datamodule(
1094+
cfg: DictConfig,
1095+
*,
1096+
eval_mode: str,
1097+
example_batch: Any,
1098+
) -> DictConfig:
1099+
"""Substitute the raw-data datamodule from `autoencoder_config.yaml`.
1100+
1101+
When the user requests ``eval.mode=ambient`` but the current datamodule
1102+
yields ``EncodedBatch`` (cached latents), we cannot run encoder->processor
1103+
->decoder in ambient space: the encoder needs raw fields. This helper
1104+
reads the ``autoencoder_config.yaml`` written next to the cached latents
1105+
by ``autocast cache-latents`` and overwrites ``cfg.datamodule`` with the
1106+
datamodule the autoencoder was trained on, which guarantees matching
1107+
normalization and field layout.
1108+
1109+
Returns the (possibly-modified) ``cfg`` in-place. Raises a descriptive
1110+
error when the swap is needed but ``autoencoder_config.yaml`` is absent;
1111+
callers should pass ``datamodule=...`` explicitly in that case.
1112+
"""
1113+
if eval_mode != "ambient" or not isinstance(example_batch, EncodedBatch):
1114+
return cfg
1115+
1116+
data_path = cfg.get("datamodule", {}).get("data_path")
1117+
if not data_path:
1118+
msg = (
1119+
"eval.mode=ambient requires a raw-data datamodule, but the current "
1120+
"datamodule yields EncodedBatch and has no data_path to locate the "
1121+
"original autoencoder config. Pass datamodule=<raw> explicitly."
1122+
)
1123+
raise ValueError(msg)
1124+
1125+
ae_cfg = _load_autoencoder_config_from_cache(Path(data_path))
1126+
if ae_cfg is None:
1127+
msg = (
1128+
"eval.mode=ambient requested but the cached-latents directory "
1129+
f"{data_path} has no 'autoencoder_config.yaml'. Either regenerate "
1130+
"the cache with a recent `autocast cache-latents` (which saves the "
1131+
"autoencoder config), or pass datamodule=<raw> explicitly."
1132+
)
1133+
raise FileNotFoundError(msg)
1134+
1135+
ae_datamodule = ae_cfg.get("datamodule")
1136+
if ae_datamodule is None:
1137+
msg = (
1138+
f"autoencoder_config.yaml at {data_path} is missing a 'datamodule' "
1139+
"section; cannot auto-wire ambient eval. Pass datamodule=<raw> "
1140+
"explicitly."
1141+
)
1142+
raise ValueError(msg)
1143+
1144+
log.info(
1145+
"eval.mode=ambient: substituting cached_latents datamodule with the "
1146+
"raw-data datamodule from %s/autoencoder_config.yaml so the encoder "
1147+
"sees the same fields/normalization it was trained on.",
1148+
data_path,
1149+
)
1150+
with open_dict(cfg):
1151+
cfg.datamodule = ae_datamodule
1152+
return cfg
1153+
1154+
1155+
def _resolve_eval_path(
1156+
*,
1157+
processor_only: bool,
1158+
example_batch: Any,
1159+
has_autoencoder_checkpoint: bool,
1160+
decode_fn_loaded: bool,
1161+
) -> str:
1162+
"""Map the auto-detected branch in `run_evaluation` to a stable label."""
1163+
if not processor_only:
1164+
return EVAL_PATH_AMBIENT_EPD
1165+
if isinstance(example_batch, Batch) and has_autoencoder_checkpoint:
1166+
return EVAL_PATH_AMBIENT_EPD
1167+
if decode_fn_loaded:
1168+
return EVAL_PATH_LATENT_CACHED_WITH_DECODER
1169+
return EVAL_PATH_LATENT_CACHED_LATENT_ONLY
1170+
1171+
1172+
def _validate_resolved_eval_path(*, eval_mode: str, resolved_path: str) -> None:
1173+
"""Raise if the resolved code path disagrees with the user-requested mode."""
1174+
if eval_mode == "auto":
1175+
return
1176+
if eval_mode == "ambient" and resolved_path != EVAL_PATH_AMBIENT_EPD:
1177+
msg = (
1178+
"eval.mode=ambient but the resolved eval path is "
1179+
f"{resolved_path!r}. Ambient eval requires a full EPD checkpoint, "
1180+
"OR a processor-only checkpoint combined with "
1181+
"autoencoder_checkpoint=<ae.ckpt> AND a raw-Batch datamodule. "
1182+
"Double-check eval.checkpoint, autoencoder_checkpoint, and "
1183+
"datamodule=."
1184+
)
1185+
raise ValueError(msg)
1186+
if eval_mode == "latent" and resolved_path == EVAL_PATH_AMBIENT_EPD:
1187+
msg = (
1188+
"eval.mode=latent but the resolved eval path is "
1189+
f"{resolved_path!r}. Latent-space eval requires a processor-only "
1190+
"checkpoint paired with an EncodedBatch (cached_latents) "
1191+
"datamodule. Use datamodule=cached_latents and remove "
1192+
"autoencoder_checkpoint=, or switch to eval.mode=ambient/auto."
1193+
)
1194+
raise ValueError(msg)
1195+
1196+
10741197
def _try_build_decode_fn(
10751198
cfg: DictConfig,
10761199
) -> "tuple[Any, Any] | tuple[None, None]":
@@ -1180,11 +1303,13 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no
11801303
eval_batch_size: int = eval_cfg.get("batch_size", 1)
11811304
max_test_batches = eval_cfg.get("max_test_batches")
11821305
max_rollout_batches = _resolve_rollout_batch_limit(eval_cfg)
1306+
eval_mode = _normalize_eval_mode(eval_cfg.get("mode", "auto"))
11831307
log.info(
11841308
"Batch limits: max_test_batches=%s, max_rollout_batches=%s",
11851309
max_test_batches,
11861310
max_rollout_batches,
11871311
)
1312+
log.info("eval.mode=%s", eval_mode)
11881313

11891314
checkpoint_path = resolve_checkpoint_path(
11901315
eval_cfg,
@@ -1220,6 +1345,19 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no
12201345
# Setup datamodule and resolve config
12211346
datamodule, cfg, stats = setup_datamodule(cfg)
12221347

1348+
# If the user asked for ambient eval but the resolved datamodule yields
1349+
# EncodedBatch (cached_latents), substitute the raw-data datamodule stored
1350+
# in the cache dir's autoencoder_config.yaml and rebuild. Honors an
1351+
# explicit `datamodule=...` override implicitly: when the override targets
1352+
# a raw-Batch datamodule the swap becomes a no-op.
1353+
cfg = _maybe_swap_to_ambient_datamodule(
1354+
cfg,
1355+
eval_mode=eval_mode,
1356+
example_batch=stats.get("example_batch"),
1357+
)
1358+
if eval_mode == "ambient" and isinstance(stats.get("example_batch"), EncodedBatch):
1359+
datamodule, cfg, stats = setup_datamodule(cfg)
1360+
12231361
# Override model n_members from eval config if specified
12241362
if "n_members" in eval_cfg:
12251363
with open_dict(cfg.model):
@@ -1318,6 +1456,18 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no
13181456
)
13191457
raise RuntimeError(msg)
13201458

1459+
resolved_eval_path = _resolve_eval_path(
1460+
processor_only=processor_only,
1461+
example_batch=example_batch,
1462+
has_autoencoder_checkpoint=bool(cfg.get("autoencoder_checkpoint")),
1463+
decode_fn_loaded=decode_fn is not None,
1464+
)
1465+
log.info("Resolved eval path: %s", resolved_eval_path)
1466+
_validate_resolved_eval_path(
1467+
eval_mode=eval_mode,
1468+
resolved_path=resolved_eval_path,
1469+
)
1470+
13211471
# Get eval parameters from config
13221472
metrics_list = eval_cfg.get("metrics", DEFAULT_EVAL_METRICS)
13231473
batch_indices = eval_cfg.get("batch_indices", [])

tests/models/test_encoder_processor_decoder.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,86 @@ def test_encoder_processor_decoder_rollout_handles_short_trajectory(
265265
# Ground truth only for windows where data was available
266266
assert gts is not None
267267
assert gts.shape == (batch_size, expected_gt_windows * n_steps_output, 32, 32, 1)
268+
269+
270+
class CountingPermuteConcat(PermuteConcat):
271+
"""PermuteConcat encoder that tracks how many times ``encode`` is called."""
272+
273+
def __init__(
274+
self, in_channels: int, n_steps_input: int, with_constants: bool = False
275+
) -> None:
276+
super().__init__(
277+
in_channels=in_channels,
278+
n_steps_input=n_steps_input,
279+
with_constants=with_constants,
280+
)
281+
self.encode_calls = 0
282+
283+
def encode(self, batch: Batch) -> Tensor: # type: ignore[override]
284+
self.encode_calls += 1
285+
return super().encode(batch)
286+
287+
288+
def test_encoder_processor_decoder_rollout_re_encodes_each_step(make_toy_batch):
289+
"""Ambient rollout must re-invoke the encoder at every rollout step.
290+
291+
This is the invariant the whole ``eval.mode=ambient`` path rests on: in
292+
ambient rollout each step decodes the prediction and re-encodes it as
293+
the next input, so decode/encode drift accumulates. If a future refactor
294+
ever collapsed this into a latent-only loop, latent and ambient eval
295+
would silently report the same numbers and ambient-vs-latent ablations
296+
would be meaningless. This test pins the contract.
297+
"""
298+
max_rollout_steps = 3
299+
n_steps_input = 2
300+
n_steps_output = 2
301+
stride = 2
302+
batch_size = 2
303+
trajectory_length = 20
304+
305+
batch = make_toy_batch(
306+
batch_size=batch_size,
307+
t_in=n_steps_input,
308+
t_out=trajectory_length - n_steps_input,
309+
)
310+
output_channels = batch.output_fields.shape[-1]
311+
merged_input_channels = output_channels * n_steps_input
312+
merged_output_channels = output_channels * n_steps_output
313+
314+
encoder = CountingPermuteConcat(
315+
in_channels=output_channels,
316+
n_steps_input=n_steps_input,
317+
with_constants=False,
318+
)
319+
decoder = ChannelsLast(output_channels=output_channels, time_steps=n_steps_output)
320+
loss = nn.MSELoss()
321+
encoder_decoder = EncoderDecoder(encoder=encoder, decoder=decoder, loss_func=loss)
322+
processor = TinyProcessor(
323+
in_channels=merged_input_channels, out_channels=merged_output_channels
324+
)
325+
model = EncoderProcessorDecoder(
326+
encoder_decoder=encoder_decoder,
327+
processor=processor,
328+
loss_func=loss,
329+
optimizer_config=get_optimizer_config(),
330+
stride=stride,
331+
max_rollout_steps=max_rollout_steps,
332+
)
333+
model.eval()
334+
335+
calls_before = encoder.encode_calls
336+
preds, _ = model.rollout(
337+
batch,
338+
stride=stride,
339+
max_rollout_steps=max_rollout_steps,
340+
free_running_only=True,
341+
)
342+
calls_during = encoder.encode_calls - calls_before
343+
344+
assert calls_during >= max_rollout_steps, (
345+
"Ambient rollout must invoke the encoder at least once per rollout "
346+
f"step; got {calls_during} encode calls for "
347+
f"{max_rollout_steps} rollout steps."
348+
)
349+
assert preds.shape[0] == batch_size
350+
assert preds.shape[1] == max_rollout_steps * n_steps_output

0 commit comments

Comments
 (0)