Skip to content

Commit 534544b

Browse files
claude[bot]github-actions[bot]dlwh
authored
[levanter] Separate temporary checkpoint base path and use Marin temp buckets (#4387)
Add temporary_base_path to CheckpointerConfig and Checkpointer so time-policy checkpoints route separately while step-policy checkpoints stay durable. Marin sends temporary checkpoints to region-local temp buckets with 14-day TTL, and both Trainer restore and Grug restore search the configured permanent and temporary roots for the newest valid checkpoint. Fixes #4386 --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: David Hall <david.hall@openathena.ai>
1 parent 8775c5b commit 534544b

20 files changed

Lines changed: 451 additions & 163 deletions

File tree

experiments/grug/base/train.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,9 @@ def _init_state(model_rng):
372372
state = _init_state(model_key)
373373

374374
checkpointer = trainer.checkpointer.create(run_id)
375-
checkpoint_path = trainer.load_checkpoint_path
376-
if checkpoint_path is None and checkpointer is not None:
377-
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
378375
state = restore_grug_state_from_checkpoint(
379376
state,
380-
checkpoint_path=checkpoint_path,
377+
checkpoint_search_paths=trainer.checkpoint_search_paths(run_id),
381378
load_checkpoint_setting=trainer.load_checkpoint,
382379
mesh=mesh,
383380
allow_partial=trainer.allow_partial_checkpoint,

experiments/grug/checkpointing.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import os
99
import urllib.parse
10-
from collections.abc import Callable
10+
from collections.abc import Callable, Sequence
1111
from typing import TypeVar
1212

1313
import fsspec
@@ -26,19 +26,34 @@ def _get_fs_and_plain_path(path: str) -> tuple[AbstractFileSystem, str]:
2626
return fs, plain_path
2727

2828

29-
def _checkpoint_candidates(checkpoint_path: str) -> list[str]:
30-
fs, plain_path = _get_fs_and_plain_path(checkpoint_path)
31-
base_path_protocol = urllib.parse.urlparse(checkpoint_path).scheme
29+
def _checkpoint_candidates(checkpoint_search_paths: Sequence[str]) -> list[str]:
30+
candidates: list[tuple[int, str, str]] = []
31+
for search_path in checkpoint_search_paths:
32+
candidates.extend(_scan_checkpoint_root(search_path))
33+
34+
candidates.sort(key=lambda item: (item[0], item[1]), reverse=True)
35+
ordered_candidates = [candidate for _, _, candidate in candidates]
36+
37+
for search_path in checkpoint_search_paths:
38+
if search_path not in ordered_candidates:
39+
ordered_candidates.append(search_path)
40+
return ordered_candidates
41+
42+
43+
def _scan_checkpoint_root(root_path: str) -> list[tuple[int, str, str]]:
44+
"""Scan a single root path and return (step, timestamp, path) tuples."""
45+
fs, plain_path = _get_fs_and_plain_path(root_path)
46+
base_path_protocol = urllib.parse.urlparse(root_path).scheme
3247

3348
def maybe_unstrip_protocol(path: str) -> str:
3449
if base_path_protocol != "" and urllib.parse.urlparse(path).scheme == "":
3550
return f"{base_path_protocol}://{path}"
3651
return path
3752

3853
checkpoint_dirs = [maybe_unstrip_protocol(d) for d in fs.glob(os.path.join(plain_path, "*")) if fs.isdir(d)]
39-
checkpoint_dirs.append(checkpoint_path)
54+
checkpoint_dirs.append(root_path)
4055

41-
candidates: list[tuple[int, str, str]] = []
56+
results: list[tuple[int, str, str]] = []
4257
for candidate in checkpoint_dirs:
4358
metadata_path = os.path.join(candidate, "metadata.json")
4459
if not fs.exists(metadata_path):
@@ -59,34 +74,29 @@ def maybe_unstrip_protocol(path: str) -> str:
5974

6075
timestamp = metadata.get("timestamp")
6176
timestamp_key = str(timestamp) if timestamp is not None else ""
62-
candidates.append((step_num, timestamp_key, candidate))
63-
64-
candidates.sort(key=lambda item: (item[0], item[1]), reverse=True)
65-
ordered_candidates = [candidate for _, _, candidate in candidates]
66-
if checkpoint_path not in ordered_candidates:
67-
ordered_candidates.append(checkpoint_path)
77+
results.append((step_num, timestamp_key, candidate))
6878

69-
return ordered_candidates
79+
return results
7080

7181

7282
def restore_grug_state_from_checkpoint(
7383
state: StateT,
7484
*,
75-
checkpoint_path: str | None,
85+
checkpoint_search_paths: Sequence[str],
7686
load_checkpoint_setting: bool | None,
7787
mesh: jax.sharding.Mesh | None,
7888
allow_partial: bool,
7989
_load_fn: Callable[..., StateT] = load_checkpoint,
8090
) -> StateT:
81-
if checkpoint_path is None:
91+
if not checkpoint_search_paths:
8292
if load_checkpoint_setting:
83-
raise FileNotFoundError("load_checkpoint=True but no checkpoint path is configured.")
93+
raise FileNotFoundError("load_checkpoint=True but no checkpoint search paths are configured.")
8494
return state
8595

8696
if load_checkpoint_setting is False:
8797
return state
8898

89-
candidates = _checkpoint_candidates(checkpoint_path)
99+
candidates = _checkpoint_candidates(checkpoint_search_paths)
90100
last_error: FileNotFoundError | None = None
91101

92102
for candidate in candidates:
@@ -98,8 +108,8 @@ def restore_grug_state_from_checkpoint(
98108
allow_partial=allow_partial,
99109
load_fn=_load_fn,
100110
)
101-
if candidate != checkpoint_path:
102-
logger.info("Loaded checkpoint %s from %s", checkpoint_path, candidate)
111+
if candidate not in checkpoint_search_paths:
112+
logger.info("Loaded checkpoint from %s while searching %s", candidate, checkpoint_search_paths)
103113
return loaded
104114
except FileNotFoundError as exc:
105115
last_error = exc
@@ -108,14 +118,15 @@ def restore_grug_state_from_checkpoint(
108118
)
109119

110120
if load_checkpoint_setting is True:
121+
search_path_summary = ", ".join(checkpoint_search_paths)
111122
attempted = ", ".join(candidates)
112123
if last_error is None:
113-
raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}")
124+
raise FileNotFoundError(f"Could not find checkpoint under any of: {search_path_summary}")
114125
raise FileNotFoundError(
115-
f"Could not load a checkpoint from {checkpoint_path}. Attempted: {attempted}"
126+
f"Could not load a checkpoint from search paths {search_path_summary}. Attempted: {attempted}"
116127
) from last_error
117128

118-
logger.info(f"Checkpoint not found at {checkpoint_path}. Starting from scratch.")
129+
logger.info("Checkpoint not found under %s. Starting from scratch.", checkpoint_search_paths)
119130
return state
120131

121132

@@ -131,7 +142,6 @@ def _load_candidate_state(
131142
return load_fn(
132143
state,
133144
candidate,
134-
discover_latest=False,
135145
axis_mapping=None,
136146
mesh=mesh,
137147
allow_partial=allow_partial,
@@ -141,7 +151,6 @@ def _load_candidate_state(
141151
wrapped = load_fn(
142152
{"train_state": state},
143153
candidate,
144-
discover_latest=False,
145154
axis_mapping=None,
146155
mesh=mesh,
147156
allow_partial=allow_partial,

experiments/grug/modular_opt/train.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,9 @@ def _init_state(model_rng):
372372
state = _init_state(model_key)
373373

374374
checkpointer = trainer.checkpointer.create(run_id)
375-
checkpoint_path = trainer.load_checkpoint_path
376-
if checkpoint_path is None and checkpointer is not None:
377-
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
378375
state = restore_grug_state_from_checkpoint(
379376
state,
380-
checkpoint_path=checkpoint_path,
377+
checkpoint_search_paths=trainer.checkpoint_search_paths(run_id),
381378
load_checkpoint_setting=trainer.load_checkpoint,
382379
mesh=mesh,
383380
allow_partial=trainer.allow_partial_checkpoint,

experiments/grug/moe/train.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,12 +410,9 @@ def _init_state(model_rng):
410410
state = _init_state(model_key)
411411

412412
checkpointer = trainer.checkpointer.create(run_id)
413-
checkpoint_path = trainer.load_checkpoint_path
414-
if checkpoint_path is None and checkpointer is not None:
415-
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
416413
state = restore_grug_state_from_checkpoint(
417414
state,
418-
checkpoint_path=checkpoint_path,
415+
checkpoint_search_paths=trainer.checkpoint_search_paths(run_id),
419416
load_checkpoint_setting=trainer.load_checkpoint,
420417
mesh=mesh,
421418
allow_partial=trainer.allow_partial_checkpoint,

0 commit comments

Comments
 (0)