2020
2121from absl import logging
2222from orbax .checkpoint ._src import asyncio_utils
23+ from orbax .checkpoint ._src .metadata import step_metadata_serialization
2324from orbax .checkpoint ._src .multihost import multihost
2425from orbax .checkpoint ._src .path import async_path
2526from orbax .checkpoint ._src .path import temporary_paths
2627from orbax .checkpoint .experimental .v1 ._src .context import context as context_lib
2728from orbax .checkpoint .experimental .v1 ._src .handlers import registration
2829from orbax .checkpoint .experimental .v1 ._src .handlers import resolution as handler_resolution
2930from orbax .checkpoint .experimental .v1 ._src .layout import checkpoint_layout
30- from orbax .checkpoint .experimental .v1 ._src .loading import v0_compatibility
31+ from orbax .checkpoint .experimental .v1 ._src .metadata import serialization as metadata_serialization
3132from orbax .checkpoint .experimental .v1 ._src .metadata import types as metadata_types
3233from orbax .checkpoint .experimental .v1 ._src .path import types as path_types
3334from orbax .checkpoint .experimental .v1 ._src .tree import types as tree_types
@@ -41,6 +42,9 @@ class CheckpointVersion(enum.Enum):
4142InvalidLayoutError = checkpoint_layout .InvalidLayoutError
4243Path = path_types .Path
4344CheckpointLayout = checkpoint_layout .CheckpointLayout
45+ InternalCheckpointMetadata = (
46+ step_metadata_serialization .InternalCheckpointMetadata
47+ )
4448
4549PYTREE_METADATA_FILE = "_METADATA"
4650ORBAX_CHECKPOINT_INDICATOR_FILE = "orbax.checkpoint"
@@ -149,6 +153,19 @@ async def _create_orbax_identifier_file(
149153 )
150154
151155
156+ async def _read_checkpoint_metadata (
157+ directory : path_types .Path ,
158+ ) -> InternalCheckpointMetadata :
159+ """Returns the step metadata for a given path."""
160+ serialized_metadata = (
161+ await metadata_serialization .read (
162+ metadata_serialization .checkpoint_metadata_file_path (directory )
163+ )
164+ or {}
165+ )
166+ return InternalCheckpointMetadata .deserialize (serialized_metadata )
167+
168+
152169class OrbaxLayout (CheckpointLayout ):
153170 """OrbaxLayout.
154171
@@ -172,23 +189,47 @@ async def metadata(
172189 self , path : Path
173190 ) -> metadata_types .CheckpointMetadata [dict [str , Any ]]:
174191 """Returns the metadata describing the Orbax checkpoint."""
175- # Uses the v0 checkpointer to get v0 StepMetadata
176- checkpointer , _ = v0_compatibility .get_v0_checkpointer_and_args (
177- path , None , context = context_lib .get_context ()
192+ checkpoint_metadata = await _read_checkpoint_metadata (
193+ path
178194 )
179- step_metadata = checkpointer .metadata (path )
195+ handlers_for_load = await handler_resolution .get_handlers_for_load (
196+ path , self ._handler_registry , {}, checkpoint_metadata
197+ )
198+ existing_checkpointable_names = await _existing_checkpointable_names (path )
199+ abstract_checkpointables = {
200+ name : None
201+ for name in handlers_for_load .keys ()
202+ if name in existing_checkpointable_names
203+ }
204+ if any (
205+ name not in existing_checkpointable_names
206+ for name in abstract_checkpointables .keys ()
207+ ):
208+ raise KeyError (
209+ "Inferred checkpointables from metadata:"
210+ f" { abstract_checkpointables .keys ()} for loading were not found in"
211+ " the checkpoint. Available checkpointables:"
212+ f" { existing_checkpointable_names } "
213+ )
180214
181- item_metadata = {k : v for k , v in step_metadata .item_metadata .items ()}
215+ # Default to none for all existing checkpointable names, for
216+ # subdirectories that we are unable to find a handler for and load.
217+ item_metadata = {name : None for name in existing_checkpointable_names }
218+ for checkpointable_name in abstract_checkpointables .keys ():
219+ handler = handlers_for_load [checkpointable_name ]
220+ item_metadata [checkpointable_name ] = await handler .metadata (
221+ path / checkpointable_name
222+ )
182223 # Exclude `metrics` if present. This is relevant only for
183224 # `training.Checkpointer`, and is separately added to the
184225 # `training.CheckpointMetadata` object.
185226 item_metadata .pop ("metrics" , None )
186227
187228 return metadata_types .CheckpointMetadata [dict [str , Any ]](
188229 metadata = item_metadata ,
189- init_timestamp_nsecs = step_metadata .init_timestamp_nsecs ,
190- commit_timestamp_nsecs = step_metadata .commit_timestamp_nsecs ,
191- custom_metadata = step_metadata .custom_metadata ,
230+ init_timestamp_nsecs = checkpoint_metadata .init_timestamp_nsecs ,
231+ commit_timestamp_nsecs = checkpoint_metadata .commit_timestamp_nsecs ,
232+ custom_metadata = checkpoint_metadata .custom_metadata ,
192233 )
193234
194235 async def _validate_pytree (self , path : Path , checkpointable_name : str | None ):
@@ -356,8 +397,14 @@ async def load_checkpointables(
356397 the checkpoint.
357398 """
358399 abstract_checkpointables = abstract_checkpointables or {}
400+ checkpoint_metadata = await _read_checkpoint_metadata (
401+ path
402+ )
359403 handlers_for_load = await handler_resolution .get_handlers_for_load (
360- path , self ._handler_registry , abstract_checkpointables
404+ path ,
405+ self ._handler_registry ,
406+ abstract_checkpointables ,
407+ checkpoint_metadata ,
361408 )
362409 existing_checkpointable_names = await _existing_checkpointable_names (path )
363410 if not abstract_checkpointables :
0 commit comments