Skip to content

Commit 415a3e0

Browse files
committed
Respect explicit grug checkpoint paths
1 parent 0eecf1c commit 415a3e0

2 files changed

Lines changed: 36 additions & 0 deletions

File tree

experiments/grug/checkpointing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def _get_fs_and_plain_path(path: str) -> tuple[AbstractFileSystem, str]:
2727

2828

2929
def _checkpoint_candidates(checkpoint_path: str, *, additional_paths: list[str] | None = None) -> list[str]:
30+
if _is_checkpoint_dir(checkpoint_path):
31+
return [checkpoint_path]
32+
3033
all_roots = [checkpoint_path] + (additional_paths or [])
3134

3235
candidates: list[tuple[int, str, str]] = []
@@ -41,6 +44,11 @@ def _checkpoint_candidates(checkpoint_path: str, *, additional_paths: list[str]
4144
return ordered_candidates
4245

4346

47+
def _is_checkpoint_dir(checkpoint_path: str) -> bool:
48+
fs, plain_path = _get_fs_and_plain_path(checkpoint_path)
49+
return fs.exists(os.path.join(plain_path, "metadata.json"))
50+
51+
4452
def _scan_checkpoint_root(root_path: str) -> list[tuple[int, str, str]]:
4553
"""Scan a single root path and return (step, timestamp, path) tuples."""
4654
fs, plain_path = _get_fs_and_plain_path(root_path)

tests/test_grug_checkpointing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,34 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
122122
assert loaded == {"loaded_from": str(temp_root / "step-150")}
123123

124124

125+
def test_restore_respects_explicit_checkpoint_path_with_additional_paths(tmp_path: Path):
126+
permanent_root = tmp_path / "checkpoints"
127+
temp_root = tmp_path / "checkpoints-temp"
128+
explicit_checkpoint = permanent_root / "step-100"
129+
130+
_write_checkpoint_metadata(explicit_checkpoint, step=100, timestamp="2026-03-17T00:00:00")
131+
_write_checkpoint_metadata(temp_root / "step-150", step=150, timestamp="2026-03-17T06:00:00")
132+
133+
attempted: list[str] = []
134+
135+
def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial):
136+
attempted.append(path)
137+
return {"loaded_from": path}
138+
139+
loaded = restore_grug_state_from_checkpoint(
140+
{"state": "init"},
141+
checkpoint_path=str(explicit_checkpoint),
142+
load_checkpoint_setting=True,
143+
mesh=None,
144+
allow_partial=False,
145+
additional_checkpoint_paths=[str(temp_root)],
146+
_load_fn=fake_load,
147+
)
148+
149+
assert attempted == [str(explicit_checkpoint)]
150+
assert loaded == {"loaded_from": str(explicit_checkpoint)}
151+
152+
125153
def test_restore_falls_back_from_temp_to_permanent(tmp_path: Path):
126154
permanent_root = tmp_path / "checkpoints"
127155
temp_root = tmp_path / "checkpoints-temp"

0 commit comments

Comments
 (0)