|
137 | 137 |
|
138 | 138 | MEMORY_INTENSIVE_METRICS = {"variogram"} |
139 | 139 |
|
| 140 | +EVAL_MODES = ("auto", "ambient", "latent") |
| 141 | + |
| 142 | +# Resolved eval paths exposed for validation / testing. Each corresponds to |
| 143 | +# exactly one branch in `run_evaluation`'s model-selection block. |
| 144 | +EVAL_PATH_AMBIENT_EPD = "ambient_epd" # full EPD checkpoint or processor+AE |
| 145 | +EVAL_PATH_LATENT_CACHED_WITH_DECODER = "latent_cached_with_decoder" # Mode 2 |
| 146 | +EVAL_PATH_LATENT_CACHED_LATENT_ONLY = "latent_cached_latent_only" # fallback |
| 147 | + |
140 | 148 |
|
141 | 149 | def _decode_tensor( |
142 | 150 | x: torch.Tensor, |
@@ -1071,6 +1079,121 @@ def _load_autoencoder_config_from_cache(cache_dir: Path) -> DictConfig | None: |
1071 | 1079 | return None |
1072 | 1080 |
|
1073 | 1081 |
|
| 1082 | +def _normalize_eval_mode(mode: Any) -> str: |
| 1083 | + """Normalize and validate the eval.mode config value.""" |
| 1084 | + if mode is None: |
| 1085 | + return "auto" |
| 1086 | + mode_str = str(mode).strip().lower() |
| 1087 | + if mode_str not in EVAL_MODES: |
| 1088 | + msg = f"Unknown eval.mode={mode!r}. Valid values: {', '.join(EVAL_MODES)}." |
| 1089 | + raise ValueError(msg) |
| 1090 | + return mode_str |
| 1091 | + |
| 1092 | + |
| 1093 | +def _maybe_swap_to_ambient_datamodule( |
| 1094 | + cfg: DictConfig, |
| 1095 | + *, |
| 1096 | + eval_mode: str, |
| 1097 | + example_batch: Any, |
| 1098 | +) -> DictConfig: |
| 1099 | + """Substitute the raw-data datamodule from `autoencoder_config.yaml`. |
| 1100 | +
|
| 1101 | + When the user requests ``eval.mode=ambient`` but the current datamodule |
| 1102 | + yields ``EncodedBatch`` (cached latents), we cannot run encoder->processor |
| 1103 | + ->decoder in ambient space: the encoder needs raw fields. This helper |
| 1104 | + reads the ``autoencoder_config.yaml`` written next to the cached latents |
| 1105 | + by ``autocast cache-latents`` and overwrites ``cfg.datamodule`` with the |
| 1106 | + datamodule the autoencoder was trained on, which guarantees matching |
| 1107 | + normalization and field layout. |
| 1108 | +
|
| 1109 | + Returns the (possibly-modified) ``cfg`` in-place. Raises a descriptive |
| 1110 | + error when the swap is needed but ``autoencoder_config.yaml`` is absent; |
| 1111 | + callers should pass ``datamodule=...`` explicitly in that case. |
| 1112 | + """ |
| 1113 | + if eval_mode != "ambient" or not isinstance(example_batch, EncodedBatch): |
| 1114 | + return cfg |
| 1115 | + |
| 1116 | + data_path = cfg.get("datamodule", {}).get("data_path") |
| 1117 | + if not data_path: |
| 1118 | + msg = ( |
| 1119 | + "eval.mode=ambient requires a raw-data datamodule, but the current " |
| 1120 | + "datamodule yields EncodedBatch and has no data_path to locate the " |
| 1121 | + "original autoencoder config. Pass datamodule=<raw> explicitly." |
| 1122 | + ) |
| 1123 | + raise ValueError(msg) |
| 1124 | + |
| 1125 | + ae_cfg = _load_autoencoder_config_from_cache(Path(data_path)) |
| 1126 | + if ae_cfg is None: |
| 1127 | + msg = ( |
| 1128 | + "eval.mode=ambient requested but the cached-latents directory " |
| 1129 | + f"{data_path} has no 'autoencoder_config.yaml'. Either regenerate " |
| 1130 | + "the cache with a recent `autocast cache-latents` (which saves the " |
| 1131 | + "autoencoder config), or pass datamodule=<raw> explicitly." |
| 1132 | + ) |
| 1133 | + raise FileNotFoundError(msg) |
| 1134 | + |
| 1135 | + ae_datamodule = ae_cfg.get("datamodule") |
| 1136 | + if ae_datamodule is None: |
| 1137 | + msg = ( |
| 1138 | + f"autoencoder_config.yaml at {data_path} is missing a 'datamodule' " |
| 1139 | + "section; cannot auto-wire ambient eval. Pass datamodule=<raw> " |
| 1140 | + "explicitly." |
| 1141 | + ) |
| 1142 | + raise ValueError(msg) |
| 1143 | + |
| 1144 | + log.info( |
| 1145 | + "eval.mode=ambient: substituting cached_latents datamodule with the " |
| 1146 | + "raw-data datamodule from %s/autoencoder_config.yaml so the encoder " |
| 1147 | + "sees the same fields/normalization it was trained on.", |
| 1148 | + data_path, |
| 1149 | + ) |
| 1150 | + with open_dict(cfg): |
| 1151 | + cfg.datamodule = ae_datamodule |
| 1152 | + return cfg |
| 1153 | + |
| 1154 | + |
| 1155 | +def _resolve_eval_path( |
| 1156 | + *, |
| 1157 | + processor_only: bool, |
| 1158 | + example_batch: Any, |
| 1159 | + has_autoencoder_checkpoint: bool, |
| 1160 | + decode_fn_loaded: bool, |
| 1161 | +) -> str: |
| 1162 | + """Map the auto-detected branch in `run_evaluation` to a stable label.""" |
| 1163 | + if not processor_only: |
| 1164 | + return EVAL_PATH_AMBIENT_EPD |
| 1165 | + if isinstance(example_batch, Batch) and has_autoencoder_checkpoint: |
| 1166 | + return EVAL_PATH_AMBIENT_EPD |
| 1167 | + if decode_fn_loaded: |
| 1168 | + return EVAL_PATH_LATENT_CACHED_WITH_DECODER |
| 1169 | + return EVAL_PATH_LATENT_CACHED_LATENT_ONLY |
| 1170 | + |
| 1171 | + |
| 1172 | +def _validate_resolved_eval_path(*, eval_mode: str, resolved_path: str) -> None: |
| 1173 | + """Raise if the resolved code path disagrees with the user-requested mode.""" |
| 1174 | + if eval_mode == "auto": |
| 1175 | + return |
| 1176 | + if eval_mode == "ambient" and resolved_path != EVAL_PATH_AMBIENT_EPD: |
| 1177 | + msg = ( |
| 1178 | + "eval.mode=ambient but the resolved eval path is " |
| 1179 | + f"{resolved_path!r}. Ambient eval requires a full EPD checkpoint, " |
| 1180 | + "OR a processor-only checkpoint combined with " |
| 1181 | + "autoencoder_checkpoint=<ae.ckpt> AND a raw-Batch datamodule. " |
| 1182 | + "Double-check eval.checkpoint, autoencoder_checkpoint, and " |
| 1183 | + "datamodule=." |
| 1184 | + ) |
| 1185 | + raise ValueError(msg) |
| 1186 | + if eval_mode == "latent" and resolved_path == EVAL_PATH_AMBIENT_EPD: |
| 1187 | + msg = ( |
| 1188 | + "eval.mode=latent but the resolved eval path is " |
| 1189 | + f"{resolved_path!r}. Latent-space eval requires a processor-only " |
| 1190 | + "checkpoint paired with an EncodedBatch (cached_latents) " |
| 1191 | + "datamodule. Use datamodule=cached_latents and remove " |
| 1192 | + "autoencoder_checkpoint=, or switch to eval.mode=ambient/auto." |
| 1193 | + ) |
| 1194 | + raise ValueError(msg) |
| 1195 | + |
| 1196 | + |
1074 | 1197 | def _try_build_decode_fn( |
1075 | 1198 | cfg: DictConfig, |
1076 | 1199 | ) -> "tuple[Any, Any] | tuple[None, None]": |
@@ -1180,11 +1303,13 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no |
1180 | 1303 | eval_batch_size: int = eval_cfg.get("batch_size", 1) |
1181 | 1304 | max_test_batches = eval_cfg.get("max_test_batches") |
1182 | 1305 | max_rollout_batches = _resolve_rollout_batch_limit(eval_cfg) |
| 1306 | + eval_mode = _normalize_eval_mode(eval_cfg.get("mode", "auto")) |
1183 | 1307 | log.info( |
1184 | 1308 | "Batch limits: max_test_batches=%s, max_rollout_batches=%s", |
1185 | 1309 | max_test_batches, |
1186 | 1310 | max_rollout_batches, |
1187 | 1311 | ) |
| 1312 | + log.info("eval.mode=%s", eval_mode) |
1188 | 1313 |
|
1189 | 1314 | checkpoint_path = resolve_checkpoint_path( |
1190 | 1315 | eval_cfg, |
@@ -1220,6 +1345,19 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no |
1220 | 1345 | # Setup datamodule and resolve config |
1221 | 1346 | datamodule, cfg, stats = setup_datamodule(cfg) |
1222 | 1347 |
|
| 1348 | + # If the user asked for ambient eval but the resolved datamodule yields |
| 1349 | + # EncodedBatch (cached_latents), substitute the raw-data datamodule stored |
| 1350 | + # in the cache dir's autoencoder_config.yaml and rebuild. Honors an |
| 1351 | + # explicit `datamodule=...` override implicitly: when the override targets |
| 1352 | + # a raw-Batch datamodule the swap becomes a no-op. |
| 1353 | + cfg = _maybe_swap_to_ambient_datamodule( |
| 1354 | + cfg, |
| 1355 | + eval_mode=eval_mode, |
| 1356 | + example_batch=stats.get("example_batch"), |
| 1357 | + ) |
| 1358 | + if eval_mode == "ambient" and isinstance(stats.get("example_batch"), EncodedBatch): |
| 1359 | + datamodule, cfg, stats = setup_datamodule(cfg) |
| 1360 | + |
1223 | 1361 | # Override model n_members from eval config if specified |
1224 | 1362 | if "n_members" in eval_cfg: |
1225 | 1363 | with open_dict(cfg.model): |
@@ -1318,6 +1456,18 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no |
1318 | 1456 | ) |
1319 | 1457 | raise RuntimeError(msg) |
1320 | 1458 |
|
| 1459 | + resolved_eval_path = _resolve_eval_path( |
| 1460 | + processor_only=processor_only, |
| 1461 | + example_batch=example_batch, |
| 1462 | + has_autoencoder_checkpoint=bool(cfg.get("autoencoder_checkpoint")), |
| 1463 | + decode_fn_loaded=decode_fn is not None, |
| 1464 | + ) |
| 1465 | + log.info("Resolved eval path: %s", resolved_eval_path) |
| 1466 | + _validate_resolved_eval_path( |
| 1467 | + eval_mode=eval_mode, |
| 1468 | + resolved_path=resolved_eval_path, |
| 1469 | + ) |
| 1470 | + |
1321 | 1471 | # Get eval parameters from config |
1322 | 1472 | metrics_list = eval_cfg.get("metrics", DEFAULT_EVAL_METRICS) |
1323 | 1473 | batch_indices = eval_cfg.get("batch_indices", []) |
|
0 commit comments