Skip to content

Commit 143c188

Browse files
angel-coreOrbax Authors
authored andcommitted
Adjust logic for notifying user of incorrect loading path (root/child)
PiperOrigin-RevId: 874792039
1 parent 9b22138 commit 143c188

File tree

449 files changed

+1444
-275
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

449 files changed

+1444
-275
lines changed

checkpoint/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ devices anyway.
4141
`load_checkpointables()` each with their own dedicated loading logic
4242
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
4343
tests
44+
- Refactor logic for handler resolution and loading checkpointables for
45+
`OrbaxLayout` and `OrbaxV0Layout`, adding additional fallback capabilities for
46+
non-standard checkpoint formats.
4447

4548
### Fixed
4649

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ def _get_possible_handlers(
392392
return possible_handlers
393393

394394

395+
def get_registered_handler_by_name(
396+
registry: CheckpointableHandlerRegistry,
397+
name: str,
398+
) -> CheckpointableHandler | None:
399+
"""Returns the handler for the given name if registered."""
400+
if registry.has(name):
401+
return _construct_handler_instance(name, registry.get(name))
402+
return None
403+
404+
395405
def resolve_handler_for_save(
396406
registry: CheckpointableHandlerRegistry,
397407
checkpointable: Any,
@@ -444,7 +454,7 @@ def resolve_handler_for_load(
444454
abstract_checkpointable: Any | None,
445455
*,
446456
name: str,
447-
handler_typestr: str,
457+
handler_typestr: str | None = None,
448458
) -> CheckpointableHandler:
449459
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading.
450460
@@ -471,7 +481,9 @@ def resolve_handler_for_load(
471481
abstract_checkpointable: An abstract checkpointable to resolve.
472482
name: The name of the checkpointable.
473483
handler_typestr: A :py:class:`~.v1.handlers.CheckpointableHandler` typestr
474-
to guide resolution.
484+
to guide resolution. We allow a None value for handler_typestr as its
485+
possible to find the last registered handler given a specified
486+
abstract_checkpointable.
475487
476488
Returns:
477489
A :py:class:`~.v1.handlers.CheckpointableHandler` instance.
@@ -492,15 +504,16 @@ def is_handleable_fn(
492504
handler_types.typestr(type(handler)) for handler in possible_handlers
493505
]
494506

495-
try:
496-
idx = possible_handler_typestrs.index(handler_typestr)
497-
return possible_handlers[idx]
498-
except ValueError:
499-
logging.warning(
500-
'No handler found for typestr %s. The checkpointable may be restored'
501-
' with different handler logic than was used for saving.',
502-
handler_typestr,
503-
)
507+
if handler_typestr:
508+
try:
509+
idx = possible_handler_typestrs.index(handler_typestr)
510+
return possible_handlers[idx]
511+
except ValueError:
512+
logging.warning(
513+
'No handler found for typestr %s. The checkpointable may be restored'
514+
' with different handler logic than was used for saving.',
515+
handler_typestr,
516+
)
504517

505-
# Prefer the first handler in the absence of any other information.
518+
# Prefer the last handler in the absence of any other information.
506519
return possible_handlers[-1]

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration_test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,23 @@ def test_resolve_handler_for_load_no_matching_name(self):
273273
handler_typestr='unused',
274274
)
275275

276-
def test_resolve_handler_for_load_checkpointable(self):
276+
def test_resolve_handler_for_load_no_handler_typestr(self):
277277
local_registry = registration.local_registry()
278278
local_registry.add(handler_utils.FooHandler)
279+
resolved = registration.resolve_handler_for_load(
280+
local_registry,
281+
handler_utils.AbstractFoo(),
282+
name='dummy_unregistered_nameame',
283+
handler_typestr=None,
284+
)
285+
self.assertIsInstance(resolved, handler_utils.FooHandler)
286+
279287
with self.assertRaises(registration.NoEntryError):
280288
registration.resolve_handler_for_load(
281289
local_registry,
282-
handler_utils.Foo(1, 'hi'),
283-
name='foo',
284-
handler_typestr='unused',
290+
handler_utils.AbstractBar(),
291+
name='unregistered_name',
292+
handler_typestr=None,
285293
)
286294

287295

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/resolution.py

Lines changed: 156 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,19 @@
1515
"""Logic for resolving handlers for saving and loading."""
1616
from __future__ import annotations
1717

18-
import itertools
1918
from typing import Any
2019

2120
from absl import logging
2221
from orbax.checkpoint._src.metadata import step_metadata_serialization
2322
from orbax.checkpoint.experimental.v1._src.handlers import registration
2423
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
2524
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
26-
from orbax.checkpoint.experimental.v1._src.path import types as path_types
2725

2826
InternalCheckpointMetadata = (
2927
step_metadata_serialization.InternalCheckpointMetadata
3028
)
3129

3230

33-
def _subdirs(directory: path_types.Path, *, limit: int = 3) -> list[str]:
34-
return list(
35-
itertools.islice(
36-
(subdir.name for subdir in directory.iterdir() if subdir.is_dir()),
37-
limit,
38-
)
39-
)
40-
41-
42-
_V0_ERROR_MESSAGE = (
43-
'If your checkpoint was saved with the Orbax V0 API, please follow the'
44-
' instructions at'
45-
' https://orbax.readthedocs.io/en/latest/guides/checkpoint/v1/orbax_v0_to_v1_migration.html'
46-
' to load it with the Orbax V1 API.'
47-
)
48-
49-
5031
def get_handlers_for_save(
5132
handler_registry: registration.CheckpointableHandlerRegistry,
5233
checkpointables: dict[str, Any],
@@ -60,60 +41,172 @@ def get_handlers_for_save(
6041
}
6142

6243

44+
def _resolve_single_handler_for_load(
45+
checkpointable_name: str,
46+
handler_registry: registration.CheckpointableHandlerRegistry,
47+
abstract_checkpointable: Any,
48+
metadata_handler_typestr: str | None,
49+
) -> handler_types.CheckpointableHandler:
50+
"""Logic to resolve a checkpointable's loading handler.
51+
52+
1. registration.resolve_handler_for_load performs handler discovery based on
53+
abstract_checkpointable type and handler_typestr.
54+
2. If this fails or if abstract_checkpointable and handler_typestr are not
55+
available, we try to resolve using the default pytree handler if registered.
56+
57+
Args:
58+
checkpointable_name: The checkpointable name to resolve the handler for.
59+
handler_registry: The handler registry to use for resolution.
60+
abstract_checkpointable: The abstract checkpointable to load.
61+
metadata_handler_typestr: The handler typestr from the checkpoint metadata.
62+
63+
Returns:
64+
The handler for the checkpointable.
65+
66+
Raises:
67+
registration.NoEntryError: If no handler is resolved and 'pytree' name is
68+
not registered.
69+
"""
70+
# 1. Resolve the handler using handler_typestr and
71+
# abstract_checkpointable type if either is specified.
72+
if abstract_checkpointable or metadata_handler_typestr:
73+
try:
74+
return registration.resolve_handler_for_load(
75+
handler_registry,
76+
abstract_checkpointable,
77+
name=checkpointable_name,
78+
handler_typestr=metadata_handler_typestr,
79+
)
80+
except registration.NoEntryError as e:
81+
logging.warning(
82+
"Failed to resolve handler for checkpointable: '%s'. Attempting to"
83+
" load using pytree handler, otherwise defaulting to a None"
84+
" return value. Error: %s",
85+
checkpointable_name,
86+
e,
87+
)
88+
else:
89+
logging.info(
90+
"No metadata present in checkpoint and no abstract checkpointable"
91+
" provided for checkpointable: '%s'. Attempting to load using"
92+
" pytree handler, otherwise defaulting to a None return value.",
93+
checkpointable_name,
94+
)
95+
96+
# 2. If no handler is resolved yet, try to resolve using the default
97+
# pytree handler.
98+
pytree_handler = registration.get_registered_handler_by_name(
99+
handler_registry, "pytree"
100+
)
101+
if not pytree_handler:
102+
raise registration.NoEntryError(
103+
f"Could not resolve a handler for '{checkpointable_name}' and no"
104+
f"'pytree' handler found in {handler_registry}).\n"
105+
"Please inspect the checkpoint contents via"
106+
" `loading.checkpointables_metadata`. You may need to provide an"
107+
" abstract_checkpointable or register a missing handler for this name"
108+
" or for 'pytree' name which is used as a fallback."
109+
)
110+
return pytree_handler
111+
112+
63113
async def get_handlers_for_load(
64-
directory: path_types.Path,
65114
handler_registry: registration.CheckpointableHandlerRegistry,
66115
abstract_checkpointables: dict[str, Any],
67116
checkpoint_metadata: InternalCheckpointMetadata,
68117
) -> dict[str, handler_types.CheckpointableHandler]:
69-
"""Returns a mapping from checkpointable name to handler."""
70-
existing_checkpointable_names_to_handler_typestrs = (
71-
await _get_saved_handler_typestrs(directory, checkpoint_metadata)
72-
)
73-
abstract_checkpointables = abstract_checkpointables or {
74-
name: None for name in existing_checkpointable_names_to_handler_typestrs
75-
}
76-
77-
loadable_checkpointable_names_to_handlers = {}
78-
for name, abstract_checkpointable in abstract_checkpointables.items():
79-
if name not in existing_checkpointable_names_to_handler_typestrs:
80-
raise KeyError(
81-
f'Checkpointable "{name}" was not found in the checkpoint.'
82-
' Available names:'
83-
f' {existing_checkpointable_names_to_handler_typestrs.keys()}'
84-
)
85-
handler_typestr = existing_checkpointable_names_to_handler_typestrs[name]
86-
handler = registration.resolve_handler_for_load(
118+
"""Returns a mapping from checkpointable name to handler.
119+
120+
Gathers and returns a mapping from checkpointable name to handler by
121+
checking the following in order:
122+
123+
1. Check for handler_typestr in checkpoint metadata item_handlers using
124+
checkpointable_name as key.
125+
2. Find the handler for each checkpointable using
126+
_resolve_single_handler_for_load.
127+
3. If no handler is resolved for a checkpointable, raise a NoEntryError.
128+
129+
Args:
130+
handler_registry: The handler registry to use for resolution.
131+
abstract_checkpointables: The abstract checkpointables to load.
132+
checkpoint_metadata: InternalCheckpointMetadata to read handler_typestr(s)
133+
from.
134+
135+
Returns:
136+
A mapping from checkpointable name to handler.
137+
138+
Raises:
139+
registration.NoEntryError: If no handler is resolved.
140+
"""
141+
handlers_for_load: dict[str, handler_types.CheckpointableHandler] = {}
142+
for (
143+
checkpointable_name,
144+
abstract_checkpointable,
145+
) in abstract_checkpointables.items():
146+
metadata_handler_typestr = _get_saved_handler_typestr(
147+
checkpointable_name, checkpoint_metadata
148+
)
149+
handlers_for_load[checkpointable_name] = _resolve_single_handler_for_load(
150+
checkpointable_name,
87151
handler_registry,
88152
abstract_checkpointable,
89-
name=name,
90-
handler_typestr=handler_typestr,
153+
metadata_handler_typestr,
91154
)
92-
loadable_checkpointable_names_to_handlers[name] = handler
93-
return loadable_checkpointable_names_to_handlers
155+
return handlers_for_load
156+
157+
158+
async def get_handler_for_load_direct_pytree(
159+
checkpointable_name: str,
160+
handler_registry: registration.CheckpointableHandlerRegistry,
161+
abstract_checkpointable: Any,
162+
checkpoint_metadata: InternalCheckpointMetadata,
163+
) -> handler_types.CheckpointableHandler:
164+
"""Returns a handler for direct load of a pytree checkpoint.
165+
166+
1. Check for checkpointable_name in checkpoint metadata item_handlers.
167+
2. resolve_handler_for_load performs handler discovery based on
168+
abstract_checkpointable type and handler_typestr.
169+
2. Find the handler for each checkpointable using
170+
_resolve_single_handler_for_load.
171+
3. If no handler is resolved for a checkpointable, raise a NoEntryError.
172+
173+
Args:
174+
checkpointable_name: The checkpointable name to resolve the handler for.
175+
handler_registry: The handler registry to use for resolution.
176+
abstract_checkpointable: The abstract checkpointable to load.
177+
checkpoint_metadata: InternalCheckpointMetadata to read handler_typestr
178+
from.
179+
180+
Returns:
181+
The handler for direct load of a pytree checkpoint.
182+
"""
183+
metadata_handler_typestr = _get_saved_handler_typestr_direct_pytree(
184+
checkpoint_metadata
185+
)
186+
return _resolve_single_handler_for_load(
187+
checkpointable_name,
188+
handler_registry,
189+
abstract_checkpointable,
190+
metadata_handler_typestr,
191+
)
94192

95193

96-
async def _get_saved_handler_typestrs(
97-
directory: path_types.Path,
194+
def _get_saved_handler_typestr(
195+
checkpointable_name: str,
98196
checkpoint_metadata: InternalCheckpointMetadata,
99-
) -> dict[str, str]:
197+
) -> str | None:
100198
"""Reads from the checkpoint metadata to get saved handler typestrs."""
101-
if checkpoint_metadata.item_handlers:
102-
if isinstance(checkpoint_metadata.item_handlers, dict):
103-
return checkpoint_metadata.item_handlers # found step level metadata.
104-
raise ValueError(
105-
f'Path at {directory} contains subdirectories:'
106-
f' {_subdirs(directory)}, which are expected to'
107-
' match the keys given by the _CHECKPOINT_METADATA file:'
108-
f' {checkpoint_metadata.item_handlers}. If you intended to load a'
109-
' pytree checkpoint from the given path, then please consider using'
110-
' `loading.load_pytree(..., checkpointable_name=None)` instead.'
111-
f' {_V0_ERROR_MESSAGE}'
112-
)
199+
if isinstance(checkpoint_metadata.item_handlers, dict) and (
200+
checkpointable_name in checkpoint_metadata.item_handlers
201+
):
202+
return checkpoint_metadata.item_handlers[checkpointable_name]
203+
return None
113204

114-
logging.warning(
115-
'Given dir does not contain checkpoint metadata file: %s. No handler'
116-
' typestrs found.',
117-
directory,
118-
)
119-
return {}
205+
206+
def _get_saved_handler_typestr_direct_pytree(
207+
checkpoint_metadata: InternalCheckpointMetadata,
208+
) -> str | None:
209+
"""Reads from the checkpoint metadata to get saved handler typestrs."""
210+
if isinstance(checkpoint_metadata.item_handlers, str):
211+
return checkpoint_metadata.item_handlers
212+
return None

0 commit comments

Comments
 (0)