Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions experiments/grug/base/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,17 @@ def _init_state(model_rng):
checkpoint_path = trainer.load_checkpoint_path
if checkpoint_path is None and checkpointer is not None:
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
additional_checkpoint_paths = []
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
if temp_path is not None:
additional_checkpoint_paths.append(temp_path)
state = restore_grug_state_from_checkpoint(
state,
checkpoint_path=checkpoint_path,
load_checkpoint_setting=trainer.load_checkpoint,
mesh=mesh,
allow_partial=trainer.allow_partial_checkpoint,
additional_checkpoint_paths=additional_checkpoint_paths,
)

levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)})
Expand Down
38 changes: 25 additions & 13 deletions experiments/grug/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,35 @@ def _get_fs_and_plain_path(path: str) -> tuple[AbstractFileSystem, str]:
return fs, plain_path


def _checkpoint_candidates(checkpoint_path: str) -> list[str]:
fs, plain_path = _get_fs_and_plain_path(checkpoint_path)
base_path_protocol = urllib.parse.urlparse(checkpoint_path).scheme
def _checkpoint_candidates(checkpoint_path: str, *, additional_paths: list[str] | None = None) -> list[str]:
all_roots = [checkpoint_path] + (additional_paths or [])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Respect explicit checkpoint path when temp roots are provided

When checkpoint_path is a concrete checkpoint directory (supported by trainer.load_checkpoint_path), this function now mixes it with additional_paths and globally ranks all candidates by step. That allows a newer temp checkpoint to be loaded instead of the explicitly requested checkpoint, which silently changes resume behavior and breaks reproducibility for users pinning a specific step. This was introduced by adding additional_paths into the same candidate pool without a guard for explicit checkpoint paths.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

@dlwh dlwh Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Fixed in 415a3e0 by treating any path with its own metadata.json as an explicit checkpoint directory. Additional temporary roots are only considered when checkpoint_path is a parent/root directory, and tests now cover the pinned-checkpoint case.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Reworked this into the cleaner API: restore_grug_state_from_checkpoint now takes checkpoint_search_paths directly. Callers pass [explicit_checkpoint_path] when a checkpoint is pinned, or [permanent_root, temporary_root] for normal resume discovery, so the restore helper no longer has to infer intent from an additional_paths parameter.


candidates: list[tuple[int, str, str]] = []
for root in all_roots:
candidates.extend(_scan_checkpoint_root(root))

candidates.sort(key=lambda item: (item[0], item[1]), reverse=True)
ordered_candidates = [candidate for _, _, candidate in candidates]
if checkpoint_path not in ordered_candidates:
ordered_candidates.append(checkpoint_path)

return ordered_candidates


def _scan_checkpoint_root(root_path: str) -> list[tuple[int, str, str]]:
"""Scan a single root path and return (step, timestamp, path) tuples."""
fs, plain_path = _get_fs_and_plain_path(root_path)
base_path_protocol = urllib.parse.urlparse(root_path).scheme

def maybe_unstrip_protocol(path: str) -> str:
if base_path_protocol != "" and urllib.parse.urlparse(path).scheme == "":
return f"{base_path_protocol}://{path}"
return path

checkpoint_dirs = [maybe_unstrip_protocol(d) for d in fs.glob(os.path.join(plain_path, "*")) if fs.isdir(d)]
checkpoint_dirs.append(checkpoint_path)
checkpoint_dirs.append(root_path)

candidates: list[tuple[int, str, str]] = []
results: list[tuple[int, str, str]] = []
for candidate in checkpoint_dirs:
metadata_path = os.path.join(candidate, "metadata.json")
if not fs.exists(metadata_path):
Expand All @@ -59,14 +75,9 @@ def maybe_unstrip_protocol(path: str) -> str:

timestamp = metadata.get("timestamp")
timestamp_key = str(timestamp) if timestamp is not None else ""
candidates.append((step_num, timestamp_key, candidate))
results.append((step_num, timestamp_key, candidate))

candidates.sort(key=lambda item: (item[0], item[1]), reverse=True)
ordered_candidates = [candidate for _, _, candidate in candidates]
if checkpoint_path not in ordered_candidates:
ordered_candidates.append(checkpoint_path)

return ordered_candidates
return results


def restore_grug_state_from_checkpoint(
Expand All @@ -76,6 +87,7 @@ def restore_grug_state_from_checkpoint(
load_checkpoint_setting: bool | None,
mesh: jax.sharding.Mesh | None,
allow_partial: bool,
additional_checkpoint_paths: list[str] | None = None,
_load_fn: Callable[..., StateT] = load_checkpoint,
) -> StateT:
if checkpoint_path is None:
Expand All @@ -86,7 +98,7 @@ def restore_grug_state_from_checkpoint(
if load_checkpoint_setting is False:
return state

candidates = _checkpoint_candidates(checkpoint_path)
candidates = _checkpoint_candidates(checkpoint_path, additional_paths=additional_checkpoint_paths or [])
last_error: FileNotFoundError | None = None

for candidate in candidates:
Expand Down
5 changes: 5 additions & 0 deletions experiments/grug/modular_opt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,17 @@ def _init_state(model_rng):
checkpoint_path = trainer.load_checkpoint_path
if checkpoint_path is None and checkpointer is not None:
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
additional_checkpoint_paths = []
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
if temp_path is not None:
additional_checkpoint_paths.append(temp_path)
state = restore_grug_state_from_checkpoint(
state,
checkpoint_path=checkpoint_path,
load_checkpoint_setting=trainer.load_checkpoint,
mesh=mesh,
allow_partial=trainer.allow_partial_checkpoint,
additional_checkpoint_paths=additional_checkpoint_paths,
)

levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)})
Expand Down
5 changes: 5 additions & 0 deletions experiments/grug/moe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,17 @@ def _init_state(model_rng):
checkpoint_path = trainer.load_checkpoint_path
if checkpoint_path is None and checkpointer is not None:
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
additional_checkpoint_paths = []
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
if temp_path is not None:
additional_checkpoint_paths.append(temp_path)
state = restore_grug_state_from_checkpoint(
state,
checkpoint_path=checkpoint_path,
load_checkpoint_setting=trainer.load_checkpoint,
mesh=mesh,
allow_partial=trainer.allow_partial_checkpoint,
additional_checkpoint_paths=additional_checkpoint_paths,
)

levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)})
Expand Down
99 changes: 82 additions & 17 deletions lib/levanter/src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
save_interval: Optional[datetime.timedelta],
step_policies: Sequence[CheckpointInterval],
*,
temporary_base_path: Optional[PathLike] = None,
keep_params: PyTree[FilterSpec] = True,
dt_now_injection: Optional[Callable[[], datetime.datetime]] = None,
delete_old_temp_checkpoints: bool = True,
Expand All @@ -86,11 +87,15 @@ def __init__(
base_path: the base path to save checkpoints to. may be gcs, local, or anything that tensorstore supports
save_interval: the minimum amount of time between checkpoints (for time)
step_policies: the step policies to use
temporary_base_path: separate base path for time-policy (temporary) checkpoints. When set,
temporary checkpoints are written here instead of base_path. Permanent (step-policy)
checkpoints always go to base_path. If None, all checkpoints go to base_path.
keep_params: a PyTree of FilterSpecs that specifies which parameters to keep in the checkpoint
dt_now_injection: a function that returns the current time. useful for testing
delete_old_temp_checkpoints: if True, delete old checkpoints when saving a new one
"""
self.base_path = str(base_path)
self.temporary_base_path = str(temporary_base_path) if temporary_base_path is not None else None
self.save_interval = save_interval
self.step_policies = list(step_policies)
self.keep_params = keep_params
Expand Down Expand Up @@ -124,15 +129,21 @@ def __init__(

# discover latest checkpoint and see if it's temporary
self._last_temporary_checkpoint = None
latest_checkpoint = discover_latest_checkpoint(self.base_path)
if latest_checkpoint is not None and delete_old_temp_checkpoints:
metadata = _load_metadata(latest_checkpoint)
if metadata.get("is_temporary", False):
logger.info(
f"Found prior temporary checkpoint {latest_checkpoint}. We will delete it after"
" saving a new checkpoint."
)
self._last_temporary_checkpoint = latest_checkpoint
# Check both base_path and temporary_base_path for prior temporary checkpoints
search_paths = [self.base_path]
if self.temporary_base_path is not None:
search_paths.append(self.temporary_base_path)
for search_path in search_paths:
latest_checkpoint = discover_latest_checkpoint(search_path)
if latest_checkpoint is not None and delete_old_temp_checkpoints:
metadata = _load_metadata(latest_checkpoint)
if metadata.get("is_temporary", False):
logger.info(
f"Found prior temporary checkpoint {latest_checkpoint}. We will delete it after"
" saving a new checkpoint."
)
self._last_temporary_checkpoint = latest_checkpoint
break

def load_checkpoint(
self,
Expand All @@ -144,6 +155,12 @@ def load_checkpoint(
mesh: Optional[haliax.partitioning.Mesh] = None,
) -> Optional[M]:
if path is None:
# When temporary_base_path is set, discover the newest checkpoint across both roots
if discover_latest and self.temporary_base_path is not None:
latest = discover_latest_checkpoint(self.base_path, self.temporary_base_path)
if latest is not None:
return load_checkpoint(state, latest, discover_latest=False, axis_mapping=axis_mapping, mesh=mesh)
return None
path = self.base_path
return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh)

Expand Down Expand Up @@ -216,8 +233,14 @@ def on_step(self, *, tree: PyTree, step: int, force: bool = False):
last_checkpoint = self._last_temporary_checkpoint
destination = f"step-{step}"

# Route temporary checkpoints to temporary_base_path when configured
if not save_permanent_ckpt and self.temporary_base_path is not None:
save_base_path = self.temporary_base_path
else:
save_base_path = self.base_path

if not save_permanent_ckpt:
self._last_temporary_checkpoint = os.path.join(self.base_path, destination)
self._last_temporary_checkpoint = os.path.join(save_base_path, destination)
else:
self._last_temporary_checkpoint = None

Expand Down Expand Up @@ -248,6 +271,7 @@ def callback():
destination=destination,
commit_callback=callback,
is_temporary=not save_permanent_ckpt,
base_path_override=save_base_path,
)

def _get_current_step_save_interval(self, step):
Expand Down Expand Up @@ -290,8 +314,10 @@ def save_checkpoint(
commit_callback: Optional[Callable[[], None]] = None,
*,
is_temporary: bool = False,
base_path_override: Optional[str] = None,
):
path = os.path.join(self.base_path, destination)
base = base_path_override if base_path_override is not None else self.base_path
path = os.path.join(base, destination)
logger.info(f"Saving checkpoint at step {step} to {path}")

save_checkpoint(
Expand Down Expand Up @@ -539,12 +565,39 @@ def _load_metadata(checkpoint_path, fs=None):
return metadata


def discover_latest_checkpoint(checkpoint_path: PathLike) -> Optional[str]:
def discover_latest_checkpoint(checkpoint_path: PathLike, *additional_paths: PathLike) -> Optional[str]:
"""
Discover the latest checkpoint in a given path.
Discover the latest checkpoint across one or more root paths.

When additional_paths are provided, all roots are searched and the newest
valid checkpoint (by timestamp then step) across all roots is returned.
"""
checkpoint_path = str(checkpoint_path)
# need to use fsspec for this, as glob.glob doesn't work on gs://
all_paths = [str(checkpoint_path)] + [str(p) for p in additional_paths]
best: Optional[str] = None
best_key: Optional[tuple] = None

for cp_path in all_paths:
found = _discover_latest_checkpoint_single(cp_path)
if found is None:
continue
try:
metadata = _load_metadata(found)
key = (datetime.datetime.fromisoformat(metadata["timestamp"]), metadata["step"])
except Exception:
continue
Comment thread
dlwh marked this conversation as resolved.
if best_key is None or key > best_key:
best = found
best_key = key

if best is not None:
logger.info(f"Discovered latest checkpoint at {best}")
else:
logger.warning(f"No checkpoints found in {all_paths}")
return best


def _discover_latest_checkpoint_single(checkpoint_path: str) -> Optional[str]:
"""Discover the latest checkpoint in a single root path."""
fs: AbstractFileSystem
fs, _ = _get_fs_and_plain_path(checkpoint_path)

Expand All @@ -567,10 +620,8 @@ def checkpoint_sort_key(ckpt_dir):

if len(ckpt_dirs) > 0:
out = max(ckpt_dirs, key=checkpoint_sort_key)
logger.info(f"Discovered latest checkpoint from {checkpoint_path} at {out}")
return out
else:
logger.warning(f"No checkpoints found in {checkpoint_path}")
return None


Expand All @@ -585,6 +636,10 @@ def _get_fs_and_plain_path(path, fs=None):
@dataclass
class CheckpointerConfig:
base_path: str = "checkpoints/"
temporary_base_path: Optional[str] = None
"""Separate base path for temporary (time-policy) checkpoints. When set, temporary checkpoints
are written here instead of base_path, allowing use of region-local storage with lifecycle TTL."""

save_interval: timedelta = timedelta(minutes=15)
# TODO: I'd like to write this, but it's not supported by draccus
# keep: List[CheckpointInterval] = field(default_factory=lambda: [CheckpointInterval(every=1000)])
Expand All @@ -605,19 +660,29 @@ def expanded_path(self, run_id) -> str:
return os.path.expanduser(os.path.join(self.base_path, run_id))
return os.path.expanduser(self.base_path)

def expanded_temporary_path(self, run_id) -> Optional[str]:
if self.temporary_base_path is None:
return None
if self.append_run_id_to_base_path:
return os.path.expanduser(os.path.join(self.temporary_base_path, run_id))
return os.path.expanduser(self.temporary_base_path)

def create(self, run_id) -> Checkpointer:
keeps = [CheckpointInterval(**k) for k in self.keep]
return Checkpointer(
base_path=self.expanded_path(run_id),
save_interval=self.save_interval,
step_policies=keeps,
temporary_base_path=self.expanded_temporary_path(run_id),
delete_old_temp_checkpoints=self.delete_old_temp_checkpoints,
)

def __post_init__(self):
# Workaround for Executor using placeholder types.
if isinstance(self.base_path, str):
self.base_path = os.path.expanduser(self.base_path)
if isinstance(self.temporary_base_path, str):
self.temporary_base_path = os.path.expanduser(self.temporary_base_path)

# validate the checkpoint intervals.
# we want to make sure that the intervals are monotonic. only the last one can be None
Expand Down
Loading
Loading