Skip to content

Commit 2fd468e

Browse files
angel-coreOrbax Authors
authored andcommitted
Fix resolution logic to correctly find handler based on checkpointable_name when missing abstract_checkpointable or checkpoint metadata.
PiperOrigin-RevId: 874801701
1 parent 9b22138 commit 2fd468e

File tree

448 files changed

+1449
-270
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

448 files changed

+1449
-270
lines changed

checkpoint/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ devices anyway.
4141
`load_checkpointables()` each with their own dedicated loading logic
4242
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
4343
tests
44+
- Refactor logic for handler resolution and loading checkpointables for
45+
`OrbaxLayout` and `OrbaxV0Layout`, adding additional fallback capabilities for
46+
non-standard checkpoint formats.
4447

4548
### Fixed
4649

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ def _get_possible_handlers(
392392
return possible_handlers
393393

394394

395+
def get_registered_handler_by_name(
396+
registry: CheckpointableHandlerRegistry,
397+
name: str,
398+
) -> CheckpointableHandler | None:
399+
"""Returns the handler for the given name if registered."""
400+
if registry.has(name):
401+
return _construct_handler_instance(name, registry.get(name))
402+
return None
403+
404+
395405
def resolve_handler_for_save(
396406
registry: CheckpointableHandlerRegistry,
397407
checkpointable: Any,
@@ -435,7 +445,7 @@ def is_handleable_fn(handler: CheckpointableHandler, ckpt: Any) -> bool:
435445
registry, is_handleable_fn, checkpointable, name
436446
)
437447

438-
# Prefer the first handler in the absence of any other information.
448+
# Prefer the last handler in the absence of any other information.
439449
return possible_handlers[-1]
440450

441451

@@ -444,7 +454,7 @@ def resolve_handler_for_load(
444454
abstract_checkpointable: Any | None,
445455
*,
446456
name: str,
447-
handler_typestr: str,
457+
handler_typestr: str | None = None,
448458
) -> CheckpointableHandler:
449459
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading.
450460
@@ -471,7 +481,9 @@ def resolve_handler_for_load(
471481
abstract_checkpointable: An abstract checkpointable to resolve.
472482
name: The name of the checkpointable.
473483
handler_typestr: A :py:class:`~.v1.handlers.CheckpointableHandler` typestr
474-
to guide resolution.
484+
to guide resolution. We allow a None value for handler_typestr as its
485+
possible to find the last registered handler given a specified
486+
abstract_checkpointable.
475487
476488
Returns:
477489
A :py:class:`~.v1.handlers.CheckpointableHandler` instance.
@@ -492,15 +504,23 @@ def is_handleable_fn(
492504
handler_types.typestr(type(handler)) for handler in possible_handlers
493505
]
494506

495-
try:
496-
idx = possible_handler_typestrs.index(handler_typestr)
497-
return possible_handlers[idx]
498-
except ValueError:
499-
logging.warning(
500-
'No handler found for typestr %s. The checkpointable may be restored'
501-
' with different handler logic than was used for saving.',
502-
handler_typestr,
503-
)
504-
505-
# Prefer the first handler in the absence of any other information.
506-
return possible_handlers[-1]
507+
if handler_typestr:
508+
try:
509+
idx = possible_handler_typestrs.index(handler_typestr)
510+
return possible_handlers[idx]
511+
except ValueError:
512+
logging.warning(
513+
'No handler found for typestr %s. The checkpointable may be restored'
514+
' with different handler logic than was used for saving.',
515+
handler_typestr,
516+
)
517+
# For non-None, specified abstract_checkpointable, prefer the last handler in
518+
# the absence of any other information.
519+
if abstract_checkpointable:
520+
return possible_handlers[-1]
521+
522+
raise NoEntryError(
523+
f'No entry for checkpointable={name} found in the registry: {registry},'
524+
f' using abstract_checkpointable={abstract_checkpointable} and'
525+
f' handler_typestr={handler_typestr}'
526+
)

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration_test.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,33 @@ def test_resolve_handler_for_load_no_matching_name(self):
273273
handler_typestr='unused',
274274
)
275275

276-
def test_resolve_handler_for_load_checkpointable(self):
276+
def test_resolve_handler_for_load_no_handler_typestr(self):
277277
local_registry = registration.local_registry()
278278
local_registry.add(handler_utils.FooHandler)
279+
resolved = registration.resolve_handler_for_load(
280+
local_registry,
281+
handler_utils.AbstractFoo(),
282+
name='dummy_unregistered_nameame',
283+
handler_typestr=None,
284+
)
285+
self.assertIsInstance(resolved, handler_utils.FooHandler)
286+
279287
with self.assertRaises(registration.NoEntryError):
280288
registration.resolve_handler_for_load(
281289
local_registry,
282-
handler_utils.Foo(1, 'hi'),
283-
name='foo',
284-
handler_typestr='unused',
290+
handler_utils.AbstractBar(),
291+
name='unregistered_name',
292+
handler_typestr=None,
293+
)
294+
295+
def test_resolve_handler_for_load_no_checkpointable_or_handler_typestr(self):
296+
local_registry = registration.local_registry()
297+
with self.assertRaises(registration.NoEntryError):
298+
registration.resolve_handler_for_load(
299+
local_registry,
300+
None,
301+
name='unregistered_name',
302+
handler_typestr=None,
285303
)
286304

287305

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/resolution.py

Lines changed: 146 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,19 @@
1515
"""Logic for resolving handlers for saving and loading."""
1616
from __future__ import annotations
1717

18-
import itertools
1918
from typing import Any
2019

2120
from absl import logging
2221
from orbax.checkpoint._src.metadata import step_metadata_serialization
2322
from orbax.checkpoint.experimental.v1._src.handlers import registration
2423
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
2524
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
26-
from orbax.checkpoint.experimental.v1._src.path import types as path_types
2725

2826
InternalCheckpointMetadata = (
2927
step_metadata_serialization.InternalCheckpointMetadata
3028
)
3129

3230

33-
def _subdirs(directory: path_types.Path, *, limit: int = 3) -> list[str]:
34-
return list(
35-
itertools.islice(
36-
(subdir.name for subdir in directory.iterdir() if subdir.is_dir()),
37-
limit,
38-
)
39-
)
40-
41-
42-
_V0_ERROR_MESSAGE = (
43-
'If your checkpoint was saved with the Orbax V0 API, please follow the'
44-
' instructions at'
45-
' https://orbax.readthedocs.io/en/latest/guides/checkpoint/v1/orbax_v0_to_v1_migration.html'
46-
' to load it with the Orbax V1 API.'
47-
)
48-
49-
5031
def get_handlers_for_save(
5132
handler_registry: registration.CheckpointableHandlerRegistry,
5233
checkpointables: dict[str, Any],
@@ -60,60 +41,162 @@ def get_handlers_for_save(
6041
}
6142

6243

44+
def _resolve_single_handler_for_load(
45+
checkpointable_name: str,
46+
handler_registry: registration.CheckpointableHandlerRegistry,
47+
abstract_checkpointable: Any,
48+
metadata_handler_typestr: str | None,
49+
) -> handler_types.CheckpointableHandler:
50+
"""Logic to resolve a checkpointable's loading handler.
51+
52+
1. registration.resolve_handler_for_load performs handler discovery based on
53+
abstract_checkpointable type and handler_typestr.
54+
2. If this fails or if abstract_checkpointable and handler_typestr are not
55+
available, we try to resolve using the default pytree handler if registered.
56+
57+
Args:
58+
checkpointable_name: The checkpointable name to resolve the handler for.
59+
handler_registry: The handler registry to use for resolution.
60+
abstract_checkpointable: The abstract checkpointable to load.
61+
metadata_handler_typestr: The handler typestr from the checkpoint metadata.
62+
63+
Returns:
64+
The handler for the checkpointable.
65+
66+
Raises:
67+
registration.NoEntryError: If no handler is resolved and 'pytree' name is
68+
not registered.
69+
"""
70+
# 1. Resolve the checkpointable's handler using handler discovery.
71+
try:
72+
return registration.resolve_handler_for_load(
73+
handler_registry,
74+
abstract_checkpointable,
75+
name=checkpointable_name,
76+
handler_typestr=metadata_handler_typestr,
77+
)
78+
except registration.NoEntryError as e:
79+
logging.warning(
80+
"Failed to resolve handler for checkpointable: '%s'. Attempting to"
81+
" load using pytree handler. Error: %s",
82+
checkpointable_name,
83+
e,
84+
)
85+
86+
# 2. If no handler is resolved yet, try to resolve using the default
87+
# pytree handler.
88+
pytree_handler = registration.get_registered_handler_by_name(
89+
handler_registry, "pytree"
90+
)
91+
if not pytree_handler:
92+
raise registration.NoEntryError(
93+
f"Could not resolve a handler for '{checkpointable_name}' and no"
94+
f" 'pytree' handler found in {handler_registry})."
95+
"Please inspect the checkpoint contents via"
96+
" `loading.checkpointables_metadata`. You may need to provide an"
97+
" abstract_checkpointable or register a missing handler for this name"
98+
" or for 'pytree' name which is used as a fallback."
99+
)
100+
return pytree_handler
101+
102+
63103
async def get_handlers_for_load(
64-
directory: path_types.Path,
65104
handler_registry: registration.CheckpointableHandlerRegistry,
66105
abstract_checkpointables: dict[str, Any],
67106
checkpoint_metadata: InternalCheckpointMetadata,
68107
) -> dict[str, handler_types.CheckpointableHandler]:
69-
"""Returns a mapping from checkpointable name to handler."""
70-
existing_checkpointable_names_to_handler_typestrs = (
71-
await _get_saved_handler_typestrs(directory, checkpoint_metadata)
72-
)
73-
abstract_checkpointables = abstract_checkpointables or {
74-
name: None for name in existing_checkpointable_names_to_handler_typestrs
75-
}
76-
77-
loadable_checkpointable_names_to_handlers = {}
78-
for name, abstract_checkpointable in abstract_checkpointables.items():
79-
if name not in existing_checkpointable_names_to_handler_typestrs:
80-
raise KeyError(
81-
f'Checkpointable "{name}" was not found in the checkpoint.'
82-
' Available names:'
83-
f' {existing_checkpointable_names_to_handler_typestrs.keys()}'
84-
)
85-
handler_typestr = existing_checkpointable_names_to_handler_typestrs[name]
86-
handler = registration.resolve_handler_for_load(
108+
"""Returns a mapping from checkpointable name to handler.
109+
110+
Gathers and returns a mapping from checkpointable name to handler by
111+
checking the following in order:
112+
113+
1. Check for handler_typestr in checkpoint metadata item_handlers using
114+
checkpointable_name as key.
115+
2. Find the handler for each checkpointable using
116+
_resolve_single_handler_for_load.
117+
3. If no handler is resolved for a checkpointable, raise a NoEntryError.
118+
119+
Args:
120+
handler_registry: The handler registry to use for resolution.
121+
abstract_checkpointables: The abstract checkpointables to load.
122+
checkpoint_metadata: InternalCheckpointMetadata to read handler_typestr(s)
123+
from.
124+
125+
Returns:
126+
A mapping from checkpointable name to handler.
127+
128+
Raises:
129+
registration.NoEntryError: If no handler is resolved.
130+
"""
131+
handlers_for_load: dict[str, handler_types.CheckpointableHandler] = {}
132+
for (
133+
checkpointable_name,
134+
abstract_checkpointable,
135+
) in abstract_checkpointables.items():
136+
metadata_handler_typestr = _get_saved_handler_typestr(
137+
checkpointable_name, checkpoint_metadata
138+
)
139+
handlers_for_load[checkpointable_name] = _resolve_single_handler_for_load(
140+
checkpointable_name,
87141
handler_registry,
88142
abstract_checkpointable,
89-
name=name,
90-
handler_typestr=handler_typestr,
143+
metadata_handler_typestr,
91144
)
92-
loadable_checkpointable_names_to_handlers[name] = handler
93-
return loadable_checkpointable_names_to_handlers
145+
return handlers_for_load
94146

95147

96-
async def _get_saved_handler_typestrs(
97-
directory: path_types.Path,
148+
async def get_handler_for_load_direct_pytree(
149+
checkpointable_name: str,
150+
handler_registry: registration.CheckpointableHandlerRegistry,
151+
abstract_checkpointable: Any,
98152
checkpoint_metadata: InternalCheckpointMetadata,
99-
) -> dict[str, str]:
153+
) -> handler_types.CheckpointableHandler:
154+
"""Returns a handler for direct load of a pytree checkpoint.
155+
156+
1. Check for checkpointable_name in checkpoint metadata item_handlers.
157+
2. resolve_handler_for_load performs handler discovery based on
158+
abstract_checkpointable type and handler_typestr.
159+
2. Find the handler for each checkpointable using
160+
_resolve_single_handler_for_load.
161+
3. If no handler is resolved for a checkpointable, raise a NoEntryError.
162+
163+
Args:
164+
checkpointable_name: The checkpointable name to resolve the handler for.
165+
handler_registry: The handler registry to use for resolution.
166+
abstract_checkpointable: The abstract checkpointable to load.
167+
checkpoint_metadata: InternalCheckpointMetadata to read handler_typestr
168+
from.
169+
170+
Returns:
171+
The handler for direct load of a pytree checkpoint.
172+
"""
173+
metadata_handler_typestr = _get_saved_handler_typestr_direct_pytree(
174+
checkpoint_metadata
175+
)
176+
return _resolve_single_handler_for_load(
177+
checkpointable_name,
178+
handler_registry,
179+
abstract_checkpointable,
180+
metadata_handler_typestr,
181+
)
182+
183+
184+
def _get_saved_handler_typestr(
185+
checkpointable_name: str,
186+
checkpoint_metadata: InternalCheckpointMetadata,
187+
) -> str | None:
100188
"""Reads from the checkpoint metadata to get saved handler typestrs."""
101-
if checkpoint_metadata.item_handlers:
102-
if isinstance(checkpoint_metadata.item_handlers, dict):
103-
return checkpoint_metadata.item_handlers # found step level metadata.
104-
raise ValueError(
105-
f'Path at {directory} contains subdirectories:'
106-
f' {_subdirs(directory)}, which are expected to'
107-
' match the keys given by the _CHECKPOINT_METADATA file:'
108-
f' {checkpoint_metadata.item_handlers}. If you intended to load a'
109-
' pytree checkpoint from the given path, then please consider using'
110-
' `loading.load_pytree(..., checkpointable_name=None)` instead.'
111-
f' {_V0_ERROR_MESSAGE}'
112-
)
189+
if isinstance(checkpoint_metadata.item_handlers, dict) and (
190+
checkpointable_name in checkpoint_metadata.item_handlers
191+
):
192+
return checkpoint_metadata.item_handlers[checkpointable_name]
193+
return None
113194

114-
logging.warning(
115-
'Given dir does not contain checkpoint metadata file: %s. No handler'
116-
' typestrs found.',
117-
directory,
118-
)
119-
return {}
195+
196+
def _get_saved_handler_typestr_direct_pytree(
197+
checkpoint_metadata: InternalCheckpointMetadata,
198+
) -> str | None:
199+
"""Reads from the checkpoint metadata to get saved handler typestrs."""
200+
if isinstance(checkpoint_metadata.item_handlers, str):
201+
return checkpoint_metadata.item_handlers
202+
return None

0 commit comments

Comments
 (0)