From f3f7cf1bb1caa8c241af011a7fa6471f88c46091 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Fri, 3 Apr 2026 05:48:57 +0000 Subject: [PATCH 1/6] [levanter] Separate temporary checkpoint base path from permanent Add temporary_base_path to CheckpointerConfig and Checkpointer so time-policy checkpoints route to a separate directory (e.g. region-local temp buckets with lifecycle TTL) while step-policy checkpoints stay on the durable base_path. Update discover_latest_checkpoint to search across multiple roots, update grug restore to merge candidates from both paths, and wire Marin training wrapper to use marin_temp_bucket. Fixes #4386 --- experiments/grug/base/train.py | 5 ++ experiments/grug/checkpointing.py | 38 +++++---- experiments/grug/modular_opt/train.py | 5 ++ experiments/grug/moe/train.py | 5 ++ lib/levanter/src/levanter/checkpoint.py | 99 ++++++++++++++++++++---- lib/levanter/tests/test_checkpoint.py | 84 ++++++++++++++++++++ lib/marin/src/marin/training/training.py | 1 + tests/test_grug_checkpointing.py | 57 ++++++++++++++ 8 files changed, 264 insertions(+), 30 deletions(-) diff --git a/experiments/grug/base/train.py b/experiments/grug/base/train.py index 40e4b712f3..f41bfa9a32 100644 --- a/experiments/grug/base/train.py +++ b/experiments/grug/base/train.py @@ -375,12 +375,17 @@ def _init_state(model_rng): checkpoint_path = trainer.load_checkpoint_path if checkpoint_path is None and checkpointer is not None: checkpoint_path = trainer.checkpointer.expanded_path(run_id) + additional_checkpoint_paths = [] + temp_path = trainer.checkpointer.expanded_temporary_path(run_id) + if temp_path is not None: + additional_checkpoint_paths.append(temp_path) state = restore_grug_state_from_checkpoint( state, checkpoint_path=checkpoint_path, load_checkpoint_setting=trainer.load_checkpoint, mesh=mesh, allow_partial=trainer.allow_partial_checkpoint, + additional_checkpoint_paths=additional_checkpoint_paths, ) levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)}) diff --git a/experiments/grug/checkpointing.py b/experiments/grug/checkpointing.py index b82e03945c..83154a358c 100644 --- a/experiments/grug/checkpointing.py +++ b/experiments/grug/checkpointing.py @@ -26,9 +26,25 @@ def _get_fs_and_plain_path(path: str) -> tuple[AbstractFileSystem, str]: return fs, plain_path -def _checkpoint_candidates(checkpoint_path: str) -> list[str]: - fs, plain_path = _get_fs_and_plain_path(checkpoint_path) - base_path_protocol = urllib.parse.urlparse(checkpoint_path).scheme +def _checkpoint_candidates(checkpoint_path: str, *, additional_paths: list[str] | None = None) -> list[str]: + all_roots = [checkpoint_path] + (additional_paths or []) + + candidates: list[tuple[int, str, str]] = [] + for root in all_roots: + candidates.extend(_scan_checkpoint_root(root)) + + candidates.sort(key=lambda item: (item[0], item[1]), reverse=True) + ordered_candidates = [candidate for _, _, candidate in candidates] + if checkpoint_path not in ordered_candidates: + ordered_candidates.append(checkpoint_path) + + return ordered_candidates + + +def _scan_checkpoint_root(root_path: str) -> list[tuple[int, str, str]]: + """Scan a single root path and return (step, timestamp, path) tuples.""" + fs, plain_path = _get_fs_and_plain_path(root_path) + base_path_protocol = urllib.parse.urlparse(root_path).scheme def maybe_unstrip_protocol(path: str) -> str: if base_path_protocol != "" and urllib.parse.urlparse(path).scheme == "": @@ -36,9 +52,9 @@ def maybe_unstrip_protocol(path: str) -> str: return path checkpoint_dirs = [maybe_unstrip_protocol(d) for d in fs.glob(os.path.join(plain_path, "*")) if fs.isdir(d)] - checkpoint_dirs.append(checkpoint_path) + checkpoint_dirs.append(root_path) - candidates: list[tuple[int, str, str]] = [] + results: list[tuple[int, str, str]] = [] for candidate in checkpoint_dirs: metadata_path = os.path.join(candidate, "metadata.json") if not fs.exists(metadata_path): @@ -59,14 +75,9 @@ def maybe_unstrip_protocol(path: str) -> str: timestamp = metadata.get("timestamp") timestamp_key = str(timestamp) if timestamp is not None else "" - candidates.append((step_num, timestamp_key, candidate)) + results.append((step_num, timestamp_key, candidate)) - candidates.sort(key=lambda item: (item[0], item[1]), reverse=True) - ordered_candidates = [candidate for _, _, candidate in candidates] - if checkpoint_path not in ordered_candidates: - ordered_candidates.append(checkpoint_path) - - return ordered_candidates + return results def restore_grug_state_from_checkpoint( @@ -76,6 +87,7 @@ def restore_grug_state_from_checkpoint( load_checkpoint_setting: bool | None, mesh: jax.sharding.Mesh | None, allow_partial: bool, + additional_checkpoint_paths: list[str] | None = None, _load_fn: Callable[..., StateT] = load_checkpoint, ) -> StateT: if checkpoint_path is None: @@ -86,7 +98,7 @@ def restore_grug_state_from_checkpoint( if load_checkpoint_setting is False: return state - candidates = _checkpoint_candidates(checkpoint_path) + candidates = _checkpoint_candidates(checkpoint_path, additional_paths=additional_checkpoint_paths or []) last_error: FileNotFoundError | None = None for candidate in candidates: diff --git a/experiments/grug/modular_opt/train.py b/experiments/grug/modular_opt/train.py index 113b607526..6d2833c677 100644 --- a/experiments/grug/modular_opt/train.py +++ b/experiments/grug/modular_opt/train.py @@ -375,12 +375,17 @@ def _init_state(model_rng): checkpoint_path = trainer.load_checkpoint_path if checkpoint_path is None and checkpointer is not None: checkpoint_path = trainer.checkpointer.expanded_path(run_id) + additional_checkpoint_paths = [] + temp_path = trainer.checkpointer.expanded_temporary_path(run_id) + if temp_path is not None: + additional_checkpoint_paths.append(temp_path) state = restore_grug_state_from_checkpoint( state, checkpoint_path=checkpoint_path, load_checkpoint_setting=trainer.load_checkpoint, mesh=mesh, allow_partial=trainer.allow_partial_checkpoint, + additional_checkpoint_paths=additional_checkpoint_paths, ) levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)}) diff --git a/experiments/grug/moe/train.py b/experiments/grug/moe/train.py index acec78c972..27ac04d039 100644 --- a/experiments/grug/moe/train.py +++ b/experiments/grug/moe/train.py @@ -401,12 +401,17 @@ def _init_state(model_rng): checkpoint_path = trainer.load_checkpoint_path if checkpoint_path is None and checkpointer is not None: checkpoint_path = trainer.checkpointer.expanded_path(run_id) + additional_checkpoint_paths = [] + temp_path = trainer.checkpointer.expanded_temporary_path(run_id) + if temp_path is not None: + additional_checkpoint_paths.append(temp_path) state = restore_grug_state_from_checkpoint( state, checkpoint_path=checkpoint_path, load_checkpoint_setting=trainer.load_checkpoint, mesh=mesh, allow_partial=trainer.allow_partial_checkpoint, + additional_checkpoint_paths=additional_checkpoint_paths, ) levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)}) diff --git a/lib/levanter/src/levanter/checkpoint.py b/lib/levanter/src/levanter/checkpoint.py index 588aba6c35..ee0022b767 100644 --- a/lib/levanter/src/levanter/checkpoint.py +++ b/lib/levanter/src/levanter/checkpoint.py @@ -70,6 +70,7 @@ def __init__( save_interval: Optional[datetime.timedelta], step_policies: Sequence[CheckpointInterval], *, + temporary_base_path: Optional[PathLike] = None, keep_params: PyTree[FilterSpec] = True, dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, delete_old_temp_checkpoints: bool = True, @@ -86,11 +87,15 @@ def __init__( base_path: the base path to save checkpoints to. may be gcs, local, or anything that tensorstore supports save_interval: the minimum amount of time between checkpoints (for time) step_policies: the step policies to use + temporary_base_path: separate base path for time-policy (temporary) checkpoints. When set, + temporary checkpoints are written here instead of base_path. Permanent (step-policy) + checkpoints always go to base_path. If None, all checkpoints go to base_path. keep_params: a PyTree of FilterSpecs that specifies which parameters to keep in the checkpoint dt_now_injection: a function that returns the current time. useful for testing delete_old_temp_checkpoints: if True, delete old checkpoints when saving a new one """ self.base_path = str(base_path) + self.temporary_base_path = str(temporary_base_path) if temporary_base_path is not None else None self.save_interval = save_interval self.step_policies = list(step_policies) self.keep_params = keep_params @@ -124,15 +129,21 @@ def __init__( # discover latest checkpoint and see if it's temporary self._last_temporary_checkpoint = None - latest_checkpoint = discover_latest_checkpoint(self.base_path) - if latest_checkpoint is not None and delete_old_temp_checkpoints: - metadata = _load_metadata(latest_checkpoint) - if metadata.get("is_temporary", False): - logger.info( - f"Found prior temporary checkpoint {latest_checkpoint}. We will delete it after" - " saving a new checkpoint." - ) - self._last_temporary_checkpoint = latest_checkpoint + # Check both base_path and temporary_base_path for prior temporary checkpoints + search_paths = [self.base_path] + if self.temporary_base_path is not None: + search_paths.append(self.temporary_base_path) + for search_path in search_paths: + latest_checkpoint = discover_latest_checkpoint(search_path) + if latest_checkpoint is not None and delete_old_temp_checkpoints: + metadata = _load_metadata(latest_checkpoint) + if metadata.get("is_temporary", False): + logger.info( + f"Found prior temporary checkpoint {latest_checkpoint}. We will delete it after" + " saving a new checkpoint." + ) + self._last_temporary_checkpoint = latest_checkpoint + break def load_checkpoint( self, @@ -144,6 +155,12 @@ def load_checkpoint( mesh: Optional[haliax.partitioning.Mesh] = None, ) -> Optional[M]: if path is None: + # When temporary_base_path is set, discover the newest checkpoint across both roots + if discover_latest and self.temporary_base_path is not None: + latest = discover_latest_checkpoint(self.base_path, self.temporary_base_path) + if latest is not None: + return load_checkpoint(state, latest, discover_latest=False, axis_mapping=axis_mapping, mesh=mesh) + return None path = self.base_path return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh) @@ -216,8 +233,14 @@ def on_step(self, *, tree: PyTree, step: int, force: bool = False): last_checkpoint = self._last_temporary_checkpoint destination = f"step-{step}" + # Route temporary checkpoints to temporary_base_path when configured + if not save_permanent_ckpt and self.temporary_base_path is not None: + save_base_path = self.temporary_base_path + else: + save_base_path = self.base_path + if not save_permanent_ckpt: - self._last_temporary_checkpoint = os.path.join(self.base_path, destination) + self._last_temporary_checkpoint = os.path.join(save_base_path, destination) else: self._last_temporary_checkpoint = None @@ -248,6 +271,7 @@ def callback(): destination=destination, commit_callback=callback, is_temporary=not save_permanent_ckpt, + base_path_override=save_base_path, ) def _get_current_step_save_interval(self, step): @@ -290,8 +314,10 @@ def save_checkpoint( commit_callback: Optional[Callable[[], None]] = None, *, is_temporary: bool = False, + base_path_override: Optional[str] = None, ): - path = os.path.join(self.base_path, destination) + base = base_path_override if base_path_override is not None else self.base_path + path = os.path.join(base, destination) logger.info(f"Saving checkpoint at step {step} to {path}") save_checkpoint( @@ -539,12 +565,39 @@ def _load_metadata(checkpoint_path, fs=None): return metadata -def discover_latest_checkpoint(checkpoint_path: PathLike) -> Optional[str]: +def discover_latest_checkpoint(checkpoint_path: PathLike, *additional_paths: PathLike) -> Optional[str]: """ - Discover the latest checkpoint in a given path. + Discover the latest checkpoint across one or more root paths. + + When additional_paths are provided, all roots are searched and the newest + valid checkpoint (by timestamp then step) across all roots is returned. """ - checkpoint_path = str(checkpoint_path) - # need to use fsspec for this, as glob.glob doesn't work on gs:// + all_paths = [str(checkpoint_path)] + [str(p) for p in additional_paths] + best: Optional[str] = None + best_key: Optional[tuple] = None + + for cp_path in all_paths: + found = _discover_latest_checkpoint_single(cp_path) + if found is None: + continue + try: + metadata = _load_metadata(found) + key = (datetime.datetime.fromisoformat(metadata["timestamp"]), metadata["step"]) + except Exception: + continue + if best_key is None or key > best_key: + best = found + best_key = key + + if best is not None: + logger.info(f"Discovered latest checkpoint at {best}") + else: + logger.warning(f"No checkpoints found in {all_paths}") + return best + + +def _discover_latest_checkpoint_single(checkpoint_path: str) -> Optional[str]: + """Discover the latest checkpoint in a single root path.""" fs: AbstractFileSystem fs, _ = _get_fs_and_plain_path(checkpoint_path) @@ -567,10 +620,8 @@ def checkpoint_sort_key(ckpt_dir): if len(ckpt_dirs) > 0: out = max(ckpt_dirs, key=checkpoint_sort_key) - logger.info(f"Discovered latest checkpoint from {checkpoint_path} at {out}") return out else: - logger.warning(f"No checkpoints found in {checkpoint_path}") return None @@ -585,6 +636,10 @@ def _get_fs_and_plain_path(path, fs=None): @dataclass class CheckpointerConfig: base_path: str = "checkpoints/" + temporary_base_path: Optional[str] = None + """Separate base path for temporary (time-policy) checkpoints. When set, temporary checkpoints + are written here instead of base_path, allowing use of region-local storage with lifecycle TTL.""" + save_interval: timedelta = timedelta(minutes=15) # TODO: I'd like to write this, but it's not supported by draccus # keep: List[CheckpointInterval] = field(default_factory=lambda: [CheckpointInterval(every=1000)]) @@ -605,12 +660,20 @@ def expanded_path(self, run_id) -> str: return os.path.expanduser(os.path.join(self.base_path, run_id)) return os.path.expanduser(self.base_path) + def expanded_temporary_path(self, run_id) -> Optional[str]: + if self.temporary_base_path is None: + return None + if self.append_run_id_to_base_path: + return os.path.expanduser(os.path.join(self.temporary_base_path, run_id)) + return os.path.expanduser(self.temporary_base_path) + def create(self, run_id) -> Checkpointer: keeps = [CheckpointInterval(**k) for k in self.keep] return Checkpointer( base_path=self.expanded_path(run_id), save_interval=self.save_interval, step_policies=keeps, + temporary_base_path=self.expanded_temporary_path(run_id), delete_old_temp_checkpoints=self.delete_old_temp_checkpoints, ) @@ -618,6 +681,8 @@ def __post_init__(self): # Workaround for Executor using placeholder types. if isinstance(self.base_path, str): self.base_path = os.path.expanduser(self.base_path) + if isinstance(self.temporary_base_path, str): + self.temporary_base_path = os.path.expanduser(self.temporary_base_path) # validate the checkpoint intervals. # we want to make sure that the intervals are monotonic. only the last one can be None diff --git a/lib/levanter/tests/test_checkpoint.py b/lib/levanter/tests/test_checkpoint.py index 1892b9bb48..c8f21fb276 100644 --- a/lib/levanter/tests/test_checkpoint.py +++ b/lib/levanter/tests/test_checkpoint.py @@ -26,6 +26,7 @@ from levanter.callbacks import StepInfo from levanter.checkpoint import ( Checkpointer, + CheckpointerConfig, CheckpointInterval, _load_metadata, discover_latest_checkpoint, @@ -259,6 +260,89 @@ def test_checkpoint_discovery(): assert discover_latest_checkpoint("file:///tmp/does-not-exist") is None +def test_checkpoint_discovery_across_multiple_paths(): + with tempfile.TemporaryDirectory() as permanent_dir, tempfile.TemporaryDirectory() as temp_dir: + save_checkpoint(dict(model=1), step=10, checkpoint_path=f"{permanent_dir}/step-10", is_temporary=False) + save_checkpoint(dict(model=2), step=15, checkpoint_path=f"{temp_dir}/step-15", is_temporary=True) + + # Without additional paths, only permanent_dir is searched + latest_single = discover_latest_checkpoint(permanent_dir) + assert latest_single == f"{permanent_dir}/step-10" + + # With additional paths, the newer checkpoint in temp_dir wins + latest_both = discover_latest_checkpoint(permanent_dir, temp_dir) + assert latest_both == f"{temp_dir}/step-15" + + +def test_checkpointer_temporary_base_path_routes_temp_checkpoints(): + fake_now = datetime.datetime(2021, 1, 1, 0, 0, 0) + tick = 10 + + def advance_time(delta_seconds): + nonlocal fake_now + fake_now += timedelta(seconds=delta_seconds) + + with tempfile.TemporaryDirectory() as permanent_dir, tempfile.TemporaryDirectory() as temp_dir: + checkpointer = Checkpointer( + permanent_dir, + timedelta(seconds=tick), + [CheckpointInterval(every=5, until=None)], + temporary_base_path=temp_dir, + dt_now_injection=lambda: fake_now, + ) + + # Step 0 doesn't save + _on_step(checkpointer, 0) + + # Time-based save goes to temp_dir + advance_time(tick) + _on_step(checkpointer, 1) + checkpointer.wait_until_finished() + assert _get_checkpoint_steps(temp_dir) == [1] + assert _get_checkpoint_steps(permanent_dir) == [] + + # Step-based save goes to permanent_dir + advance_time(tick) + _on_step(checkpointer, 5) + checkpointer.wait_until_finished() + assert _get_checkpoint_steps(permanent_dir) == [5] + # Old temp checkpoint should be deleted + assert _get_checkpoint_steps(temp_dir) == [] + + # Another time-based save goes to temp_dir + advance_time(tick) + _on_step(checkpointer, 6) + checkpointer.wait_until_finished() + assert _get_checkpoint_steps(temp_dir) == [6] + assert _get_checkpoint_steps(permanent_dir) == [5] + + +def test_checkpointer_config_temporary_base_path(): + config = dataclasses.replace( + CheckpointerConfig(), + base_path="/tmp/test-perm", + temporary_base_path="/tmp/test-temp", + append_run_id_to_base_path=False, + ) + assert config.expanded_path("run1") == "/tmp/test-perm" + assert config.expanded_temporary_path("run1") == "/tmp/test-temp" + + config_with_run_id = dataclasses.replace( + CheckpointerConfig(), + base_path="/tmp/test-perm", + temporary_base_path="/tmp/test-temp", + append_run_id_to_base_path=True, + ) + assert config_with_run_id.expanded_path("run1") == "/tmp/test-perm/run1" + assert config_with_run_id.expanded_temporary_path("run1") == "/tmp/test-temp/run1" + + +def test_checkpointer_config_no_temporary_base_path(): + config = CheckpointerConfig() + assert config.temporary_base_path is None + assert config.expanded_temporary_path("run1") is None + + def test_checkpointer_deletes_previous_checkpoints(): fake_now = datetime.datetime(2021, 1, 1, 0, 0, 0) diff --git a/lib/marin/src/marin/training/training.py b/lib/marin/src/marin/training/training.py index a0ddd4627e..e4c6a5c812 100644 --- a/lib/marin/src/marin/training/training.py +++ b/lib/marin/src/marin/training/training.py @@ -108,6 +108,7 @@ def _update_config_to_use_out_path(pod_config: TrainOnPodConfigT) -> TrainOnPodC checkpointer=replace( pod_config.train_config.trainer.checkpointer, base_path=os.path.join(pod_config.output_path, DEFAULT_CHECKPOINTS_PATH), + temporary_base_path=marin_temp_bucket(ttl_days=14, prefix="checkpoints-temp"), ), ) diff --git a/tests/test_grug_checkpointing.py b/tests/test_grug_checkpointing.py index 1837227c8b..314331544a 100644 --- a/tests/test_grug_checkpointing.py +++ b/tests/test_grug_checkpointing.py @@ -94,6 +94,63 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial ) +def test_restore_discovers_candidates_across_additional_paths(tmp_path: Path): + permanent_root = tmp_path / "checkpoints" + temp_root = tmp_path / "checkpoints-temp" + + _write_checkpoint_metadata(permanent_root / "step-100", step=100, timestamp="2026-03-17T00:00:00") + _write_checkpoint_metadata(temp_root / "step-150", step=150, timestamp="2026-03-17T06:00:00") + + attempted: list[str] = [] + + def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial): + attempted.append(path) + return {"loaded_from": path} + + loaded = restore_grug_state_from_checkpoint( + {"state": "init"}, + checkpoint_path=str(permanent_root), + load_checkpoint_setting=True, + mesh=None, + allow_partial=False, + additional_checkpoint_paths=[str(temp_root)], + _load_fn=fake_load, + ) + + # step-150 from temp root should be preferred (highest step) + assert attempted == [str(temp_root / "step-150")] + assert loaded == {"loaded_from": str(temp_root / "step-150")} + + +def test_restore_falls_back_from_temp_to_permanent(tmp_path: Path): + permanent_root = tmp_path / "checkpoints" + temp_root = tmp_path / "checkpoints-temp" + + _write_checkpoint_metadata(permanent_root / "step-100", step=100, timestamp="2026-03-17T00:00:00") + _write_checkpoint_metadata(temp_root / "step-150", step=150, timestamp="2026-03-17T06:00:00") + + attempted: list[str] = [] + + def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial): + attempted.append(path) + if "step-150" in path: + raise FileNotFoundError(path) + return {"loaded_from": path} + + loaded = restore_grug_state_from_checkpoint( + {"state": "init"}, + checkpoint_path=str(permanent_root), + load_checkpoint_setting=None, + mesh=None, + allow_partial=False, + additional_checkpoint_paths=[str(temp_root)], + _load_fn=fake_load, + ) + + # Should fall back to step-100 from permanent root + assert loaded == {"loaded_from": str(permanent_root / "step-100")} + + def test_restore_supports_legacy_wrapped_and_current_checkpoint_formats(tmp_path: Path): checkpoint_root = tmp_path / "checkpoints" From 415a3e053e02e300438f3b9d258898591f9947a8 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 21 Apr 2026 15:00:23 -0700 Subject: [PATCH 2/6] Respect explicit grug checkpoint paths --- experiments/grug/checkpointing.py | 8 ++++++++ tests/test_grug_checkpointing.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/experiments/grug/checkpointing.py b/experiments/grug/checkpointing.py index 83154a358c..f5d042bfcf 100644 --- a/experiments/grug/checkpointing.py +++ b/experiments/grug/checkpointing.py @@ -27,6 +27,9 @@ def _get_fs_and_plain_path(path: str) -> tuple[AbstractFileSystem, str]: def _checkpoint_candidates(checkpoint_path: str, *, additional_paths: list[str] | None = None) -> list[str]: + if _is_checkpoint_dir(checkpoint_path): + return [checkpoint_path] + all_roots = [checkpoint_path] + (additional_paths or []) candidates: list[tuple[int, str, str]] = [] @@ -41,6 +44,11 @@ def _checkpoint_candidates(checkpoint_path: str, *, additional_paths: list[str] return ordered_candidates +def _is_checkpoint_dir(checkpoint_path: str) -> bool: + fs, plain_path = _get_fs_and_plain_path(checkpoint_path) + return fs.exists(os.path.join(plain_path, "metadata.json")) + + def _scan_checkpoint_root(root_path: str) -> list[tuple[int, str, str]]: """Scan a single root path and return (step, timestamp, path) tuples.""" fs, plain_path = _get_fs_and_plain_path(root_path) diff --git a/tests/test_grug_checkpointing.py b/tests/test_grug_checkpointing.py index 314331544a..cbba918cc1 100644 --- a/tests/test_grug_checkpointing.py +++ b/tests/test_grug_checkpointing.py @@ -122,6 +122,34 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial assert loaded == {"loaded_from": str(temp_root / "step-150")} +def test_restore_respects_explicit_checkpoint_path_with_additional_paths(tmp_path: Path): + permanent_root = tmp_path / "checkpoints" + temp_root = tmp_path / "checkpoints-temp" + explicit_checkpoint = permanent_root / "step-100" + + _write_checkpoint_metadata(explicit_checkpoint, step=100, timestamp="2026-03-17T00:00:00") + _write_checkpoint_metadata(temp_root / "step-150", step=150, timestamp="2026-03-17T06:00:00") + + attempted: list[str] = [] + + def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial): + attempted.append(path) + return {"loaded_from": path} + + loaded = restore_grug_state_from_checkpoint( + {"state": "init"}, + checkpoint_path=str(explicit_checkpoint), + load_checkpoint_setting=True, + mesh=None, + allow_partial=False, + additional_checkpoint_paths=[str(temp_root)], + _load_fn=fake_load, + ) + + assert attempted == [str(explicit_checkpoint)] + assert loaded == {"loaded_from": str(explicit_checkpoint)} + + def test_restore_falls_back_from_temp_to_permanent(tmp_path: Path): permanent_root = tmp_path / "checkpoints" temp_root = tmp_path / "checkpoints-temp" From a93bf7a36a95f68f3ac722d1d702a5ee33c2248b Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 21 Apr 2026 16:03:54 -0700 Subject: [PATCH 3/6] Use explicit grug checkpoint search paths --- experiments/grug/base/train.py | 19 ++++++------ experiments/grug/checkpointing.py | 43 +++++++++++---------------- experiments/grug/modular_opt/train.py | 19 ++++++------ experiments/grug/moe/train.py | 19 ++++++------ tests/test_grug_checkpointing.py | 25 +++++++--------- 5 files changed, 58 insertions(+), 67 deletions(-) diff --git a/experiments/grug/base/train.py b/experiments/grug/base/train.py index f41bfa9a32..98fb8b6439 100644 --- a/experiments/grug/base/train.py +++ b/experiments/grug/base/train.py @@ -372,20 +372,21 @@ def _init_state(model_rng): state = _init_state(model_key) checkpointer = trainer.checkpointer.create(run_id) - checkpoint_path = trainer.load_checkpoint_path - if checkpoint_path is None and checkpointer is not None: - checkpoint_path = trainer.checkpointer.expanded_path(run_id) - additional_checkpoint_paths = [] - temp_path = trainer.checkpointer.expanded_temporary_path(run_id) - if temp_path is not None: - additional_checkpoint_paths.append(temp_path) + if trainer.load_checkpoint_path is not None: + checkpoint_search_paths = [trainer.load_checkpoint_path] + elif checkpointer is not None: + checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)] + temp_path = trainer.checkpointer.expanded_temporary_path(run_id) + if temp_path is not None: + checkpoint_search_paths.append(temp_path) + else: + checkpoint_search_paths = [] state = restore_grug_state_from_checkpoint( state, - checkpoint_path=checkpoint_path, + checkpoint_search_paths=checkpoint_search_paths, load_checkpoint_setting=trainer.load_checkpoint, mesh=mesh, allow_partial=trainer.allow_partial_checkpoint, - additional_checkpoint_paths=additional_checkpoint_paths, ) levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)}) diff --git a/experiments/grug/checkpointing.py b/experiments/grug/checkpointing.py index f5d042bfcf..24b209aed1 100644 --- a/experiments/grug/checkpointing.py +++ b/experiments/grug/checkpointing.py @@ -7,7 +7,7 @@ import logging import os import urllib.parse -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import TypeVar import fsspec @@ -26,29 +26,20 @@ def _get_fs_and_plain_path(path: str) -> tuple[AbstractFileSystem, str]: return fs, plain_path -def _checkpoint_candidates(checkpoint_path: str, *, additional_paths: list[str] | None = None) -> list[str]: - if _is_checkpoint_dir(checkpoint_path): - return [checkpoint_path] - - all_roots = [checkpoint_path] + (additional_paths or []) - +def _checkpoint_candidates(checkpoint_search_paths: Sequence[str]) -> list[str]: candidates: list[tuple[int, str, str]] = [] - for root in all_roots: - candidates.extend(_scan_checkpoint_root(root)) + for search_path in checkpoint_search_paths: + candidates.extend(_scan_checkpoint_root(search_path)) candidates.sort(key=lambda item: (item[0], item[1]), reverse=True) ordered_candidates = [candidate for _, _, candidate in candidates] - if checkpoint_path not in ordered_candidates: - ordered_candidates.append(checkpoint_path) + for search_path in checkpoint_search_paths: + if search_path not in ordered_candidates: + ordered_candidates.append(search_path) return ordered_candidates -def _is_checkpoint_dir(checkpoint_path: str) -> bool: - fs, plain_path = _get_fs_and_plain_path(checkpoint_path) - return fs.exists(os.path.join(plain_path, "metadata.json")) - - def _scan_checkpoint_root(root_path: str) -> list[tuple[int, str, str]]: """Scan a single root path and return (step, timestamp, path) tuples.""" fs, plain_path = _get_fs_and_plain_path(root_path) @@ -91,22 +82,21 @@ def maybe_unstrip_protocol(path: str) -> str: def restore_grug_state_from_checkpoint( state: StateT, *, - checkpoint_path: str | None, + checkpoint_search_paths: Sequence[str], load_checkpoint_setting: bool | None, mesh: jax.sharding.Mesh | None, allow_partial: bool, - additional_checkpoint_paths: list[str] | None = None, _load_fn: Callable[..., StateT] = load_checkpoint, ) -> StateT: - if checkpoint_path is None: + if not checkpoint_search_paths: if load_checkpoint_setting: - raise FileNotFoundError("load_checkpoint=True but no checkpoint path is configured.") + raise FileNotFoundError("load_checkpoint=True but no checkpoint search paths are configured.") return state if load_checkpoint_setting is False: return state - candidates = _checkpoint_candidates(checkpoint_path, additional_paths=additional_checkpoint_paths or []) + candidates = _checkpoint_candidates(checkpoint_search_paths) last_error: FileNotFoundError | None = None for candidate in candidates: @@ -118,8 +108,8 @@ def restore_grug_state_from_checkpoint( allow_partial=allow_partial, load_fn=_load_fn, ) - if candidate != checkpoint_path: - logger.info("Loaded checkpoint %s from %s", checkpoint_path, candidate) + if candidate not in checkpoint_search_paths: + logger.info("Loaded checkpoint from %s while searching %s", candidate, checkpoint_search_paths) return loaded except FileNotFoundError as exc: last_error = exc @@ -128,14 +118,15 @@ def restore_grug_state_from_checkpoint( ) if load_checkpoint_setting is True: + search_path_summary = ", ".join(checkpoint_search_paths) attempted = ", ".join(candidates) if last_error is None: - raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") + raise FileNotFoundError(f"Could not find checkpoint under any of: {search_path_summary}") raise FileNotFoundError( - f"Could not load a checkpoint from {checkpoint_path}. Attempted: {attempted}" + f"Could not load a checkpoint from search paths {search_path_summary}. Attempted: {attempted}" ) from last_error - logger.info(f"Checkpoint not found at {checkpoint_path}. Starting from scratch.") + logger.info("Checkpoint not found under %s. Starting from scratch.", checkpoint_search_paths) return state diff --git a/experiments/grug/modular_opt/train.py b/experiments/grug/modular_opt/train.py index 6d2833c677..887b45dfa0 100644 --- a/experiments/grug/modular_opt/train.py +++ b/experiments/grug/modular_opt/train.py @@ -372,20 +372,21 @@ def _init_state(model_rng): state = _init_state(model_key) checkpointer = trainer.checkpointer.create(run_id) - checkpoint_path = trainer.load_checkpoint_path - if checkpoint_path is None and checkpointer is not None: - checkpoint_path = trainer.checkpointer.expanded_path(run_id) - additional_checkpoint_paths = [] - temp_path = trainer.checkpointer.expanded_temporary_path(run_id) - if temp_path is not None: - additional_checkpoint_paths.append(temp_path) + if trainer.load_checkpoint_path is not None: + checkpoint_search_paths = [trainer.load_checkpoint_path] + elif checkpointer is not None: + checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)] + temp_path = trainer.checkpointer.expanded_temporary_path(run_id) + if temp_path is not None: + checkpoint_search_paths.append(temp_path) + else: + checkpoint_search_paths = [] state = restore_grug_state_from_checkpoint( state, - checkpoint_path=checkpoint_path, + checkpoint_search_paths=checkpoint_search_paths, load_checkpoint_setting=trainer.load_checkpoint, mesh=mesh, allow_partial=trainer.allow_partial_checkpoint, - additional_checkpoint_paths=additional_checkpoint_paths, ) levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)}) diff --git a/experiments/grug/moe/train.py b/experiments/grug/moe/train.py index e539406a0b..ef47ccc457 100644 --- a/experiments/grug/moe/train.py +++ b/experiments/grug/moe/train.py @@ -410,20 +410,21 @@ def _init_state(model_rng): state = _init_state(model_key) checkpointer = trainer.checkpointer.create(run_id) - checkpoint_path = trainer.load_checkpoint_path - if checkpoint_path is None and checkpointer is not None: - checkpoint_path = trainer.checkpointer.expanded_path(run_id) - additional_checkpoint_paths = [] - temp_path = trainer.checkpointer.expanded_temporary_path(run_id) - if temp_path is not None: - additional_checkpoint_paths.append(temp_path) + if trainer.load_checkpoint_path is not None: + checkpoint_search_paths = [trainer.load_checkpoint_path] + elif checkpointer is not None: + checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)] + temp_path = trainer.checkpointer.expanded_temporary_path(run_id) + if temp_path is not None: + checkpoint_search_paths.append(temp_path) + else: + checkpoint_search_paths = [] state = restore_grug_state_from_checkpoint( state, - checkpoint_path=checkpoint_path, + checkpoint_search_paths=checkpoint_search_paths, load_checkpoint_setting=trainer.load_checkpoint, mesh=mesh, allow_partial=trainer.allow_partial_checkpoint, - additional_checkpoint_paths=additional_checkpoint_paths, ) levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)}) diff --git a/tests/test_grug_checkpointing.py b/tests/test_grug_checkpointing.py index cbba918cc1..caa7461de0 100644 --- a/tests/test_grug_checkpointing.py +++ b/tests/test_grug_checkpointing.py @@ -34,7 +34,7 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial loaded = restore_grug_state_from_checkpoint( {"state": "init"}, - checkpoint_path=str(checkpoint_root), + checkpoint_search_paths=[str(checkpoint_root)], load_checkpoint_setting=True, mesh=None, allow_partial=False, @@ -61,7 +61,7 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial loaded = restore_grug_state_from_checkpoint( {"state": "init"}, - checkpoint_path=str(checkpoint_root), + checkpoint_search_paths=[str(checkpoint_root)], load_checkpoint_setting=None, mesh=None, allow_partial=False, @@ -86,7 +86,7 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial with pytest.raises(FileNotFoundError, match="Could not load a checkpoint"): restore_grug_state_from_checkpoint( {"state": "init"}, - checkpoint_path=str(checkpoint_root), + checkpoint_search_paths=[str(checkpoint_root)], load_checkpoint_setting=True, mesh=None, allow_partial=False, @@ -94,7 +94,7 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial ) -def test_restore_discovers_candidates_across_additional_paths(tmp_path: Path): +def test_restore_discovers_candidates_across_search_paths(tmp_path: Path): permanent_root = tmp_path / "checkpoints" temp_root = tmp_path / "checkpoints-temp" @@ -109,20 +109,19 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial loaded = restore_grug_state_from_checkpoint( {"state": "init"}, - checkpoint_path=str(permanent_root), + checkpoint_search_paths=[str(permanent_root), str(temp_root)], load_checkpoint_setting=True, mesh=None, allow_partial=False, - additional_checkpoint_paths=[str(temp_root)], _load_fn=fake_load, ) - # step-150 from temp root should be preferred (highest step) + # step-150 from temp root should be preferred (highest step). assert attempted == [str(temp_root / "step-150")] assert loaded == {"loaded_from": str(temp_root / "step-150")} -def test_restore_respects_explicit_checkpoint_path_with_additional_paths(tmp_path: Path): +def test_restore_respects_explicit_checkpoint_path_as_single_search_path(tmp_path: Path): permanent_root = tmp_path / "checkpoints" temp_root = tmp_path / "checkpoints-temp" explicit_checkpoint = permanent_root / "step-100" @@ -138,11 +137,10 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial loaded = restore_grug_state_from_checkpoint( {"state": "init"}, - checkpoint_path=str(explicit_checkpoint), + checkpoint_search_paths=[str(explicit_checkpoint)], load_checkpoint_setting=True, mesh=None, allow_partial=False, - additional_checkpoint_paths=[str(temp_root)], _load_fn=fake_load, ) @@ -167,11 +165,10 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial loaded = restore_grug_state_from_checkpoint( {"state": "init"}, - checkpoint_path=str(permanent_root), + checkpoint_search_paths=[str(permanent_root), str(temp_root)], load_checkpoint_setting=None, mesh=None, allow_partial=False, - additional_checkpoint_paths=[str(temp_root)], _load_fn=fake_load, ) @@ -197,7 +194,7 @@ def test_restore_supports_legacy_wrapped_and_current_checkpoint_formats(tmp_path loaded_legacy = restore_grug_state_from_checkpoint( template_state, - checkpoint_path=str(checkpoint_root), + checkpoint_search_paths=[str(checkpoint_root)], load_checkpoint_setting=True, mesh=None, allow_partial=False, @@ -210,7 +207,7 @@ def test_restore_supports_legacy_wrapped_and_current_checkpoint_formats(tmp_path loaded_current = restore_grug_state_from_checkpoint( template_state, - checkpoint_path=str(checkpoint_root), + checkpoint_search_paths=[str(checkpoint_root)], load_checkpoint_setting=True, mesh=None, allow_partial=False, From 72ffd637f39c09a06022d16a06fb88986724c36a Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 21 Apr 2026 16:09:19 -0700 Subject: [PATCH 4/6] Centralize checkpoint search paths --- experiments/grug/base/train.py | 11 +---------- experiments/grug/modular_opt/train.py | 11 +---------- experiments/grug/moe/train.py | 11 +---------- lib/levanter/src/levanter/trainer.py | 11 +++++++++++ lib/levanter/tests/test_checkpoint.py | 16 ++++++++++++++++ 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/experiments/grug/base/train.py b/experiments/grug/base/train.py index 98fb8b6439..85ed20aab8 100644 --- a/experiments/grug/base/train.py +++ b/experiments/grug/base/train.py @@ -372,18 +372,9 @@ def _init_state(model_rng): state = _init_state(model_key) checkpointer = trainer.checkpointer.create(run_id) - if trainer.load_checkpoint_path is not None: - checkpoint_search_paths = [trainer.load_checkpoint_path] - elif checkpointer is not None: - checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)] - temp_path = trainer.checkpointer.expanded_temporary_path(run_id) - if temp_path is not None: - checkpoint_search_paths.append(temp_path) - else: - checkpoint_search_paths = [] state = restore_grug_state_from_checkpoint( state, - checkpoint_search_paths=checkpoint_search_paths, + checkpoint_search_paths=trainer.checkpoint_search_paths(run_id), load_checkpoint_setting=trainer.load_checkpoint, mesh=mesh, allow_partial=trainer.allow_partial_checkpoint, diff --git a/experiments/grug/modular_opt/train.py b/experiments/grug/modular_opt/train.py index 887b45dfa0..40c6d962ff 100644 --- a/experiments/grug/modular_opt/train.py +++ b/experiments/grug/modular_opt/train.py @@ -372,18 +372,9 @@ def _init_state(model_rng): state = _init_state(model_key) checkpointer = trainer.checkpointer.create(run_id) - if trainer.load_checkpoint_path is not None: - checkpoint_search_paths = [trainer.load_checkpoint_path] - elif checkpointer is not None: - checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)] - temp_path = trainer.checkpointer.expanded_temporary_path(run_id) - if temp_path is not None: - checkpoint_search_paths.append(temp_path) - else: - checkpoint_search_paths = [] state = restore_grug_state_from_checkpoint( state, - checkpoint_search_paths=checkpoint_search_paths, + checkpoint_search_paths=trainer.checkpoint_search_paths(run_id), load_checkpoint_setting=trainer.load_checkpoint, mesh=mesh, allow_partial=trainer.allow_partial_checkpoint, diff --git a/experiments/grug/moe/train.py b/experiments/grug/moe/train.py index ef47ccc457..a931123671 100644 --- a/experiments/grug/moe/train.py +++ b/experiments/grug/moe/train.py @@ -410,18 +410,9 @@ def _init_state(model_rng): state = _init_state(model_key) checkpointer = trainer.checkpointer.create(run_id) - if trainer.load_checkpoint_path is not None: - checkpoint_search_paths = [trainer.load_checkpoint_path] - elif checkpointer is not None: - checkpoint_search_paths = [trainer.checkpointer.expanded_path(run_id)] - temp_path = trainer.checkpointer.expanded_temporary_path(run_id) - if temp_path is not None: - checkpoint_search_paths.append(temp_path) - else: - checkpoint_search_paths = [] state = restore_grug_state_from_checkpoint( state, - checkpoint_search_paths=checkpoint_search_paths, + checkpoint_search_paths=trainer.checkpoint_search_paths(run_id), load_checkpoint_setting=trainer.load_checkpoint, mesh=mesh, allow_partial=trainer.allow_partial_checkpoint, diff --git a/lib/levanter/src/levanter/trainer.py b/lib/levanter/src/levanter/trainer.py index dcc119a04c..4e0993874f 100644 --- a/lib/levanter/src/levanter/trainer.py +++ b/lib/levanter/src/levanter/trainer.py @@ -854,6 +854,17 @@ def batch_axis_name(self) -> str | None: """if None (default), we'll load a checkpoint if it exists. If true, we must load a checkpoint""" load_checkpoint_path: Optional[str] = None """can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path.""" + + def checkpoint_search_paths(self, run_id: str) -> list[str]: + if self.load_checkpoint_path is not None: + return [self.load_checkpoint_path] + + paths = [self.checkpointer.expanded_path(run_id)] + temp_path = self.checkpointer.expanded_temporary_path(run_id) + if temp_path is not None: + paths.append(temp_path) + return paths + initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from """Load and continue training from a checkpoint. If None, will initialize from model_init.""" allow_partial_checkpoint: bool = False diff --git a/lib/levanter/tests/test_checkpoint.py b/lib/levanter/tests/test_checkpoint.py index fed4da61f1..cd35501cc7 100644 --- a/lib/levanter/tests/test_checkpoint.py +++ b/lib/levanter/tests/test_checkpoint.py @@ -39,6 +39,7 @@ save_checkpoint, unregister_debug_checkpointer_state_provider, ) +from levanter.trainer import TrainerConfig from levanter.trainer_state import TrainerState @@ -348,6 +349,21 @@ def test_checkpointer_config_no_temporary_base_path(): assert config.expanded_temporary_path("run1") is None +def test_trainer_config_checkpoint_search_paths(): + config = dataclasses.replace( + TrainerConfig(), + checkpointer=CheckpointerConfig( + base_path="/tmp/test-perm", + temporary_base_path="/tmp/test-temp", + append_run_id_to_base_path=True, + ), + ) + assert config.checkpoint_search_paths("run1") == ["/tmp/test-perm/run1", "/tmp/test-temp/run1"] + + pinned_config = dataclasses.replace(config, load_checkpoint_path="/tmp/test-perm/run1/step-100") + assert pinned_config.checkpoint_search_paths("run1") == ["/tmp/test-perm/run1/step-100"] + + def test_checkpointer_config_propagates_debug_settings(): config = CheckpointerConfig( base_path="/tmp/checkpoints", From 45935c897685970b192623799b2566f85a32cf3d Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 22 Apr 2026 10:55:37 -0700 Subject: [PATCH 5/6] Search temporary checkpoints from trainer restore --- lib/levanter/src/levanter/checkpoint.py | 17 ++++++++---- lib/levanter/src/levanter/trainer.py | 20 +++++++------ lib/levanter/tests/test_checkpoint.py | 37 +++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 14 deletions(-) diff --git a/lib/levanter/src/levanter/checkpoint.py b/lib/levanter/src/levanter/checkpoint.py index 1627c925a5..8024d1b612 100644 --- a/lib/levanter/src/levanter/checkpoint.py +++ b/lib/levanter/src/levanter/checkpoint.py @@ -779,6 +779,7 @@ def load_checkpoint( tree: M, checkpoint_path: PathLike, *, + additional_checkpoint_paths: Sequence[PathLike] = (), subpath: Optional[str] = None, discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, @@ -798,6 +799,7 @@ def load_checkpoint( Args: tree: an exemplar of the tree to load. Can be a PyTree[ShapeDTypeStruct] instead of a PyTree[Any] checkpoint_path: the path to load the checkpoint from + additional_checkpoint_paths: extra roots to search when discover_latest is True subpath: the subpath to load from the checkpoint discover_latest: whether to discover the latest checkpoint in the given path axis_mapping: the axis mapping to use for loading the checkpoint @@ -807,20 +809,19 @@ def load_checkpoint( the loaded checkpoint, with the same structure as the exemplar tree """ - fs: AbstractFileSystem - fs, _ = _get_fs_and_plain_path(checkpoint_path) - checkpoint_path = str(checkpoint_path) if is_in_jit(): logger.warning("Loading checkpoint in jit. This is not recommended and probably won't work.") if discover_latest: - discovered_checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore + discovered_checkpoint_path = discover_latest_checkpoint(checkpoint_path, *additional_checkpoint_paths) else: + if additional_checkpoint_paths: + raise ValueError("additional_checkpoint_paths only applies when discover_latest=True") discovered_checkpoint_path = checkpoint_path - if discovered_checkpoint_path is None or not fs.exists(discovered_checkpoint_path): + if discovered_checkpoint_path is None or not fsspec_utils.exists(discovered_checkpoint_path): raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") checkpoint_path = discovered_checkpoint_path @@ -842,6 +843,7 @@ def load_checkpoint_or_initialize( init_fn: Callable[Sig, M], checkpoint_path: PathLike, *, + additional_checkpoint_paths: Sequence[PathLike] = (), subpath: Optional[str] = None, discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, @@ -874,6 +876,7 @@ def load_checkpoint_or_initialize( Args: init_fn: a function to initialize if needed checkpoint_path: the path to load the checkpoint from + additional_checkpoint_paths: extra roots to search when discover_latest is True subpath: the subpath to load from the checkpoint discover_latest: whether to discover the latest checkpoint in the given path axis_mapping: the axis mapping to use for loading the checkpoint @@ -921,6 +924,7 @@ def load_or_init(*args, **kwargs): loaded_state = load_checkpoint( filtered_state_shape, checkpoint_path, + additional_checkpoint_paths=additional_checkpoint_paths, subpath=subpath, discover_latest=discover_latest, axis_mapping=axis_mapping, @@ -956,7 +960,7 @@ def discover_latest_checkpoint(checkpoint_path: PathLike, *additional_paths: Pat """ all_paths = [str(checkpoint_path)] + [str(p) for p in additional_paths] best: Optional[str] = None - best_key: Optional[tuple] = None + best_key: tuple[datetime.datetime, int] | None = None for cp_path in all_paths: found = _discover_latest_checkpoint_single(cp_path) @@ -966,6 +970,7 @@ def discover_latest_checkpoint(checkpoint_path: PathLike, *additional_paths: Pat metadata = _load_metadata(found) key = (datetime.datetime.fromisoformat(metadata["timestamp"]), metadata["step"]) except Exception: + logger.exception("Error loading metadata for discovered checkpoint %s", found) continue if best_key is None or key > best_key: best = found diff --git a/lib/levanter/src/levanter/trainer.py b/lib/levanter/src/levanter/trainer.py index 4e0993874f..c118cbaefc 100644 --- a/lib/levanter/src/levanter/trainer.py +++ b/lib/levanter/src/levanter/trainer.py @@ -419,20 +419,22 @@ def initial_state( assert model_init is not None # first try to load a full trainer state checkpoint - checkpoint_path = self.checkpoint_path + checkpoint_search_paths = self.checkpoint_search_paths + checkpoint_path = checkpoint_search_paths[0] load_checkpoint = self.config.load_checkpoint # we don't save the full trainer state, so we need to filter out the non-trainable parameters - if load_checkpoint is True and not fsspec_utils.exists(checkpoint_path): - raise FileNotFoundError(f"Checkpoint {checkpoint_path} does not exist") + if load_checkpoint is True and not any(fsspec_utils.exists(path) for path in checkpoint_search_paths): + raise FileNotFoundError(f"Checkpoint search paths do not exist: {checkpoint_search_paths}") elif load_checkpoint is None: - load_checkpoint = levanter.checkpoint.is_checkpoint_path(checkpoint_path) + load_checkpoint = any(levanter.checkpoint.is_checkpoint_path(path) for path in checkpoint_search_paths) if load_checkpoint is False and self.config.initialize_from is not None: # we're not going to load a checkpoint from this run, so instead we can initialize from a different run logger.info(f"Initializing from {self.config.initialize_from}") load_checkpoint = True checkpoint_path = self.config.initialize_from + checkpoint_search_paths = [checkpoint_path] if not is_checkpoint_path(checkpoint_path): raise ValueError(f"initialize_from must be a checkpoint path, got {checkpoint_path}") @@ -456,6 +458,7 @@ def init_state_and_model(model_init, training_key): state = load_checkpoint_or_initialize( init_state_and_model, checkpoint_path, + additional_checkpoint_paths=checkpoint_search_paths[1:], axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh, is_checkpointed=saveable_train_state, @@ -465,12 +468,13 @@ def init_state_and_model(model_init, training_key): return state + @property + def checkpoint_search_paths(self) -> list[str]: + return self.config.checkpoint_search_paths(self.run_id) + @property def checkpoint_path(self) -> str: - checkpoint_path = self.config.load_checkpoint_path - if checkpoint_path is None: - checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - return checkpoint_path + return self.checkpoint_search_paths[0] def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]: """ diff --git a/lib/levanter/tests/test_checkpoint.py b/lib/levanter/tests/test_checkpoint.py index cd35501cc7..9589f47a9d 100644 --- a/lib/levanter/tests/test_checkpoint.py +++ b/lib/levanter/tests/test_checkpoint.py @@ -594,6 +594,43 @@ def init_fn(key): ) +def test_load_from_checkpoint_or_initialize_searches_additional_paths(): + In = Axis("in", 2) + Out = Axis("out", 1) + + def init_fn(key): + return hax.nn.MLP.init(In, Out, 2, 1, key=key, use_bias=False, use_final_bias=False) + + with use_test_mesh(), tempfile.TemporaryDirectory() as permanent_dir, tempfile.TemporaryDirectory() as temp_dir: + k0 = jax.random.PRNGKey(0) + k1 = jax.random.PRNGKey(1) + model0 = eqx.filter_jit(init_fn)(k0) + model1 = eqx.filter_jit(init_fn)(k1) + + is_checkpointed = hax.tree_util.tree_map(lambda _: False, model0) + is_checkpointed = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed, replace=True) + + filtered = eqx.filter(model0, is_checkpointed) + save_checkpoint(filtered, step=0, checkpoint_path=temp_dir) + + loaded = load_checkpoint_or_initialize( + init_fn, + permanent_dir, + additional_checkpoint_paths=[temp_dir], + is_checkpointed=is_checkpointed, + donate_args=False, + )(k1) + + assert_trees_all_equal( + jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), + jax.tree_util.tree_leaves(arrays_only(eqx.filter(model0, is_checkpointed))), + ) + assert_trees_all_equal( + jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed, inverse=True))), + jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed, inverse=True))), + ) + + def test_load_from_checkpoint_or_initialize_works_if_file_not_found(): In = Axis("in", 2) Out = Axis("out", 1) From 5dcbd850660bac7ace22f2bba8137766a4468c99 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 22 Apr 2026 11:29:19 -0700 Subject: [PATCH 6/6] Separate checkpoint discovery from loading --- experiments/grug/checkpointing.py | 2 - lib/levanter/src/levanter/checkpoint.py | 90 ++++++++----------- lib/levanter/src/levanter/eval_harness.py | 5 +- lib/levanter/src/levanter/main/eval_lm.py | 5 +- .../src/levanter/main/export_lm_to_hf.py | 7 +- .../src/levanter/main/inference_repl.py | 21 ++--- .../src/levanter/main/perplexity_gap.py | 5 +- lib/levanter/src/levanter/main/sample_lm.py | 5 +- lib/levanter/src/levanter/main/train_dpo.py | 10 +-- lib/levanter/src/levanter/main/train_lm.py | 5 +- .../src/levanter/main/viz_logprobs.py | 8 +- lib/levanter/src/levanter/trainer.py | 4 +- lib/levanter/tests/test_checkpoint.py | 23 ++--- .../src/marin/evaluation/save_logprobs.py | 5 +- lib/marin/src/marin/rl/model_utils.py | 5 +- tests/test_grug_checkpointing.py | 14 ++- 16 files changed, 98 insertions(+), 116 deletions(-) diff --git a/experiments/grug/checkpointing.py b/experiments/grug/checkpointing.py index 24b209aed1..40db3d35ad 100644 --- a/experiments/grug/checkpointing.py +++ b/experiments/grug/checkpointing.py @@ -142,7 +142,6 @@ def _load_candidate_state( return load_fn( state, candidate, - discover_latest=False, axis_mapping=None, mesh=mesh, allow_partial=allow_partial, @@ -152,7 +151,6 @@ def _load_candidate_state( wrapped = load_fn( {"train_state": state}, candidate, - discover_latest=False, axis_mapping=None, mesh=mesh, allow_partial=allow_partial, diff --git a/lib/levanter/src/levanter/checkpoint.py b/lib/levanter/src/levanter/checkpoint.py index 8024d1b612..2a4b8a8336 100644 --- a/lib/levanter/src/levanter/checkpoint.py +++ b/lib/levanter/src/levanter/checkpoint.py @@ -476,40 +476,26 @@ def __init__( def load_checkpoint( self, state: M, - path: Optional[PathLike] = None, + checkpoint_path: PathLike, *, - discover_latest: bool = True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[M]: - if path is None: - # When temporary_base_path is set, discover the newest checkpoint across both roots - if discover_latest and self.temporary_base_path is not None: - latest = discover_latest_checkpoint(self.base_path, self.temporary_base_path) - if latest is not None: - return load_checkpoint(state, latest, discover_latest=False, axis_mapping=axis_mapping, mesh=mesh) - return None - path = self.base_path - return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh) + ) -> M: + return load_checkpoint(state, checkpoint_path, axis_mapping=axis_mapping, mesh=mesh) def load_model( self, model: M, - path: Optional[str] = None, + checkpoint_path: PathLike, *, - discover_latest: bool = True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[M]: + ) -> M: """ Convenience method/holdover from previous API for loading checkpoints. - Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. + Loads just the model assuming the model is in the `model` subdir of the checkpoint. """ - ret_dict = self.load_checkpoint( - {"model": model}, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) - if ret_dict is None: - return None + ret_dict = self.load_checkpoint({"model": model}, checkpoint_path, axis_mapping=axis_mapping, mesh=mesh) return ret_dict["model"] def on_step(self, *, tree: PyTree, step: int, force: bool = False): @@ -779,9 +765,7 @@ def load_checkpoint( tree: M, checkpoint_path: PathLike, *, - additional_checkpoint_paths: Sequence[PathLike] = (), subpath: Optional[str] = None, - discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, allow_partial: bool = False, @@ -792,16 +776,14 @@ def load_checkpoint( Supports both OCDBT (new format) and non-OCDBT (old format) checkpoints through automatic format detection. - If discover_latest is True, then the latest checkpoint in a subdirectory of the given path - will be loaded. If subpath is not None, then the checkpoint loads only that subpath of the - checkpoint. This is useful for loading, e.g., just the model and not the entire training state. + This function expects ``checkpoint_path`` to already point at a concrete checkpoint directory. + Use ``discover_latest_checkpoint`` or ``latest_checkpoint_path`` before calling when accepting + a parent directory. Args: tree: an exemplar of the tree to load. Can be a PyTree[ShapeDTypeStruct] instead of a PyTree[Any] - checkpoint_path: the path to load the checkpoint from - additional_checkpoint_paths: extra roots to search when discover_latest is True + checkpoint_path: the concrete checkpoint directory to load from subpath: the subpath to load from the checkpoint - discover_latest: whether to discover the latest checkpoint in the given path axis_mapping: the axis mapping to use for loading the checkpoint mesh: the mesh to use for loading the checkpoint allow_partial: if True, allow partial loading of the checkpoint. If False, all parameters must be present in the checkpoint. @@ -814,18 +796,9 @@ def load_checkpoint( if is_in_jit(): logger.warning("Loading checkpoint in jit. This is not recommended and probably won't work.") - if discover_latest: - discovered_checkpoint_path = discover_latest_checkpoint(checkpoint_path, *additional_checkpoint_paths) - else: - if additional_checkpoint_paths: - raise ValueError("additional_checkpoint_paths only applies when discover_latest=True") - discovered_checkpoint_path = checkpoint_path - - if discovered_checkpoint_path is None or not fsspec_utils.exists(discovered_checkpoint_path): + if not fsspec_utils.exists(checkpoint_path): raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") - checkpoint_path = discovered_checkpoint_path - logger.info(f"Loading checkpoint from {checkpoint_path}") if subpath: @@ -841,9 +814,8 @@ def load_checkpoint( def load_checkpoint_or_initialize( init_fn: Callable[Sig, M], - checkpoint_path: PathLike, + checkpoint_search_paths: Sequence[PathLike], *, - additional_checkpoint_paths: Sequence[PathLike] = (), subpath: Optional[str] = None, discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, @@ -855,10 +827,10 @@ def load_checkpoint_or_initialize( allow_partial: bool = False, ) -> Callable[Sig, M]: """ - Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint - in a subdirectory of the given path will be loaded. If subpath is not None, then the checkpoint - loads only that subpath of the checkpoint. This is useful for loading, e.g., just the model and not - the entire training state. + Load from checkpoint search paths, or initialize from scratch when no checkpoint is available. + If discover_latest is True, the latest checkpoint across the search paths will be loaded. If + subpath is not None, only that subpath of the checkpoint is loaded. This is useful for loading, + e.g., just the model and not the entire training state. This function supports "partial" checkpoint loading, where only a subset of the parameters of the state is loaded from the checkpoint. This is useful for initializing just some parameters. @@ -875,10 +847,9 @@ def load_checkpoint_or_initialize( Args: init_fn: a function to initialize if needed - checkpoint_path: the path to load the checkpoint from - additional_checkpoint_paths: extra roots to search when discover_latest is True + checkpoint_search_paths: paths to search for a checkpoint. If discover_latest is False, this must contain exactly one concrete checkpoint path. subpath: the subpath to load from the checkpoint - discover_latest: whether to discover the latest checkpoint in the given path + discover_latest: whether to discover the latest checkpoint in the search paths axis_mapping: the axis mapping to use for loading the checkpoint mesh: the mesh to use for loading the checkpoint is_checkpointed: a FilterSpec that specifies which parameters are checkpointed @@ -892,6 +863,9 @@ def load_checkpoint_or_initialize( loaded state. """ + if len(checkpoint_search_paths) == 0: + raise ValueError("checkpoint_search_paths must contain at least one path") + checkpoint_search_paths = [str(path) for path in checkpoint_search_paths] # some state might not be initialized, so we need to initialize it # JAX will be smart and only do the compute for things we actually need @@ -921,12 +895,17 @@ def load_or_init(*args, **kwargs): if do_load is not False: # now we can load the checkpoint try: + if discover_latest: + checkpoint_path = latest_checkpoint_path(checkpoint_search_paths[0], *checkpoint_search_paths[1:]) + else: + if len(checkpoint_search_paths) != 1: + raise ValueError("discover_latest=False requires exactly one checkpoint search path") + checkpoint_path = checkpoint_search_paths[0] + loaded_state = load_checkpoint( filtered_state_shape, checkpoint_path, - additional_checkpoint_paths=additional_checkpoint_paths, subpath=subpath, - discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh, allow_partial=allow_partial, @@ -934,7 +913,7 @@ def load_or_init(*args, **kwargs): except FileNotFoundError: if do_load is True: raise - logger.info(f"Checkpoint not found at {checkpoint_path}. Initializing from scratch.") + logger.info(f"Checkpoint not found in {checkpoint_search_paths}. Initializing from scratch.") state = init_and_merge(loaded_state, *args, **kwargs) @@ -983,6 +962,15 @@ def discover_latest_checkpoint(checkpoint_path: PathLike, *additional_paths: Pat return best +def latest_checkpoint_path(checkpoint_path: PathLike, *additional_paths: PathLike) -> str: + """Return the latest concrete checkpoint path across one or more search roots.""" + latest = discover_latest_checkpoint(checkpoint_path, *additional_paths) + if latest is None: + search_paths = [str(checkpoint_path)] + [str(path) for path in additional_paths] + raise FileNotFoundError(f"Could not discover checkpoint under any of: {search_paths}") + return latest + + def _discover_latest_checkpoint_single(checkpoint_path: str) -> Optional[str]: """Discover the latest checkpoint in a single root path.""" fs: AbstractFileSystem diff --git a/lib/levanter/src/levanter/eval_harness.py b/lib/levanter/src/levanter/eval_harness.py index 9cffde485e..698d08931d 100644 --- a/lib/levanter/src/levanter/eval_harness.py +++ b/lib/levanter/src/levanter/eval_harness.py @@ -77,7 +77,7 @@ import levanter.config from levanter.callbacks import StepInfo -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.data import batched from levanter.data.loader import stack_batches from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel @@ -1475,9 +1475,10 @@ def run_eval_harness_main(config: EvalHarnessMainConfig): else: with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) + checkpoint_path = latest_checkpoint_path(config.checkpoint_path) model = load_checkpoint( model, - config.checkpoint_path, + checkpoint_path, subpath="model", axis_mapping=parameter_axis_mapping, ) diff --git a/lib/levanter/src/levanter/main/eval_lm.py b/lib/levanter/src/levanter/main/eval_lm.py index 404b8f0921..552eef33dc 100644 --- a/lib/levanter/src/levanter/main/eval_lm.py +++ b/lib/levanter/src/levanter/main/eval_lm.py @@ -16,7 +16,7 @@ from haliax.partitioning import round_axis_for_partitioning import levanter -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef from levanter.data import DataLoader from levanter.data.text import LmDataConfig @@ -127,7 +127,8 @@ def compute_logits(model: LmHeadModel, example: LmExample): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: can't load the EMA model with current setup here. Not a big deal for now. # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - model = load_checkpoint(model, config.checkpoint_path, subpath="model") + checkpoint_path = latest_checkpoint_path(config.checkpoint_path) + model = load_checkpoint(model, checkpoint_path, subpath="model") model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) elif config.hf_checkpoint is not None: diff --git a/lib/levanter/src/levanter/main/export_lm_to_hf.py b/lib/levanter/src/levanter/main/export_lm_to_hf.py index 586d92d7be..66b6478857 100644 --- a/lib/levanter/src/levanter/main/export_lm_to_hf.py +++ b/lib/levanter/src/levanter/main/export_lm_to_hf.py @@ -13,7 +13,7 @@ from haliax import Axis import levanter -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import RepoRef, load_tokenizer, HFCompatConfig from levanter.models.llama import LlamaConfig from levanter.models.lm_model import LmConfig, LmHeadModel @@ -67,8 +67,9 @@ def main(config: ConvertLmConfig): model: LmHeadModel = eqx.filter_eval_shape(config.model.build, Vocab, key=key) trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - logger.info(f"Loading checkpoint from {config.checkpoint_path}...") - trainable = load_checkpoint(trainable, config.checkpoint_path, subpath="model") + checkpoint_path = latest_checkpoint_path(config.checkpoint_path) + logger.info(f"Loading checkpoint from {checkpoint_path}...") + trainable = load_checkpoint(trainable, checkpoint_path, subpath="model") assert trainable is not None model = eqx.combine(trainable, non_trainable) diff --git a/lib/levanter/src/levanter/main/inference_repl.py b/lib/levanter/src/levanter/main/inference_repl.py index 5e7a18aedc..f33753254e 100644 --- a/lib/levanter/src/levanter/main/inference_repl.py +++ b/lib/levanter/src/levanter/main/inference_repl.py @@ -37,7 +37,7 @@ from rich.panel import Panel import levanter -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer from levanter.tokenizers import MarinTokenizer from levanter.inference.engine import InferenceEngineConfig @@ -52,7 +52,6 @@ ) from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.trainer import TrainerConfig -from levanter.utils.jax_utils import use_cpu_device from levanter.utils.mesh import MeshConfig logger = logging.getLogger(__name__) @@ -64,17 +63,6 @@ jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") -def weight_loader(server, server_config, current_model: LmHeadModel) -> LmHeadModel: - with use_cpu_device(): - key = jrandom.PRNGKey(server_config.seed) - vocab_size = server.inference_context.tokenizer.vocab_size - Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), server_config.trainer.param_axis_mapping) - model = eqx.filter_eval_shape(server_config.model.build, Vocab, key=key) - model = load_checkpoint(model, model, subpath="model") - model = server_config.trainer.mp.cast_to_compute(model) - return model - - def _load_model( trainer_config: TrainerConfig, model_config: LmConfig, @@ -99,9 +87,10 @@ def _load_model( if levanter_checkpoint is not None: model = eqx.filter_eval_shape(model_config.build, Vocab, key=key) + checkpoint_path = latest_checkpoint_path(levanter_checkpoint) model = load_checkpoint( model, - levanter_checkpoint, + checkpoint_path, subpath="model", axis_mapping=trainer_config.parameter_axis_mapping, ) @@ -248,8 +237,8 @@ def load(self, path: str, tokenizer: Optional[str] = None, **kwargs): if self.server is not None: - def _reload(current_model: LmHeadModel) -> LmHeadModel: - return weight_loader(self.server, self.config.server, current_model) + def _reload(_current_model: LmHeadModel) -> LmHeadModel: + return model self.server.reload(_reload) else: diff --git a/lib/levanter/src/levanter/main/perplexity_gap.py b/lib/levanter/src/levanter/main/perplexity_gap.py index 78565f78ba..bf19bf2da6 100644 --- a/lib/levanter/src/levanter/main/perplexity_gap.py +++ b/lib/levanter/src/levanter/main/perplexity_gap.py @@ -30,7 +30,7 @@ tokenize_text_with_byte_spans, write_report_files, ) -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig from levanter.data.text import DatasetComponent from levanter.data.text.examples import GrugLmExample, named_lm_example_from_grug @@ -265,7 +265,8 @@ def compute_losses(model: LmHeadModel, batch: GrugLmExample): else: with use_cpu_device(): model = eqx.filter_eval_shape(spec.model.build, Vocab, key=key) - model = load_checkpoint(model, spec.checkpoint_path, subpath="model") + checkpoint_path = latest_checkpoint_path(spec.checkpoint_path) + model = load_checkpoint(model, checkpoint_path, subpath="model") model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) label = _model_label(spec) diff --git a/lib/levanter/src/levanter/main/sample_lm.py b/lib/levanter/src/levanter/main/sample_lm.py index 312be6c99c..a6bc12679c 100644 --- a/lib/levanter/src/levanter/main/sample_lm.py +++ b/lib/levanter/src/levanter/main/sample_lm.py @@ -15,7 +15,7 @@ import levanter from levanter.callbacks import profile_ctx -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef, load_tokenizer from levanter.inference.engine import InferenceEngine, InferenceEngineConfig, Request from levanter.inference.jit_scheduler import SeqDecodingParams @@ -77,9 +77,10 @@ def _load_model(config: SampleLmConfig, Vocab: Axis, *, key) -> LmHeadModel: if config.checkpoint_path is not None: with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) + checkpoint_path = latest_checkpoint_path(config.checkpoint_path) model = load_checkpoint( model, - config.checkpoint_path, + checkpoint_path, subpath="model", axis_mapping=config.trainer.parameter_axis_mapping, ) diff --git a/lib/levanter/src/levanter/main/train_dpo.py b/lib/levanter/src/levanter/main/train_dpo.py index 59f8099eff..bfb166a037 100644 --- a/lib/levanter/src/levanter/main/train_dpo.py +++ b/lib/levanter/src/levanter/main/train_dpo.py @@ -17,7 +17,7 @@ import levanter import levanter.callbacks from levanter import callbacks -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, build_generation_config from levanter.data.dataset import AsyncDataset from levanter.data.mixture import MixtureDataset @@ -379,9 +379,8 @@ def loss_function(model: DpoModel, example: DpoExample, *, key=None): elif config.initialize_from_checkpoint_path is not None: with use_cpu_device(): policy_model = eqx.filter_eval_shape(config.model.build, Vocab, key=init_policy_key) - policy_model = load_checkpoint( - policy_model, config.initialize_from_checkpoint_path, subpath="model" - ) + checkpoint_path = latest_checkpoint_path(config.initialize_from_checkpoint_path) + policy_model = load_checkpoint(policy_model, checkpoint_path, subpath="model") policy_model = hax.shard(policy_model, parameter_axis_mapping) policy_model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(policy_model) else: @@ -406,7 +405,8 @@ def loss_function(model: DpoModel, example: DpoExample, *, key=None): else: with use_cpu_device(): reference_model = eqx.filter_eval_shape(config.model.build, Vocab, key=model_key) - reference_model = load_checkpoint(reference_model, config.reference_model_path, subpath="model") + checkpoint_path = latest_checkpoint_path(config.reference_model_path) + reference_model = load_checkpoint(reference_model, checkpoint_path, subpath="model") reference_model = hax.shard(reference_model, parameter_axis_mapping) reference_model = cast(LmHeadModel, inference_mode(reference_model, True)) diff --git a/lib/levanter/src/levanter/main/train_lm.py b/lib/levanter/src/levanter/main/train_lm.py index 84fce12fb6..fa404fad76 100644 --- a/lib/levanter/src/levanter/main/train_lm.py +++ b/lib/levanter/src/levanter/main/train_lm.py @@ -20,7 +20,7 @@ import levanter.eval_harness from levanter import callbacks from levanter.callbacks.tensorstore_callbacks import install_tensorstore_metrics_hook_if_enabled -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, build_generation_config, save_hf_checkpoint_callback from levanter.data.mixture import MixtureDataset from levanter.data.text import LmDataConfig @@ -174,7 +174,8 @@ def loss_function(model: LmHeadModel, example: LmExample, *, key=None): state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) if int(state.step) == 0 and config.initialize_from_checkpoint_path is not None: - state = load_checkpoint(state, config.initialize_from_checkpoint_path) + checkpoint_path = latest_checkpoint_path(config.initialize_from_checkpoint_path) + state = load_checkpoint(state, checkpoint_path) # reset to step 0, we're just initializing weights here state = dataclasses.replace(state, step=jnp.array(0)) diff --git a/lib/levanter/src/levanter/main/viz_logprobs.py b/lib/levanter/src/levanter/main/viz_logprobs.py index 8367f73c72..6548779542 100644 --- a/lib/levanter/src/levanter/main/viz_logprobs.py +++ b/lib/levanter/src/levanter/main/viz_logprobs.py @@ -15,7 +15,7 @@ from haliax.partitioning import round_axis_for_partitioning import levanter -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.data import DataLoader from levanter.data.text import LmDataConfig @@ -115,7 +115,8 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): else: with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) - model = load_checkpoint(model, config.checkpoint_path, subpath="model") + checkpoint_path = latest_checkpoint_path(config.checkpoint_path) + model = load_checkpoint(model, checkpoint_path, subpath="model") model = hax.shard(model, parameter_axis_mapping) model = typing.cast(LmHeadModel, inference_mode(model, True)) @@ -134,7 +135,8 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): else: with use_cpu_device(): comparison_model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) - comparison_model = load_checkpoint(comparison_model, config.comparison_model_path, subpath="model") + comparison_checkpoint_path = latest_checkpoint_path(config.comparison_model_path) + comparison_model = load_checkpoint(comparison_model, comparison_checkpoint_path, subpath="model") comparison_model = hax.shard(comparison_model, parameter_axis_mapping) comparison_model = typing.cast(LmHeadModel, inference_mode(comparison_model, True)) else: diff --git a/lib/levanter/src/levanter/trainer.py b/lib/levanter/src/levanter/trainer.py index c118cbaefc..f0ea22d0ed 100644 --- a/lib/levanter/src/levanter/trainer.py +++ b/lib/levanter/src/levanter/trainer.py @@ -420,7 +420,6 @@ def initial_state( # first try to load a full trainer state checkpoint checkpoint_search_paths = self.checkpoint_search_paths - checkpoint_path = checkpoint_search_paths[0] load_checkpoint = self.config.load_checkpoint # we don't save the full trainer state, so we need to filter out the non-trainable parameters @@ -457,8 +456,7 @@ def init_state_and_model(model_init, training_key): state = load_checkpoint_or_initialize( init_state_and_model, - checkpoint_path, - additional_checkpoint_paths=checkpoint_search_paths[1:], + checkpoint_search_paths, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh, is_checkpointed=saveable_train_state, diff --git a/lib/levanter/tests/test_checkpoint.py b/lib/levanter/tests/test_checkpoint.py index 9589f47a9d..26d7f20821 100644 --- a/lib/levanter/tests/test_checkpoint.py +++ b/lib/levanter/tests/test_checkpoint.py @@ -207,7 +207,6 @@ def test_checkpoint_simple(): restored_state = load_checkpoint( rep_state, checkpoint_path=tmpdir, - discover_latest=False, ) assert_trees_all_equal( @@ -246,7 +245,7 @@ def loss_fn(model, data): with tempfile.TemporaryDirectory() as tmpdir: save_checkpoint(state, step=3, checkpoint_path=tmpdir) - restored_state = load_checkpoint(rep_state, checkpoint_path=tmpdir, discover_latest=False) + restored_state = load_checkpoint(rep_state, checkpoint_path=tmpdir) assert_trees_all_equal( jax.tree_util.tree_leaves(arrays_only(restored_state)), @@ -562,10 +561,14 @@ def init_fn(key): filtered = eqx.filter(model0, is_checkpointed) save_checkpoint(filtered, step=0, checkpoint_path=tmpdir) - loaded = load_checkpoint_or_initialize(init_fn, tmpdir, is_checkpointed=is_checkpointed, donate_args=False)(k1) + loaded = load_checkpoint_or_initialize(init_fn, [tmpdir], is_checkpointed=is_checkpointed, donate_args=False)( + k1 + ) assert not any(jax.tree_util.tree_leaves(eqx.filter(loaded, lambda x: isinstance(x, ShapeDtypeStruct)))) - loaded2 = load_checkpoint(eqx.filter(model1, is_checkpointed), tmpdir, discover_latest=True) + latest_checkpoint = discover_latest_checkpoint(tmpdir) + assert latest_checkpoint is not None + loaded2 = load_checkpoint(eqx.filter(model1, is_checkpointed), latest_checkpoint) loaded2 = eqx.combine(loaded2, model1) assert_trees_all_equal( @@ -615,8 +618,7 @@ def init_fn(key): loaded = load_checkpoint_or_initialize( init_fn, - permanent_dir, - additional_checkpoint_paths=[temp_dir], + [permanent_dir, temp_dir], is_checkpointed=is_checkpointed, donate_args=False, )(k1) @@ -647,9 +649,9 @@ def init_fn(key): is_checkpointed = jtu.tree_map(lambda _: False, model0) is_checkpointed = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed, replace=True) - loaded = load_checkpoint_or_initialize(init_fn, "kanmfklafnmjlkanfjklanfjkh", is_checkpointed=is_checkpointed)( - k1 - ) + loaded = load_checkpoint_or_initialize( + init_fn, ["kanmfklafnmjlkanfjklanfjkh"], is_checkpointed=is_checkpointed + )(k1) assert not any(jax.tree_util.tree_leaves(eqx.filter(loaded, lambda x: isinstance(x, ShapeDtypeStruct)))) # should be the same as model1 @@ -684,7 +686,7 @@ def init_fn(key, use_b): loaded = load_checkpoint_or_initialize( init_fn, - tmpdir, + [tmpdir], is_checkpointed=is_checkpointed, allow_partial=True, )(k1, True) @@ -767,7 +769,6 @@ def test_backward_compatibility_with_ocdbt(): restored_state = load_checkpoint( rep_state, checkpoint_path=tmpdir, - discover_latest=False, ) # Verify the data was loaded correctly diff --git a/lib/marin/src/marin/evaluation/save_logprobs.py b/lib/marin/src/marin/evaluation/save_logprobs.py index 204133b982..d03fad1a56 100644 --- a/lib/marin/src/marin/evaluation/save_logprobs.py +++ b/lib/marin/src/marin/evaluation/save_logprobs.py @@ -30,7 +30,7 @@ from haliax.partitioning import round_axis_for_partitioning import levanter -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef from levanter.data import DataLoader from levanter.data.text import DatasetComponent, LmDataConfig, LMMixtureDatasetConfig @@ -137,7 +137,8 @@ def compute_top(logprobs: hax.NamedArray, k: int): if config.checkpoint_path is not None and not config.checkpoint_is_hf: with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) - model = load_checkpoint(model, config.checkpoint_path, subpath="model") + checkpoint_path = latest_checkpoint_path(config.checkpoint_path) + model = load_checkpoint(model, checkpoint_path, subpath="model") model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) elif hf_checkpoint is not None: model_config = config.model diff --git a/lib/marin/src/marin/rl/model_utils.py b/lib/marin/src/marin/rl/model_utils.py index 1b3a8517d3..36c69809ed 100644 --- a/lib/marin/src/marin/rl/model_utils.py +++ b/lib/marin/src/marin/rl/model_utils.py @@ -15,7 +15,7 @@ import haliax as hax import jax from jax.sharding import Mesh -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import latest_checkpoint_path, load_checkpoint from levanter.compat.hf_checkpoints import ( HFCheckpointConverter, PYTORCH_WEIGHTS_INDEX_NAME, @@ -116,6 +116,7 @@ def load_model_from_checkpoint( else: # Load from local Levanter checkpoint model = eqx.filter_eval_shape(model_config.build, vocab_axis, key=key) - model = load_checkpoint(model, checkpoint, subpath="model", axis_mapping=axis_mapping, mesh=mesh) + checkpoint_path = latest_checkpoint_path(checkpoint) + model = load_checkpoint(model, checkpoint_path, subpath="model", axis_mapping=axis_mapping, mesh=mesh) model = mp.cast_to_compute(model) return model diff --git a/tests/test_grug_checkpointing.py b/tests/test_grug_checkpointing.py index caa7461de0..536bfd8cfe 100644 --- a/tests/test_grug_checkpointing.py +++ b/tests/test_grug_checkpointing.py @@ -27,9 +27,8 @@ def test_restore_prefers_highest_step_over_latest_timestamp(tmp_path: Path): attempted: list[str] = [] - def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial): + def fake_load(state, path, *, axis_mapping, mesh, allow_partial): attempted.append(path) - assert discover_latest is False return {"loaded_from": path} loaded = restore_grug_state_from_checkpoint( @@ -52,9 +51,8 @@ def test_restore_falls_back_to_older_checkpoint_when_latest_fails(tmp_path: Path attempted: list[str] = [] - def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial): + def fake_load(state, path, *, axis_mapping, mesh, allow_partial): attempted.append(path) - assert discover_latest is False if path.endswith("step-100"): raise FileNotFoundError(path) return {"loaded_from": path} @@ -80,7 +78,7 @@ def test_restore_raises_when_required_and_no_checkpoint_loads(tmp_path: Path): checkpoint_root = tmp_path / "checkpoints" _write_checkpoint_metadata(checkpoint_root / "step-100", step=100, timestamp="2026-03-17T10:00:00") - def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial): + def fake_load(state, path, *, axis_mapping, mesh, allow_partial): raise FileNotFoundError(path) with pytest.raises(FileNotFoundError, match="Could not load a checkpoint"): @@ -103,7 +101,7 @@ def test_restore_discovers_candidates_across_search_paths(tmp_path: Path): attempted: list[str] = [] - def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial): + def fake_load(state, path, *, axis_mapping, mesh, allow_partial): attempted.append(path) return {"loaded_from": path} @@ -131,7 +129,7 @@ def test_restore_respects_explicit_checkpoint_path_as_single_search_path(tmp_pat attempted: list[str] = [] - def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial): + def fake_load(state, path, *, axis_mapping, mesh, allow_partial): attempted.append(path) return {"loaded_from": path} @@ -157,7 +155,7 @@ def test_restore_falls_back_from_temp_to_permanent(tmp_path: Path): attempted: list[str] = [] - def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial): + def fake_load(state, path, *, axis_mapping, mesh, allow_partial): attempted.append(path) if "step-150" in path: raise FileNotFoundError(path)