Skip to content

Commit a93bf7a

Browse files
committed
Use explicit grug checkpoint search paths
1 parent 415a3e0 commit a93bf7a

5 files changed

Lines changed: 58 additions & 67 deletions

File tree

experiments/grug/base/train.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -372,20 +372,21 @@ def _init_state(model_rng):
372372
state = _init_state(model_key)
373373

374374
checkpointer = trainer.checkpointer.create(run_id)
375-
checkpoint_path = trainer.load_checkpoint_path
376-
if checkpoint_path is None and checkpointer is not None:
377-
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
378-
additional_checkpoint_paths = []
379-
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
380-
if temp_path is not None:
381-
additional_checkpoint_paths.append(temp_path)
375+
if trainer.load_checkpoint_path is not None:
376+
checkpoint_search_paths = [trainer.load_checkpoint_path]
377+
elif checkpointer is not None:
378+
checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)]
379+
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
380+
if temp_path is not None:
381+
checkpoint_search_paths.append(temp_path)
382+
else:
383+
checkpoint_search_paths = []
382384
state = restore_grug_state_from_checkpoint(
383385
state,
384-
checkpoint_path=checkpoint_path,
386+
checkpoint_search_paths=checkpoint_search_paths,
385387
load_checkpoint_setting=trainer.load_checkpoint,
386388
mesh=mesh,
387389
allow_partial=trainer.allow_partial_checkpoint,
388-
additional_checkpoint_paths=additional_checkpoint_paths,
389390
)
390391

391392
levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)})

experiments/grug/checkpointing.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import os
99
import urllib.parse
10-
from collections.abc import Callable
10+
from collections.abc import Callable, Sequence
1111
from typing import TypeVar
1212

1313
import fsspec
@@ -26,29 +26,20 @@ def _get_fs_and_plain_path(path: str) -> tuple[AbstractFileSystem, str]:
2626
return fs, plain_path
2727

2828

29-
def _checkpoint_candidates(checkpoint_path: str, *, additional_paths: list[str] | None = None) -> list[str]:
30-
if _is_checkpoint_dir(checkpoint_path):
31-
return [checkpoint_path]
32-
33-
all_roots = [checkpoint_path] + (additional_paths or [])
34-
29+
def _checkpoint_candidates(checkpoint_search_paths: Sequence[str]) -> list[str]:
3530
candidates: list[tuple[int, str, str]] = []
36-
for root in all_roots:
37-
candidates.extend(_scan_checkpoint_root(root))
31+
for search_path in checkpoint_search_paths:
32+
candidates.extend(_scan_checkpoint_root(search_path))
3833

3934
candidates.sort(key=lambda item: (item[0], item[1]), reverse=True)
4035
ordered_candidates = [candidate for _, _, candidate in candidates]
41-
if checkpoint_path not in ordered_candidates:
42-
ordered_candidates.append(checkpoint_path)
4336

37+
for search_path in checkpoint_search_paths:
38+
if search_path not in ordered_candidates:
39+
ordered_candidates.append(search_path)
4440
return ordered_candidates
4541

4642

47-
def _is_checkpoint_dir(checkpoint_path: str) -> bool:
48-
fs, plain_path = _get_fs_and_plain_path(checkpoint_path)
49-
return fs.exists(os.path.join(plain_path, "metadata.json"))
50-
51-
5243
def _scan_checkpoint_root(root_path: str) -> list[tuple[int, str, str]]:
5344
"""Scan a single root path and return (step, timestamp, path) tuples."""
5445
fs, plain_path = _get_fs_and_plain_path(root_path)
@@ -91,22 +82,21 @@ def maybe_unstrip_protocol(path: str) -> str:
9182
def restore_grug_state_from_checkpoint(
9283
state: StateT,
9384
*,
94-
checkpoint_path: str | None,
85+
checkpoint_search_paths: Sequence[str],
9586
load_checkpoint_setting: bool | None,
9687
mesh: jax.sharding.Mesh | None,
9788
allow_partial: bool,
98-
additional_checkpoint_paths: list[str] | None = None,
9989
_load_fn: Callable[..., StateT] = load_checkpoint,
10090
) -> StateT:
101-
if checkpoint_path is None:
91+
if not checkpoint_search_paths:
10292
if load_checkpoint_setting:
103-
raise FileNotFoundError("load_checkpoint=True but no checkpoint path is configured.")
93+
raise FileNotFoundError("load_checkpoint=True but no checkpoint search paths are configured.")
10494
return state
10595

10696
if load_checkpoint_setting is False:
10797
return state
10898

109-
candidates = _checkpoint_candidates(checkpoint_path, additional_paths=additional_checkpoint_paths or [])
99+
candidates = _checkpoint_candidates(checkpoint_search_paths)
110100
last_error: FileNotFoundError | None = None
111101

112102
for candidate in candidates:
@@ -118,8 +108,8 @@ def restore_grug_state_from_checkpoint(
118108
allow_partial=allow_partial,
119109
load_fn=_load_fn,
120110
)
121-
if candidate != checkpoint_path:
122-
logger.info("Loaded checkpoint %s from %s", checkpoint_path, candidate)
111+
if candidate not in checkpoint_search_paths:
112+
logger.info("Loaded checkpoint from %s while searching %s", candidate, checkpoint_search_paths)
123113
return loaded
124114
except FileNotFoundError as exc:
125115
last_error = exc
@@ -128,14 +118,15 @@ def restore_grug_state_from_checkpoint(
128118
)
129119

130120
if load_checkpoint_setting is True:
121+
search_path_summary = ", ".join(checkpoint_search_paths)
131122
attempted = ", ".join(candidates)
132123
if last_error is None:
133-
raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}")
124+
raise FileNotFoundError(f"Could not find checkpoint under any of: {search_path_summary}")
134125
raise FileNotFoundError(
135-
f"Could not load a checkpoint from {checkpoint_path}. Attempted: {attempted}"
126+
f"Could not load a checkpoint from search paths {search_path_summary}. Attempted: {attempted}"
136127
) from last_error
137128

138-
logger.info(f"Checkpoint not found at {checkpoint_path}. Starting from scratch.")
129+
logger.info("Checkpoint not found under %s. Starting from scratch.", checkpoint_search_paths)
139130
return state
140131

141132

experiments/grug/modular_opt/train.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -372,20 +372,21 @@ def _init_state(model_rng):
372372
state = _init_state(model_key)
373373

374374
checkpointer = trainer.checkpointer.create(run_id)
375-
checkpoint_path = trainer.load_checkpoint_path
376-
if checkpoint_path is None and checkpointer is not None:
377-
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
378-
additional_checkpoint_paths = []
379-
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
380-
if temp_path is not None:
381-
additional_checkpoint_paths.append(temp_path)
375+
if trainer.load_checkpoint_path is not None:
376+
checkpoint_search_paths = [trainer.load_checkpoint_path]
377+
elif checkpointer is not None:
378+
checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)]
379+
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
380+
if temp_path is not None:
381+
checkpoint_search_paths.append(temp_path)
382+
else:
383+
checkpoint_search_paths = []
382384
state = restore_grug_state_from_checkpoint(
383385
state,
384-
checkpoint_path=checkpoint_path,
386+
checkpoint_search_paths=checkpoint_search_paths,
385387
load_checkpoint_setting=trainer.load_checkpoint,
386388
mesh=mesh,
387389
allow_partial=trainer.allow_partial_checkpoint,
388-
additional_checkpoint_paths=additional_checkpoint_paths,
389390
)
390391

391392
levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)})

experiments/grug/moe/train.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -410,20 +410,21 @@ def _init_state(model_rng):
410410
state = _init_state(model_key)
411411

412412
checkpointer = trainer.checkpointer.create(run_id)
413-
checkpoint_path = trainer.load_checkpoint_path
414-
if checkpoint_path is None and checkpointer is not None:
415-
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
416-
additional_checkpoint_paths = []
417-
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
418-
if temp_path is not None:
419-
additional_checkpoint_paths.append(temp_path)
413+
if trainer.load_checkpoint_path is not None:
414+
checkpoint_search_paths = [trainer.load_checkpoint_path]
415+
elif checkpointer is not None:
416+
checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)]
417+
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
418+
if temp_path is not None:
419+
checkpoint_search_paths.append(temp_path)
420+
else:
421+
checkpoint_search_paths = []
420422
state = restore_grug_state_from_checkpoint(
421423
state,
422-
checkpoint_path=checkpoint_path,
424+
checkpoint_search_paths=checkpoint_search_paths,
423425
load_checkpoint_setting=trainer.load_checkpoint,
424426
mesh=mesh,
425427
allow_partial=trainer.allow_partial_checkpoint,
426-
additional_checkpoint_paths=additional_checkpoint_paths,
427428
)
428429

429430
levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)})

tests/test_grug_checkpointing.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
3434

3535
loaded = restore_grug_state_from_checkpoint(
3636
{"state": "init"},
37-
checkpoint_path=str(checkpoint_root),
37+
checkpoint_search_paths=[str(checkpoint_root)],
3838
load_checkpoint_setting=True,
3939
mesh=None,
4040
allow_partial=False,
@@ -61,7 +61,7 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
6161

6262
loaded = restore_grug_state_from_checkpoint(
6363
{"state": "init"},
64-
checkpoint_path=str(checkpoint_root),
64+
checkpoint_search_paths=[str(checkpoint_root)],
6565
load_checkpoint_setting=None,
6666
mesh=None,
6767
allow_partial=False,
@@ -86,15 +86,15 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
8686
with pytest.raises(FileNotFoundError, match="Could not load a checkpoint"):
8787
restore_grug_state_from_checkpoint(
8888
{"state": "init"},
89-
checkpoint_path=str(checkpoint_root),
89+
checkpoint_search_paths=[str(checkpoint_root)],
9090
load_checkpoint_setting=True,
9191
mesh=None,
9292
allow_partial=False,
9393
_load_fn=fake_load,
9494
)
9595

9696

97-
def test_restore_discovers_candidates_across_additional_paths(tmp_path: Path):
97+
def test_restore_discovers_candidates_across_search_paths(tmp_path: Path):
9898
permanent_root = tmp_path / "checkpoints"
9999
temp_root = tmp_path / "checkpoints-temp"
100100

@@ -109,20 +109,19 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
109109

110110
loaded = restore_grug_state_from_checkpoint(
111111
{"state": "init"},
112-
checkpoint_path=str(permanent_root),
112+
checkpoint_search_paths=[str(permanent_root), str(temp_root)],
113113
load_checkpoint_setting=True,
114114
mesh=None,
115115
allow_partial=False,
116-
additional_checkpoint_paths=[str(temp_root)],
117116
_load_fn=fake_load,
118117
)
119118

120-
# step-150 from temp root should be preferred (highest step)
119+
# step-150 from temp root should be preferred (highest step).
121120
assert attempted == [str(temp_root / "step-150")]
122121
assert loaded == {"loaded_from": str(temp_root / "step-150")}
123122

124123

125-
def test_restore_respects_explicit_checkpoint_path_with_additional_paths(tmp_path: Path):
124+
def test_restore_respects_explicit_checkpoint_path_as_single_search_path(tmp_path: Path):
126125
permanent_root = tmp_path / "checkpoints"
127126
temp_root = tmp_path / "checkpoints-temp"
128127
explicit_checkpoint = permanent_root / "step-100"
@@ -138,11 +137,10 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
138137

139138
loaded = restore_grug_state_from_checkpoint(
140139
{"state": "init"},
141-
checkpoint_path=str(explicit_checkpoint),
140+
checkpoint_search_paths=[str(explicit_checkpoint)],
142141
load_checkpoint_setting=True,
143142
mesh=None,
144143
allow_partial=False,
145-
additional_checkpoint_paths=[str(temp_root)],
146144
_load_fn=fake_load,
147145
)
148146

@@ -167,11 +165,10 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
167165

168166
loaded = restore_grug_state_from_checkpoint(
169167
{"state": "init"},
170-
checkpoint_path=str(permanent_root),
168+
checkpoint_search_paths=[str(permanent_root), str(temp_root)],
171169
load_checkpoint_setting=None,
172170
mesh=None,
173171
allow_partial=False,
174-
additional_checkpoint_paths=[str(temp_root)],
175172
_load_fn=fake_load,
176173
)
177174

@@ -197,7 +194,7 @@ def test_restore_supports_legacy_wrapped_and_current_checkpoint_formats(tmp_path
197194

198195
loaded_legacy = restore_grug_state_from_checkpoint(
199196
template_state,
200-
checkpoint_path=str(checkpoint_root),
197+
checkpoint_search_paths=[str(checkpoint_root)],
201198
load_checkpoint_setting=True,
202199
mesh=None,
203200
allow_partial=False,
@@ -210,7 +207,7 @@ def test_restore_supports_legacy_wrapped_and_current_checkpoint_formats(tmp_path
210207

211208
loaded_current = restore_grug_state_from_checkpoint(
212209
template_state,
213-
checkpoint_path=str(checkpoint_root),
210+
checkpoint_search_paths=[str(checkpoint_root)],
214211
load_checkpoint_setting=True,
215212
mesh=None,
216213
allow_partial=False,

0 commit comments

Comments
 (0)