diff --git a/src/autocast/scripts/workflow/constants.py b/src/autocast/scripts/workflow/constants.py index 3744ad28..330ca526 100644 --- a/src/autocast/scripts/workflow/constants.py +++ b/src/autocast/scripts/workflow/constants.py @@ -28,6 +28,7 @@ "conditioned_navier_stokes": "cns64", "gpe_low_complexity": "gpelc64", "gpe_high_complexity": "gpehc64", + "gpe_laser_only_wake": "gpe64", "shallow_water2d": "sw2d64", "shallow_water2d_4": "sw2d464", } diff --git a/src/autocast/scripts/workflow/naming.py b/src/autocast/scripts/workflow/naming.py index ae50f277..c34084c0 100644 --- a/src/autocast/scripts/workflow/naming.py +++ b/src/autocast/scripts/workflow/naming.py @@ -122,6 +122,57 @@ def _preset_overrides_for_naming(overrides: list[str]) -> list[str]: return hints +def _unquote(value: str) -> str: + return value.strip().strip('"').strip("'") + + +def _dataset_key_from_data_path(data_path: str) -> str | None: + """Infer canonical dataset key from a filesystem data path.""" + normalized = Path(_unquote(data_path)) + dataset_dir = normalized.name + + for key in sorted(DATASET_NAME_TOKENS, key=len, reverse=True): + if dataset_dir == key or dataset_dir.startswith(f"{key}_"): + return key + + if ( + len(normalized.parts) >= 2 + and normalized.parts[-2] == "gpe" + and dataset_dir.startswith("laser_only_wake") + ): + return "gpe_laser_only_wake" + + return None + + +def _dataset_key_from_cached_latents(cache_path: str) -> str | None: # noqa: PLR0911 + """Infer source dataset key from a cached-latents directory.""" + cache_dir = Path(_unquote(cache_path)).expanduser() + ae_config = cache_dir / "autoencoder_config.yaml" + if not ae_config.exists(): + return None + + loaded = OmegaConf.to_container(OmegaConf.load(ae_config), resolve=True) + if not isinstance(loaded, dict): + return None + + datamodule_cfg = loaded.get("datamodule") + if isinstance(datamodule_cfg, str): + return datamodule_cfg + if not isinstance(datamodule_cfg, dict): + return None + + dataset_name = datamodule_cfg.get("dataset") + if isinstance(dataset_name, str) and dataset_name: + return dataset_name + + source_data_path = datamodule_cfg.get("data_path") + if isinstance(source_data_path, str): + return _dataset_key_from_data_path(source_data_path) + + return None + + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- @@ -130,7 +181,18 @@ def _preset_overrides_for_naming(overrides: list[str]) -> list[str]: def dataset_name_token(dataset: str, overrides: list[str]) -> str: """Short token for *dataset* used in auto-generated run names.""" datamodule_cfg = extract_override_value(overrides, "datamodule") or dataset - return sanitize_name_part(DATASET_NAME_TOKENS.get(datamodule_cfg, datamodule_cfg)) + data_path_override = extract_override_value(overrides, "datamodule.data_path") + + dataset_key = datamodule_cfg + if datamodule_cfg == "cached_latents" and data_path_override: + inferred = _dataset_key_from_cached_latents(data_path_override) + if inferred: + dataset_key = inferred + + elif inferred_from_path := _dataset_key_from_data_path(data_path_override): + dataset_key = inferred_from_path + + return sanitize_name_part(DATASET_NAME_TOKENS.get(dataset_key, dataset_key)) def auto_run_name(kind: str, dataset: str, overrides: list[str]) -> str: diff --git a/tests/scripts/test_workflow.py b/tests/scripts/test_workflow.py index 1d610e42..3d2c95a0 100644 --- a/tests/scripts/test_workflow.py +++ b/tests/scripts/test_workflow.py @@ -220,6 +220,29 @@ def test_dataset_name_token_datamodule_override_takes_precedence(): assert dataset_name_token("something_else", overrides) == "rd64" +def test_dataset_name_token_handles_gpe_laser_only_wake_alias(): + assert dataset_name_token("gpe_laser_only_wake", []) == "gpe64" + + +def test_dataset_name_token_ignores_data_path_when_not_cached_latents(): + overrides = ["datamodule.data_path=/tmp/datasets/reaction_diffusion_e3e8515"] + assert dataset_name_token("something_else", overrides) == "something_else" + + +def test_dataset_name_token_cached_latents_uses_saved_autoencoder_dataset(tmp_path): + cached_dir = tmp_path / "cached" + cached_dir.mkdir(parents=True) + (cached_dir / "autoencoder_config.yaml").write_text( + "datamodule:\n data_path: /tmp/datasets/reaction_diffusion_e3e8515\n", + encoding="utf-8", + ) + overrides = [ + "datamodule=cached_latents", + f"datamodule.data_path={cached_dir}", + ] + assert dataset_name_token("cached_latents", overrides) == "rd64" + + def test_auto_run_name_ae(): with ( patch("autocast.scripts.workflow.naming._git_hash", return_value="abc1234"),