Skip to content

Commit 9b22138

Browse files
angel-coreOrbax Authors
authored andcommitted
#v1 Refactor OrbaxLayout metadata retrieval and v1 Checkpointer to use internal free function logic, removing dependency on v0 checkpointer + refactor to read checkpoint metadata externally from get_handlers_from_load.
PiperOrigin-RevId: 874781759
1 parent 306242b commit 9b22138

File tree

4 files changed

+66
-176
lines changed

4 files changed

+66
-176
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/resolution.py

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020

2121
from absl import logging
2222
from orbax.checkpoint._src.metadata import step_metadata_serialization
23-
from orbax.checkpoint._src.path import async_path
2423
from orbax.checkpoint.experimental.v1._src.handlers import registration
2524
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
2625
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
27-
from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization
2826
from orbax.checkpoint.experimental.v1._src.path import types as path_types
2927

3028
InternalCheckpointMetadata = (
@@ -49,21 +47,6 @@ def _subdirs(directory: path_types.Path, *, limit: int = 3) -> list[str]:
4947
)
5048

5149

52-
async def read_checkpoint_metadata(
53-
directory: path_types.Path,
54-
) -> InternalCheckpointMetadata:
55-
"""Returns the step metadata for a given path, normalized for V1."""
56-
57-
serialized_metadata = (
58-
await metadata_serialization.read(
59-
metadata_serialization.checkpoint_metadata_file_path(directory)
60-
)
61-
or {}
62-
)
63-
64-
return InternalCheckpointMetadata.deserialize(serialized_metadata)
65-
66-
6750
def get_handlers_for_save(
6851
handler_registry: registration.CheckpointableHandlerRegistry,
6952
checkpointables: dict[str, Any],
@@ -81,10 +64,11 @@ async def get_handlers_for_load(
8164
directory: path_types.Path,
8265
handler_registry: registration.CheckpointableHandlerRegistry,
8366
abstract_checkpointables: dict[str, Any],
67+
checkpoint_metadata: InternalCheckpointMetadata,
8468
) -> dict[str, handler_types.CheckpointableHandler]:
8569
"""Returns a mapping from checkpointable name to handler."""
8670
existing_checkpointable_names_to_handler_typestrs = (
87-
await _get_saved_handler_typestrs(directory)
71+
await _get_saved_handler_typestrs(directory, checkpoint_metadata)
8872
)
8973
abstract_checkpointables = abstract_checkpointables or {
9074
name: None for name in existing_checkpointable_names_to_handler_typestrs
@@ -111,13 +95,10 @@ async def get_handlers_for_load(
11195

11296
async def _get_saved_handler_typestrs(
11397
directory: path_types.Path,
98+
checkpoint_metadata: InternalCheckpointMetadata,
11499
) -> dict[str, str]:
115100
"""Reads from the checkpoint metadata to get saved handler typestrs."""
116-
checkpoint_metadata_file_path = (
117-
metadata_serialization.checkpoint_metadata_file_path(directory)
118-
)
119-
if await async_path.exists(checkpoint_metadata_file_path):
120-
checkpoint_metadata = await read_checkpoint_metadata(directory)
101+
if checkpoint_metadata.item_handlers:
121102
if isinstance(checkpoint_metadata.item_handlers, dict):
122103
return checkpoint_metadata.item_handlers # found step level metadata.
123104
raise ValueError(
@@ -131,34 +112,8 @@ async def _get_saved_handler_typestrs(
131112
)
132113

133114
logging.warning(
134-
'Given dir does not contain checkpoint metadata file: %s. Trying to get'
135-
' saved handlers from checkpoint metadata in each of the checkpointable'
136-
' subdirectory.',
115+
'Given dir does not contain checkpoint metadata file: %s. No handler'
116+
' typestrs found.',
137117
directory,
138118
)
139-
140-
# TODO(b/475265289): Currently, we rely solely on CHECKPOINT_METADATA to
141-
# find available checkpointables, ignoring valid subdirectories. We
142-
# should update the composite handler to validate subdirectories to
143-
# check if any either represents a valid pytree checkpointable or has a
144-
# name that is registered in the handler registry.
145-
saved_handler_typestrs: dict[str, str] = {}
146-
for checkpointable_path in await async_path.iterdir(directory):
147-
if not await async_path.is_dir(checkpointable_path):
148-
continue
149-
checkpoint_metadata = await read_checkpoint_metadata(checkpointable_path)
150-
if isinstance(checkpoint_metadata.item_handlers, dict):
151-
raise ValueError(
152-
f'Path at {directory} contains subdirectories:'
153-
f' {_subdirs(directory)}, which are expected to'
154-
' match the keys given by the _CHECKPOINT_METADATA file:'
155-
f' {checkpoint_metadata.item_handlers}. If you intended to load a'
156-
' pytree checkpoint from the given path, then please consider using'
157-
' `loading.load_pytree(..., checkpointable_name=None)` instead.'
158-
f' {_V0_ERROR_MESSAGE}'
159-
)
160-
item_handlers = checkpoint_metadata.item_handlers
161-
if item_handlers is not None:
162-
checkpointable_name = checkpointable_path.name
163-
saved_handler_typestrs[checkpointable_name] = item_handlers
164-
return saved_handler_typestrs
119+
return {}

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020

2121
from absl import logging
2222
from orbax.checkpoint._src import asyncio_utils
23+
from orbax.checkpoint._src.metadata import step_metadata_serialization
2324
from orbax.checkpoint._src.multihost import multihost
2425
from orbax.checkpoint._src.path import async_path
2526
from orbax.checkpoint._src.path import temporary_paths
2627
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
2728
from orbax.checkpoint.experimental.v1._src.handlers import registration
2829
from orbax.checkpoint.experimental.v1._src.handlers import resolution as handler_resolution
2930
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
30-
from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility
31+
from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization
3132
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
3233
from orbax.checkpoint.experimental.v1._src.path import types as path_types
3334
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
@@ -41,6 +42,9 @@ class CheckpointVersion(enum.Enum):
4142
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
4243
Path = path_types.Path
4344
CheckpointLayout = checkpoint_layout.CheckpointLayout
45+
InternalCheckpointMetadata = (
46+
step_metadata_serialization.InternalCheckpointMetadata
47+
)
4448

4549
PYTREE_METADATA_FILE = "_METADATA"
4650
ORBAX_CHECKPOINT_INDICATOR_FILE = "orbax.checkpoint"
@@ -149,6 +153,19 @@ async def _create_orbax_identifier_file(
149153
)
150154

151155

156+
async def _read_checkpoint_metadata(
157+
directory: path_types.Path,
158+
) -> InternalCheckpointMetadata:
159+
"""Returns the step metadata for a given path."""
160+
serialized_metadata = (
161+
await metadata_serialization.read(
162+
metadata_serialization.checkpoint_metadata_file_path(directory)
163+
)
164+
or {}
165+
)
166+
return InternalCheckpointMetadata.deserialize(serialized_metadata)
167+
168+
152169
class OrbaxLayout(CheckpointLayout):
153170
"""OrbaxLayout.
154171
@@ -172,23 +189,47 @@ async def metadata(
172189
self, path: Path
173190
) -> metadata_types.CheckpointMetadata[dict[str, Any]]:
174191
"""Returns the metadata describing the Orbax checkpoint."""
175-
# Uses the v0 checkpointer to get v0 StepMetadata
176-
checkpointer, _ = v0_compatibility.get_v0_checkpointer_and_args(
177-
path, None, context=context_lib.get_context()
192+
checkpoint_metadata = await _read_checkpoint_metadata(
193+
path
178194
)
179-
step_metadata = checkpointer.metadata(path)
195+
handlers_for_load = await handler_resolution.get_handlers_for_load(
196+
path, self._handler_registry, {}, checkpoint_metadata
197+
)
198+
existing_checkpointable_names = await _existing_checkpointable_names(path)
199+
abstract_checkpointables = {
200+
name: None
201+
for name in handlers_for_load.keys()
202+
if name in existing_checkpointable_names
203+
}
204+
if any(
205+
name not in existing_checkpointable_names
206+
for name in abstract_checkpointables.keys()
207+
):
208+
raise KeyError(
209+
"Inferred checkpointables from metadata:"
210+
f" {abstract_checkpointables.keys()} for loading were not found in"
211+
" the checkpoint. Available checkpointables:"
212+
f" {existing_checkpointable_names}"
213+
)
180214

181-
item_metadata = {k: v for k, v in step_metadata.item_metadata.items()}
215+
# Default to none for all existing checkpointable names, for
216+
# subdirectories that we are unable to find a handler for and load.
217+
item_metadata = {name: None for name in existing_checkpointable_names}
218+
for checkpointable_name in abstract_checkpointables.keys():
219+
handler = handlers_for_load[checkpointable_name]
220+
item_metadata[checkpointable_name] = await handler.metadata(
221+
path / checkpointable_name
222+
)
182223
# Exclude `metrics` if present. This is relevant only for
183224
# `training.Checkpointer`, and is separately added to the
184225
# `training.CheckpointMetadata` object.
185226
item_metadata.pop("metrics", None)
186227

187228
return metadata_types.CheckpointMetadata[dict[str, Any]](
188229
metadata=item_metadata,
189-
init_timestamp_nsecs=step_metadata.init_timestamp_nsecs,
190-
commit_timestamp_nsecs=step_metadata.commit_timestamp_nsecs,
191-
custom_metadata=step_metadata.custom_metadata,
230+
init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs,
231+
commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs,
232+
custom_metadata=checkpoint_metadata.custom_metadata,
192233
)
193234

194235
async def _validate_pytree(self, path: Path, checkpointable_name: str | None):
@@ -356,8 +397,14 @@ async def load_checkpointables(
356397
the checkpoint.
357398
"""
358399
abstract_checkpointables = abstract_checkpointables or {}
400+
checkpoint_metadata = await _read_checkpoint_metadata(
401+
path
402+
)
359403
handlers_for_load = await handler_resolution.get_handlers_for_load(
360-
path, self._handler_registry, abstract_checkpointables
404+
path,
405+
self._handler_registry,
406+
abstract_checkpointables,
407+
checkpoint_metadata,
361408
)
362409
existing_checkpointable_names = await _existing_checkpointable_names(path)
363410
if not abstract_checkpointables:

checkpoint/orbax/checkpoint/experimental/v1/_src/loading/v0_compatibility.py

Lines changed: 0 additions & 105 deletions
This file was deleted.

checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
2626
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
2727
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
28-
from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility as v0_loading_utils
28+
from orbax.checkpoint.experimental.v1._src.loading import loading
2929
from orbax.checkpoint.experimental.v1._src.metadata import loading as metadata_loading
3030
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
3131
from orbax.checkpoint.experimental.v1._src.path import step as path_step_lib
@@ -444,17 +444,10 @@ def load_checkpointables(
444444
) -> dict[str, Any]:
445445
"""Loads a set of checkpointables at the given step."""
446446
step = self._resolve_existing_checkpoint(step).step
447-
checkpointer, args = v0_loading_utils.get_v0_checkpointer_and_args(
447+
return loading.load_checkpointables(
448448
self.directory / self._step_name_format.build_name(step),
449449
abstract_checkpointables,
450-
context=context_lib.get_context(),
451450
)
452-
self._manager._checkpointer = checkpointer # pylint: disable=protected-access
453-
restored = self._manager.restore(
454-
step,
455-
args=args,
456-
)
457-
return {k: v for k, v in zip(restored.keys(), restored.values())}
458451

459452
def load_pytree_async(
460453
self,

0 commit comments

Comments
 (0)