1515"""Defines `OrbaxLayout`, a class to handle Orbax checkpoint formats."""
1616
1717import asyncio
18+ import enum
1819from typing import Any , Awaitable
1920
2021from absl import logging
22+ from orbax .checkpoint ._src .metadata import checkpoint as checkpoint_metadata
23+ from orbax .checkpoint ._src .metadata import step_metadata_serialization
2124from orbax .checkpoint ._src .path import async_path
2225from orbax .checkpoint ._src .path import temporary_paths
2326from orbax .checkpoint .experimental .v1 ._src .context import context as context_lib
2427from orbax .checkpoint .experimental .v1 ._src .handlers import composite_handler
28+ from orbax .checkpoint .experimental .v1 ._src .handlers import pytree_handler
2529from orbax .checkpoint .experimental .v1 ._src .handlers import registration
2630from orbax .checkpoint .experimental .v1 ._src .layout import checkpoint_layout
2731from orbax .checkpoint .experimental .v1 ._src .loading import v0_compatibility
2832from orbax .checkpoint .experimental .v1 ._src .metadata import serialization as metadata_serialization
2933from orbax .checkpoint .experimental .v1 ._src .metadata import types as metadata_types
3034from orbax .checkpoint .experimental .v1 ._src .path import types as path_types
35+ from orbax .checkpoint .experimental .v1 ._src .tree import types as tree_types
36+
37+
38+ class CheckpointVersion (enum .Enum ):
39+ V0 = 0
40+ V1 = 1
3141
3242
3343InvalidLayoutError = checkpoint_layout .InvalidLayoutError
5565)
5666
5767
68+ def checkpoint_version (path : path_types .PathLike ) -> CheckpointVersion :
69+ """Returns the checkpoint version of the given path."""
70+ if (path / ORBAX_CHECKPOINT_INDICATOR_FILE ).exists ():
71+ return CheckpointVersion .V1
72+ else :
73+ return CheckpointVersion .V0
74+
75+
5876async def _subpaths (directory : Path ) -> list [Path ]:
5977 """Returns subdirectories up to a limit."""
6078 return list (await async_path .iterdir (directory ))
@@ -96,6 +114,11 @@ async def has_pytree_metadata_file(path: Path) -> bool:
96114 return await async_path .exists (path / PYTREE_METADATA_FILE )
97115
98116
117+ async def has_indicator_file (path : Path ) -> bool :
118+ """Checks if the indicator file exists in the given path."""
119+ return await async_path .exists (path / ORBAX_CHECKPOINT_INDICATOR_FILE )
120+
121+
99122class OrbaxLayout (CheckpointLayout ):
100123 """OrbaxLayout.
101124
@@ -114,10 +137,7 @@ def __init__(self):
114137 include_global_registry = False ,
115138 )
116139 self ._composite_handler = CompositeHandler (self ._handler_registry )
117-
118- async def has_indicator_file (self , path : Path ) -> bool :
119- """Checks if the indicator file exists in the given path."""
120- return await async_path .exists (path / ORBAX_CHECKPOINT_INDICATOR_FILE )
140+ self ._metadata_store = checkpoint_metadata .metadata_store (enable_write = True )
121141
122142 async def metadata (
123143 self , path : Path
@@ -157,34 +177,53 @@ async def _validate_pytree(self, path: Path, checkpointable_name: str | None):
157177 in the directory
158178 ValueError: If the PyTree checkpoint is malformed.
159179 """
180+ # TODO(b/476156780): Remove v0 logic from V1 OrbaxLayout
181+
182+ # If it's a V1 checkpoint, it's not valid for the PyTree to be saved
183+ # directly to the checkpoint directory.
184+ if (
185+ checkpoint_version (path ) == CheckpointVersion .V1
186+ and checkpointable_name is None
187+ ):
188+ raise FileNotFoundError (
189+ "Cannot load a V1 checkpoint directly as a PyTree checkpointable."
190+ )
191+
192+ # Determine the directory, either root or checkpointable.
160193 pytree_dir = (
161194 path if checkpointable_name is None else path / checkpointable_name
162195 )
163- if checkpointable_name is not None and not await async_path .exists (
196+
197+ # Check if the directory exists and has PyTree metadata.
198+ if not await async_path .exists (
164199 pytree_dir
165- ):
166- subdirs = [
167- d .name for d in await _subpaths (path ) if await async_path .is_dir (d )
168- ]
169- raise FileNotFoundError (
170- f"Checkpoint path { path } must contain a subdirectory named"
171- f' "{ checkpointable_name } ". Found subdirectories:'
172- f" { subdirs } ."
173- " Please try inspecting the checkpointable metadata using"
174- " `ocp.checkpointables_metadata()` or try loading the checkpoint"
175- " using"
176- " `ocp.load_checkpointables()`."
177- )
178- if not await has_pytree_metadata_file (pytree_dir ):
179- # TODO(niketkb): Add following details to the error message:
200+ ) or not await has_pytree_metadata_file (pytree_dir ):
180201 # 1. we should check other available subdirectories and see if any of them
181202 # look like PyTree checkpoints, and instruct the user to consider
182203 # whether they meant to specify any of those.
183- # 2. we need to check the directory - if it contains PyTree files, suggest
204+
205+ pytree_checkpointable_names = []
206+ for subdir in await _subpaths (path ):
207+ if await has_pytree_metadata_file (subdir ):
208+ pytree_checkpointable_names .append (subdir .name )
209+ # 2. Check checkpoint root directory if it is a PyTree checkpoint, suggest
184210 # loading with checkpointable_name=None
211+ if await has_pytree_metadata_file (path ):
212+ pytree_checkpointable_names .append (None )
213+
214+ if pytree_checkpointable_names :
215+ raise FileNotFoundError (
216+ "checkpointable_name either does not exist or is missing Pytree"
217+ " checkpoint metadata. Please consider using one of the following"
218+ " valid pytree checkpointable_names:"
219+ f" { pytree_checkpointable_names } "
220+ )
185221 raise FileNotFoundError (
186- f"Checkpoint path { path } does not contain a PyTree metadata file."
222+ "checkpointable_name either does not exist or is missing Pytree"
223+ " checkpoint metadata. There are no valid pytree checkpointables in"
224+ " this checkpoint"
187225 )
226+
188227 if not await has_tensorstore_data_files (pytree_dir ):
189228 logging .warning (
190229 "TensorStore data files not found in checkpoint path %s. This may be"
@@ -214,11 +253,13 @@ async def _validate(self, path: Path):
214253 NotADirectoryError: If the path is not a directory.
215254 ValueError: If the checkpoint is incomplete.
216255 """
217-
256+ # TODO(b/476156780): Remove v0 logic from V1 OrbaxLayout
218257 if not await async_path .exists (path ):
219258 raise FileNotFoundError (f"Checkpoint path { path } does not exist." )
259+
220260 if not await async_path .is_dir (path ):
221261 raise NotADirectoryError (f"Checkpoint path { path } is not a directory." )
262+
222263 if await temporary_paths .is_path_temporary (
223264 path ,
224265 temporary_path_cls = self ._context .file_options .temporary_path_class ,
@@ -231,20 +272,20 @@ async def _validate(self, path: Path):
231272 if ORBAX_CHECKPOINT_INDICATOR_FILE in [p .name for p in subpaths ]:
232273 return
233274
234- # Path points to a single step checkpoint with valid metadata.
275+ # Path points to a checkpoint with valid metadata.
235276 if await async_path .exists (
236277 metadata_serialization .checkpoint_metadata_file_path (path )
237278 ):
238279 return
239280
240281 # The path itself points to a PyTree checkpointable.
241- if await async_path . exists (path / PYTREE_METADATA_FILE ):
282+ if await has_pytree_metadata_file (path ):
242283 return
243284 # The path points to a directory containing at least one PyTree
244285 # checkpointable.
245286 for subpath in subpaths :
246- if await async_path .is_dir (subpath ) and await async_path . exists (
247- subpath / PYTREE_METADATA_FILE
287+ if await async_path .is_dir (subpath ) and await has_pytree_metadata_file (
288+ subpath
248289 ):
249290 return
250291
@@ -281,7 +322,81 @@ async def validate_pytree(
281322 f" checkpoint. { _GENERAL_ERROR_MESSAGE } "
282323 ) from e
283324
284- async def load (
325+ def _get_typestr (
326+ self , path : Path , checkpointable_name : str | None
327+ ) -> str | None :
328+ """Gets the typestr for the given path, falling back to parent if needed."""
329+ # TODO(b/476156780): Remove complex V0 handler resolution logic out of V1
330+ # OrbaxLayout and re-evaulate implementation
331+
332+ # Attempt to get typestr from the step metadata file in the current
333+ # checkpoint path.
334+ metadata_path = checkpoint_metadata .step_metadata_file_path (path )
335+ if metadata_path .exists ():
336+ serialized = self ._metadata_store .read (metadata_path )
337+ if serialized :
338+ metadata = step_metadata_serialization .deserialize (serialized or {})
339+ # If checkpoint is V0 and pytree is saved directly to checkpoint,
340+ # we expect a single string type for a PyTree in the metadata.
341+ if checkpointable_name is None :
342+ if isinstance (metadata .item_handlers , str ):
343+ return metadata .item_handlers
344+ else :
345+ if isinstance (metadata .item_handlers , dict ):
346+ return metadata .item_handlers .get (checkpointable_name )
347+
348+ # For pytree checkpointable directory, if direct path didn't yield a typestr
349+ # we try the parent path.
350+ if checkpointable_name is None :
351+ parent_metadata_path = checkpoint_metadata .step_metadata_file_path (
352+ path .parent
353+ )
354+ if parent_metadata_path .exists ():
355+ serialized = self ._metadata_store .read (parent_metadata_path )
356+ if serialized :
357+ metadata = step_metadata_serialization .deserialize (serialized or {})
358+ if isinstance (metadata .item_handlers , dict ):
359+ return metadata .item_handlers .get (path .name )
360+ return None
361+
362+ async def load_pytree (
363+ self ,
364+ path : Path ,
365+ checkpointable_name : str | None = None ,
366+ abstract_pytree : (
367+ tree_types .PyTreeOf [tree_types .AbstractLeafType ] | None
368+ ) = None ,
369+ ) -> Awaitable [Any ]:
370+ typestr = self ._get_typestr (path , checkpointable_name )
371+ name_for_registration = checkpointable_name or path .name
372+
373+ if typestr :
374+ handler = registration .resolve_handler_for_load (
375+ self ._handler_registry ,
376+ abstract_pytree ,
377+ name = name_for_registration ,
378+ handler_typestr = typestr ,
379+ )
380+ # TODO(b/476156780): Remove from V1 OrbaxLayout and re-evaulate resolution
381+ # logic
382+
383+ # If missing _CHECKPOINT_METADATA and its a V0 pytree checkpoint, check
384+ # if it has _METADATA
385+ elif checkpointable_name is None and await has_pytree_metadata_file (path ):
386+ handler = pytree_handler .PyTreeHandler (context = self ._context )
387+ else :
388+ raise ValueError (
389+ "Could not find handler information for the given checkpointable"
390+ f" name: { checkpointable_name } in path: { path } ."
391+ )
392+
393+ pytree_dir = (
394+ path if checkpointable_name is None else path / checkpointable_name
395+ )
396+ load_awaitable = await handler .load (pytree_dir , abstract_pytree )
397+ return load_awaitable
398+
399+ async def load_checkpointables (
285400 self ,
286401 path : Path ,
287402 abstract_checkpointables : dict [str , Any ] | None = None ,
0 commit comments