Skip to content

Commit bc92c67

Browse files
angel-coreOrbax Authors
authored andcommitted
Refactor CheckpointLayout splitting load() into load_pytree() and load_checkpointables() each with their own dedicated loading logic.
PiperOrigin-RevId: 869447682
1 parent ebffb79 commit bc92c67

19 files changed

+560
-392
lines changed

checkpoint/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
### Changed
2323

2424
- #v1 Make most V1 public concrete classes final.
25+
- Refactor `CheckpointLayout` splitting `load()` into `load_pytree()` and
26+
`load_checkpointables()` each with their own dedicated loading logic
2527

2628
## [0.11.32] - 2026-01-20
2729

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ def _get_saved_handler_typestrs(
293293

294294
saved_handler_typestrs: dict[str, str] = {}
295295
for checkpointable_path in directory.iterdir():
296+
if not checkpointable_path.is_dir():
297+
continue
296298
serialized_metadata = self._metadata_store.read(
297299
checkpoint_metadata.step_metadata_file_path(checkpointable_path)
298300
)

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/checkpoint_layout.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,24 +89,21 @@ async def validate_pytree(
8989
"""
9090
...
9191

92-
async def load(
92+
async def load_pytree(
9393
self,
9494
path: Path,
95-
abstract_checkpointables: dict[str, Any] | None = None,
96-
) -> Awaitable[dict[str, Any]]:
97-
"""Loads the checkpoint from the given directory.
95+
checkpointable_name: str | None = None,
96+
abstract_pytree: Any | None = None,
97+
) -> Awaitable[Any]:
98+
"""Loads a PyTree from the checkpoint.
9899
99100
Args:
100101
path: The path to the checkpoint.
101-
abstract_checkpointables: A dictionary of abstract checkpointables.
102-
Dictionary keys represent the names of the checkpointables, while the
103-
values are the abstract checkpointable objects themselves.
102+
checkpointable_name: The name of the checkpointable to load.
103+
abstract_pytree: The abstract PyTree structure.
104104
105105
Returns:
106-
An awaitable dictionary of checkpointables. Dictionary keys represent the
107-
names of
108-
the checkpointables, while the values are the checkpointable objects
109-
themselves.
106+
An awaitable PyTree.
110107
"""
111108
...
112109

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/numpy_layout.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _load_leaf(leaf: Any, abstract_leaf: jax.ShapeDtypeStruct):
9898
async def _load_numpy(
9999
path: Path,
100100
abstract_pytree: tree_types.PyTreeOf[jax.ShapeDtypeStruct] | None = None,
101-
) -> dict[str, Any]:
101+
) -> Any:
102102
"""Loads numpy checkpoint as numpy arrays or sharded jax arrays."""
103103
npz_file = await asyncio.to_thread(np.load, path, allow_pickle=True)
104104
try:
@@ -112,7 +112,7 @@ async def _load_numpy(
112112
finally:
113113
npz_file.close()
114114

115-
return {checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: restored_pytree}
115+
return restored_pytree
116116

117117

118118
class NumpyLayout(CheckpointLayout):
@@ -193,20 +193,30 @@ def _read_metadata_sync():
193193
commit_timestamp_nsecs=commit_timestamp_nsecs,
194194
)
195195

196-
async def load(
196+
async def load_pytree(
197197
self,
198198
path: Path,
199-
abstract_checkpointables: (
200-
dict[str, tree_types.PyTreeOf[jax.ShapeDtypeStruct]] | None
201-
) = None,
202-
) -> Awaitable[dict[str, tree_types.PyTreeOf[Any]]]:
203-
"""Loads a NumPy checkpoint file."""
204-
abstract_pytree = None
205-
if abstract_checkpointables:
206-
abstract_pytree = abstract_checkpointables.get(
207-
checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
208-
)
209-
return _load_numpy(path, abstract_pytree)
199+
checkpointable_name: str | None = None,
200+
abstract_pytree: Any | None = None,
201+
) -> Awaitable[tree_types.PyTreeOf[Any]]:
202+
"""Loads a NumPy checkpoint file.
203+
204+
If `abstract_pytree` is provided, it attempts to load numpy arrays as
205+
sharded `jax.Arrays` onto devices.
206+
207+
Args:
208+
path: The path to the checkpoint.
209+
checkpointable_name: The name of the pytree checkpointable to load,
210+
unsused in this case.
211+
abstract_pytree: An optional PyTree of abstract arrays specifying sharding
212+
information.
213+
214+
Returns:
215+
An awaitable of a dictionary containing the loaded PyTree.
216+
"""
217+
del checkpointable_name
218+
load_awaitable = _load_numpy(path, abstract_pytree)
219+
return load_awaitable
210220

211221
async def save(
212222
self,

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/numpy_layout_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,8 @@ async def test_load_numpy_checkpoint(self, dtype: np.dtype):
9595

9696
# Load the checkpoint
9797
layout = NumpyLayout()
98-
restore_fn = await layout.load(test_path)
99-
restored_checkpointables = await restore_fn
100-
pytree = restored_checkpointables['pytree']
98+
restore_fn = await layout.load_pytree(test_path)
99+
pytree = await restore_fn
101100

102101
# Verify restored data
103102
if np.issubdtype(dtype, np.floating):

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py

Lines changed: 143 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,29 @@
1515
"""Defines `OrbaxLayout`, a class to handle Orbax checkpoint formats."""
1616

1717
import asyncio
18+
import enum
1819
from typing import Any, Awaitable
1920

2021
from absl import logging
22+
from orbax.checkpoint._src.metadata import checkpoint as checkpoint_metadata
23+
from orbax.checkpoint._src.metadata import step_metadata_serialization
2124
from orbax.checkpoint._src.path import async_path
2225
from orbax.checkpoint._src.path import temporary_paths
2326
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
2427
from orbax.checkpoint.experimental.v1._src.handlers import composite_handler
28+
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
2529
from orbax.checkpoint.experimental.v1._src.handlers import registration
2630
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
2731
from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility
2832
from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization
2933
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
3034
from orbax.checkpoint.experimental.v1._src.path import types as path_types
35+
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
36+
37+
38+
class CheckpointVersion(enum.Enum):
39+
V0 = 0
40+
V1 = 1
3141

3242

3343
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
@@ -55,6 +65,14 @@
5565
)
5666

5767

68+
def checkpoint_version(path: path_types.PathLike) -> CheckpointVersion:
69+
"""Returns the checkpoint version of the given path."""
70+
if (path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists():
71+
return CheckpointVersion.V1
72+
else:
73+
return CheckpointVersion.V0
74+
75+
5876
async def _subpaths(directory: Path) -> list[Path]:
5977
"""Returns subdirectories up to a limit."""
6078
return list(await async_path.iterdir(directory))
@@ -96,6 +114,11 @@ async def has_pytree_metadata_file(path: Path) -> bool:
96114
return await async_path.exists(path / PYTREE_METADATA_FILE)
97115

98116

117+
async def has_indicator_file(path: Path) -> bool:
118+
"""Checks if the indicator file exists in the given path."""
119+
return await async_path.exists(path / ORBAX_CHECKPOINT_INDICATOR_FILE)
120+
121+
99122
class OrbaxLayout(CheckpointLayout):
100123
"""OrbaxLayout.
101124
@@ -114,10 +137,7 @@ def __init__(self):
114137
include_global_registry=False,
115138
)
116139
self._composite_handler = CompositeHandler(self._handler_registry)
117-
118-
async def has_indicator_file(self, path: Path) -> bool:
119-
"""Checks if the indicator file exists in the given path."""
120-
return await async_path.exists(path / ORBAX_CHECKPOINT_INDICATOR_FILE)
140+
self._metadata_store = checkpoint_metadata.metadata_store(enable_write=True)
121141

122142
async def metadata(
123143
self, path: Path
@@ -157,34 +177,53 @@ async def _validate_pytree(self, path: Path, checkpointable_name: str | None):
157177
in the directory
158178
ValueError: If the PyTree checkpoint is malformed.
159179
"""
180+
# TODO(b/476156780): Remove v0 logic from V1 OrbaxLayout
181+
182+
# If it's a V1 checkpoint, it's not valid for the PyTree to be saved
183+
# directly to the checkpoint directory.
184+
if (
185+
checkpoint_version(path) == CheckpointVersion.V1
186+
and checkpointable_name is None
187+
):
188+
raise FileNotFoundError(
189+
"Cannot load a V1 checkpoint directly as a PyTree checkpointable."
190+
)
191+
192+
# Determine the directory, either root or checkpointable.
160193
pytree_dir = (
161194
path if checkpointable_name is None else path / checkpointable_name
162195
)
163-
if checkpointable_name is not None and not await async_path.exists(
196+
197+
# Check if the directory exists and has PyTree metadata.
198+
if not await async_path.exists(
164199
pytree_dir
165-
):
166-
subdirs = [
167-
d.name for d in await _subpaths(path) if await async_path.is_dir(d)
168-
]
169-
raise FileNotFoundError(
170-
f"Checkpoint path {path} must contain a subdirectory named"
171-
f' "{checkpointable_name}". Found subdirectories:'
172-
f" {subdirs}."
173-
" Please try inspecting the checkpointable metadata using"
174-
" `ocp.checkpointables_metadata()` or try loading the checkpoint"
175-
" using"
176-
" `ocp.load_checkpointables()`."
177-
)
178-
if not await has_pytree_metadata_file(pytree_dir):
179-
# TODO(niketkb): Add following details to the error message:
200+
) or not await has_pytree_metadata_file(pytree_dir):
180201
# 1. we should check other available subdirectories and see if any of them
181202
# look like PyTree checkpoints, and instruct the user to consider
182203
# whether they meant to specify any of those.
183-
# 2. we need to check the directory - if it contains PyTree files, suggest
204+
205+
pytree_checkpointable_names = []
206+
for subdir in await _subpaths(path):
207+
if await has_pytree_metadata_file(subdir):
208+
pytree_checkpointable_names.append(subdir.name)
209+
# 2. Check checkpoint root directory if it is a PyTree checkpoint, suggest
184210
# loading with checkpointable_name=None
211+
if await has_pytree_metadata_file(path):
212+
pytree_checkpointable_names.append(None)
213+
214+
if pytree_checkpointable_names:
215+
raise FileNotFoundError(
216+
"checkpointable_name either does not exist or is missing Pytree"
217+
" checkpoint metadata. Please consider using one of the following"
218+
" valid pytree checkpointable_names:"
219+
f" {pytree_checkpointable_names}"
220+
)
185221
raise FileNotFoundError(
186-
f"Checkpoint path {path} does not contain a PyTree metadata file."
222+
"checkpointable_name either does not exist or is missing Pytree"
223+
" checkpoint metadata. There are no valid pytree checkpointables in"
224+
" this checkpoint"
187225
)
226+
188227
if not await has_tensorstore_data_files(pytree_dir):
189228
logging.warning(
190229
"TensorStore data files not found in checkpoint path %s. This may be"
@@ -214,11 +253,13 @@ async def _validate(self, path: Path):
214253
NotADirectoryError: If the path is not a directory.
215254
ValueError: If the checkpoint is incomplete.
216255
"""
217-
256+
# TODO(b/476156780): Remove v0 logic from V1 OrbaxLayout
218257
if not await async_path.exists(path):
219258
raise FileNotFoundError(f"Checkpoint path {path} does not exist.")
259+
220260
if not await async_path.is_dir(path):
221261
raise NotADirectoryError(f"Checkpoint path {path} is not a directory.")
262+
222263
if await temporary_paths.is_path_temporary(
223264
path,
224265
temporary_path_cls=self._context.file_options.temporary_path_class,
@@ -231,20 +272,20 @@ async def _validate(self, path: Path):
231272
if ORBAX_CHECKPOINT_INDICATOR_FILE in [p.name for p in subpaths]:
232273
return
233274

234-
# Path points to a single step checkpoint with valid metadata.
275+
# Path points to a checkpoint with valid metadata.
235276
if await async_path.exists(
236277
metadata_serialization.checkpoint_metadata_file_path(path)
237278
):
238279
return
239280

240281
# The path itself points to a PyTree checkpointable.
241-
if await async_path.exists(path / PYTREE_METADATA_FILE):
282+
if await has_pytree_metadata_file(path):
242283
return
243284
# The path points to a directory containing at least one PyTree
244285
# checkpointable.
245286
for subpath in subpaths:
246-
if await async_path.is_dir(subpath) and await async_path.exists(
247-
subpath / PYTREE_METADATA_FILE
287+
if await async_path.is_dir(subpath) and await has_pytree_metadata_file(
288+
subpath
248289
):
249290
return
250291

@@ -281,7 +322,81 @@ async def validate_pytree(
281322
f" checkpoint. {_GENERAL_ERROR_MESSAGE}"
282323
) from e
283324

284-
async def load(
325+
def _get_typestr(
326+
self, path: Path, checkpointable_name: str | None
327+
) -> str | None:
328+
"""Gets the typestr for the given path, falling back to parent if needed."""
329+
# TODO(b/476156780): Remove complex V0 handler resolution logic out of V1
330+
# OrbaxLayout and re-evaulate implementation
331+
332+
# Attempt to get typestr from the step metadata file in the current
333+
# checkpoint path.
334+
metadata_path = checkpoint_metadata.step_metadata_file_path(path)
335+
if metadata_path.exists():
336+
serialized = self._metadata_store.read(metadata_path)
337+
if serialized:
338+
metadata = step_metadata_serialization.deserialize(serialized or {})
339+
# If checkpoint is V0 and pytree is saved directly to checkpoint,
340+
# we expect a single string type for a PyTree in the metadata.
341+
if checkpointable_name is None:
342+
if isinstance(metadata.item_handlers, str):
343+
return metadata.item_handlers
344+
else:
345+
if isinstance(metadata.item_handlers, dict):
346+
return metadata.item_handlers.get(checkpointable_name)
347+
348+
# For pytree checkpointable directory, if direct path didn't yield a typestr
349+
# we try the parent path.
350+
if checkpointable_name is None:
351+
parent_metadata_path = checkpoint_metadata.step_metadata_file_path(
352+
path.parent
353+
)
354+
if parent_metadata_path.exists():
355+
serialized = self._metadata_store.read(parent_metadata_path)
356+
if serialized:
357+
metadata = step_metadata_serialization.deserialize(serialized or {})
358+
if isinstance(metadata.item_handlers, dict):
359+
return metadata.item_handlers.get(path.name)
360+
return None
361+
362+
async def load_pytree(
363+
self,
364+
path: Path,
365+
checkpointable_name: str | None = None,
366+
abstract_pytree: (
367+
tree_types.PyTreeOf[tree_types.AbstractLeafType] | None
368+
) = None,
369+
) -> Awaitable[Any]:
370+
typestr = self._get_typestr(path, checkpointable_name)
371+
name_for_registration = checkpointable_name or path.name
372+
373+
if typestr:
374+
handler = registration.resolve_handler_for_load(
375+
self._handler_registry,
376+
abstract_pytree,
377+
name=name_for_registration,
378+
handler_typestr=typestr,
379+
)
380+
# TODO(b/476156780): Remove from V1 OrbaxLayout and re-evaulate resolution
381+
# logic
382+
383+
# If missing _CHECKPOINT_METADATA and its a V0 pytree checkpoint, check
384+
# if it has _METADATA
385+
elif checkpointable_name is None and await has_pytree_metadata_file(path):
386+
handler = pytree_handler.PyTreeHandler(context=self._context)
387+
else:
388+
raise ValueError(
389+
"Could not find handler information for the given checkpointable"
390+
f" name: {checkpointable_name} in path: {path}."
391+
)
392+
393+
pytree_dir = (
394+
path if checkpointable_name is None else path / checkpointable_name
395+
)
396+
load_awaitable = await handler.load(pytree_dir, abstract_pytree)
397+
return load_awaitable
398+
399+
async def load_checkpointables(
285400
self,
286401
path: Path,
287402
abstract_checkpointables: dict[str, Any] | None = None,

0 commit comments

Comments
 (0)