Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/autocast/scripts/workflow/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"conditioned_navier_stokes": "cns64",
"gpe_low_complexity": "gpelc64",
"gpe_high_complexity": "gpehc64",
"gpe_laser_only_wake": "gpe64",
"shallow_water2d": "sw2d64",
"shallow_water2d_4": "sw2d464",
}
64 changes: 63 additions & 1 deletion src/autocast/scripts/workflow/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,57 @@ def _preset_overrides_for_naming(overrides: list[str]) -> list[str]:
return hints


def _unquote(value: str) -> str:
return value.strip().strip('"').strip("'")


def _dataset_key_from_data_path(data_path: str) -> str | None:
"""Infer canonical dataset key from a filesystem data path."""
normalized = Path(_unquote(data_path))
dataset_dir = normalized.name

for key in sorted(DATASET_NAME_TOKENS, key=len, reverse=True):
if dataset_dir == key or dataset_dir.startswith(f"{key}_"):
return key

if (
len(normalized.parts) >= 2
and normalized.parts[-2] == "gpe"
and dataset_dir.startswith("laser_only_wake")
):
return "gpe_laser_only_wake"

return None


def _dataset_key_from_cached_latents(cache_path: str) -> str | None: # noqa: PLR0911
"""Infer source dataset key from a cached-latents directory."""
cache_dir = Path(_unquote(cache_path)).expanduser()
ae_config = cache_dir / "autoencoder_config.yaml"
if not ae_config.exists():
return None

loaded = OmegaConf.to_container(OmegaConf.load(ae_config), resolve=True)
if not isinstance(loaded, dict):
return None

datamodule_cfg = loaded.get("datamodule")
if isinstance(datamodule_cfg, str):
return datamodule_cfg
if not isinstance(datamodule_cfg, dict):
return None

dataset_name = datamodule_cfg.get("dataset")
if isinstance(dataset_name, str) and dataset_name:
return dataset_name

source_data_path = datamodule_cfg.get("data_path")
if isinstance(source_data_path, str):
return _dataset_key_from_data_path(source_data_path)

return None


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
Expand All @@ -130,7 +181,18 @@ def _preset_overrides_for_naming(overrides: list[str]) -> list[str]:
def dataset_name_token(dataset: str, overrides: list[str]) -> str:
"""Short token for *dataset* used in auto-generated run names."""
datamodule_cfg = extract_override_value(overrides, "datamodule") or dataset
return sanitize_name_part(DATASET_NAME_TOKENS.get(datamodule_cfg, datamodule_cfg))
data_path_override = extract_override_value(overrides, "datamodule.data_path")

dataset_key = datamodule_cfg
if datamodule_cfg == "cached_latents" and data_path_override:
inferred = _dataset_key_from_cached_latents(data_path_override)
if inferred:
dataset_key = inferred

elif inferred_from_path := _dataset_key_from_data_path(data_path_override):
dataset_key = inferred_from_path

return sanitize_name_part(DATASET_NAME_TOKENS.get(dataset_key, dataset_key))


def auto_run_name(kind: str, dataset: str, overrides: list[str]) -> str:
Expand Down
23 changes: 23 additions & 0 deletions tests/scripts/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,29 @@ def test_dataset_name_token_datamodule_override_takes_precedence():
assert dataset_name_token("something_else", overrides) == "rd64"


def test_dataset_name_token_handles_gpe_laser_only_wake_alias():
assert dataset_name_token("gpe_laser_only_wake", []) == "gpe64"


def test_dataset_name_token_ignores_data_path_when_not_cached_latents():
overrides = ["datamodule.data_path=/tmp/datasets/reaction_diffusion_e3e8515"]
assert dataset_name_token("something_else", overrides) == "something_else"


def test_dataset_name_token_cached_latents_uses_saved_autoencoder_dataset(tmp_path):
cached_dir = tmp_path / "cached"
cached_dir.mkdir(parents=True)
(cached_dir / "autoencoder_config.yaml").write_text(
"datamodule:\n data_path: /tmp/datasets/reaction_diffusion_e3e8515\n",
encoding="utf-8",
)
overrides = [
"datamodule=cached_latents",
f"datamodule.data_path={cached_dir}",
]
assert dataset_name_token("cached_latents", overrides) == "rd64"


def test_auto_run_name_ae():
with (
patch("autocast.scripts.workflow.naming._git_hash", return_value="abc1234"),
Expand Down
Loading