Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions experiments/grug/base/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,9 @@ def _init_state(model_rng):
state = _init_state(model_key)

checkpointer = trainer.checkpointer.create(run_id)
checkpoint_path = trainer.load_checkpoint_path
if checkpoint_path is None and checkpointer is not None:
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
state = restore_grug_state_from_checkpoint(
state,
checkpoint_path=checkpoint_path,
checkpoint_search_paths=trainer.checkpoint_search_paths(run_id),
load_checkpoint_setting=trainer.load_checkpoint,
mesh=mesh,
allow_partial=trainer.allow_partial_checkpoint,
Expand Down
57 changes: 33 additions & 24 deletions experiments/grug/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import os
import urllib.parse
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import TypeVar

import fsspec
Expand All @@ -26,19 +26,34 @@ def _get_fs_and_plain_path(path: str) -> tuple[AbstractFileSystem, str]:
return fs, plain_path


def _checkpoint_candidates(checkpoint_path: str) -> list[str]:
fs, plain_path = _get_fs_and_plain_path(checkpoint_path)
base_path_protocol = urllib.parse.urlparse(checkpoint_path).scheme
def _checkpoint_candidates(checkpoint_search_paths: Sequence[str]) -> list[str]:
candidates: list[tuple[int, str, str]] = []
for search_path in checkpoint_search_paths:
candidates.extend(_scan_checkpoint_root(search_path))

candidates.sort(key=lambda item: (item[0], item[1]), reverse=True)
ordered_candidates = [candidate for _, _, candidate in candidates]

for search_path in checkpoint_search_paths:
if search_path not in ordered_candidates:
ordered_candidates.append(search_path)
return ordered_candidates


def _scan_checkpoint_root(root_path: str) -> list[tuple[int, str, str]]:
"""Scan a single root path and return (step, timestamp, path) tuples."""
fs, plain_path = _get_fs_and_plain_path(root_path)
base_path_protocol = urllib.parse.urlparse(root_path).scheme

def maybe_unstrip_protocol(path: str) -> str:
if base_path_protocol != "" and urllib.parse.urlparse(path).scheme == "":
return f"{base_path_protocol}://{path}"
return path

checkpoint_dirs = [maybe_unstrip_protocol(d) for d in fs.glob(os.path.join(plain_path, "*")) if fs.isdir(d)]
checkpoint_dirs.append(checkpoint_path)
checkpoint_dirs.append(root_path)

candidates: list[tuple[int, str, str]] = []
results: list[tuple[int, str, str]] = []
for candidate in checkpoint_dirs:
metadata_path = os.path.join(candidate, "metadata.json")
if not fs.exists(metadata_path):
Expand All @@ -59,34 +74,29 @@ def maybe_unstrip_protocol(path: str) -> str:

timestamp = metadata.get("timestamp")
timestamp_key = str(timestamp) if timestamp is not None else ""
candidates.append((step_num, timestamp_key, candidate))

candidates.sort(key=lambda item: (item[0], item[1]), reverse=True)
ordered_candidates = [candidate for _, _, candidate in candidates]
if checkpoint_path not in ordered_candidates:
ordered_candidates.append(checkpoint_path)
results.append((step_num, timestamp_key, candidate))

return ordered_candidates
return results


def restore_grug_state_from_checkpoint(
state: StateT,
*,
checkpoint_path: str | None,
checkpoint_search_paths: Sequence[str],
load_checkpoint_setting: bool | None,
mesh: jax.sharding.Mesh | None,
allow_partial: bool,
_load_fn: Callable[..., StateT] = load_checkpoint,
) -> StateT:
if checkpoint_path is None:
if not checkpoint_search_paths:
if load_checkpoint_setting:
raise FileNotFoundError("load_checkpoint=True but no checkpoint path is configured.")
raise FileNotFoundError("load_checkpoint=True but no checkpoint search paths are configured.")
return state

if load_checkpoint_setting is False:
return state

candidates = _checkpoint_candidates(checkpoint_path)
candidates = _checkpoint_candidates(checkpoint_search_paths)
last_error: FileNotFoundError | None = None

for candidate in candidates:
Expand All @@ -98,8 +108,8 @@ def restore_grug_state_from_checkpoint(
allow_partial=allow_partial,
load_fn=_load_fn,
)
if candidate != checkpoint_path:
logger.info("Loaded checkpoint %s from %s", checkpoint_path, candidate)
if candidate not in checkpoint_search_paths:
logger.info("Loaded checkpoint from %s while searching %s", candidate, checkpoint_search_paths)
return loaded
except FileNotFoundError as exc:
last_error = exc
Expand All @@ -108,14 +118,15 @@ def restore_grug_state_from_checkpoint(
)

if load_checkpoint_setting is True:
search_path_summary = ", ".join(checkpoint_search_paths)
attempted = ", ".join(candidates)
if last_error is None:
raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}")
raise FileNotFoundError(f"Could not find checkpoint under any of: {search_path_summary}")
raise FileNotFoundError(
f"Could not load a checkpoint from {checkpoint_path}. Attempted: {attempted}"
f"Could not load a checkpoint from search paths {search_path_summary}. Attempted: {attempted}"
) from last_error

logger.info(f"Checkpoint not found at {checkpoint_path}. Starting from scratch.")
logger.info("Checkpoint not found under %s. Starting from scratch.", checkpoint_search_paths)
return state


Expand All @@ -131,7 +142,6 @@ def _load_candidate_state(
return load_fn(
state,
candidate,
discover_latest=False,
axis_mapping=None,
mesh=mesh,
allow_partial=allow_partial,
Expand All @@ -141,7 +151,6 @@ def _load_candidate_state(
wrapped = load_fn(
{"train_state": state},
candidate,
discover_latest=False,
axis_mapping=None,
mesh=mesh,
allow_partial=allow_partial,
Expand Down
5 changes: 1 addition & 4 deletions experiments/grug/modular_opt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,9 @@ def _init_state(model_rng):
state = _init_state(model_key)

checkpointer = trainer.checkpointer.create(run_id)
checkpoint_path = trainer.load_checkpoint_path
if checkpoint_path is None and checkpointer is not None:
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
state = restore_grug_state_from_checkpoint(
state,
checkpoint_path=checkpoint_path,
checkpoint_search_paths=trainer.checkpoint_search_paths(run_id),
load_checkpoint_setting=trainer.load_checkpoint,
mesh=mesh,
allow_partial=trainer.allow_partial_checkpoint,
Expand Down
5 changes: 1 addition & 4 deletions experiments/grug/moe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,9 @@ def _init_state(model_rng):
state = _init_state(model_key)

checkpointer = trainer.checkpointer.create(run_id)
checkpoint_path = trainer.load_checkpoint_path
if checkpoint_path is None and checkpointer is not None:
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
state = restore_grug_state_from_checkpoint(
state,
checkpoint_path=checkpoint_path,
checkpoint_search_paths=trainer.checkpoint_search_paths(run_id),
load_checkpoint_setting=trainer.load_checkpoint,
mesh=mesh,
allow_partial=trainer.allow_partial_checkpoint,
Expand Down
Loading
Loading