Skip to content

Commit 48372a0

Browse files
angel-coreOrbax Authors
authored andcommitted
Refine resolution logic to ensure valid handler matching for abstract checkpointables.
PiperOrigin-RevId: 874801701
1 parent d83491f commit 48372a0

File tree

451 files changed

+1533
-285
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

451 files changed

+1533
-285
lines changed

checkpoint/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ devices anyway.
4141
`load_checkpointables()` each with their own dedicated loading logic
4242
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
4343
tests
44+
- Refactor logic for handler resolution and loading checkpointables for
45+
`OrbaxLayout` and `OrbaxV0Layout`, adding additional fallback capabilities for
46+
non-standard checkpoint formats.
4447

4548
### Fixed
4649

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ def _get_possible_handlers(
392392
return possible_handlers
393393

394394

395+
def get_registered_handler_by_name(
396+
registry: CheckpointableHandlerRegistry,
397+
name: str,
398+
) -> CheckpointableHandler | None:
399+
"""Returns the handler for the given name if registered."""
400+
if registry.has(name):
401+
return _construct_handler_instance(name, registry.get(name))
402+
return None
403+
404+
395405
def resolve_handler_for_save(
396406
registry: CheckpointableHandlerRegistry,
397407
checkpointable: Any,
@@ -435,7 +445,7 @@ def is_handleable_fn(handler: CheckpointableHandler, ckpt: Any) -> bool:
435445
registry, is_handleable_fn, checkpointable, name
436446
)
437447

438-
# Prefer the first handler in the absence of any other information.
448+
# Prefer the last handler in the absence of any other information.
439449
return possible_handlers[-1]
440450

441451

@@ -444,7 +454,7 @@ def resolve_handler_for_load(
444454
abstract_checkpointable: Any | None,
445455
*,
446456
name: str,
447-
handler_typestr: str,
457+
handler_typestr: str | None = None,
448458
) -> CheckpointableHandler:
449459
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading.
450460
@@ -471,7 +481,9 @@ def resolve_handler_for_load(
471481
abstract_checkpointable: An abstract checkpointable to resolve.
472482
name: The name of the checkpointable.
473483
handler_typestr: A :py:class:`~.v1.handlers.CheckpointableHandler` typestr
474-
to guide resolution.
484+
to guide resolution. We allow a None value for handler_typestr as its
485+
possible to find the last registered handler given a specified
486+
abstract_checkpointable.
475487
476488
Returns:
477489
A :py:class:`~.v1.handlers.CheckpointableHandler` instance.
@@ -492,15 +504,24 @@ def is_handleable_fn(
492504
handler_types.typestr(type(handler)) for handler in possible_handlers
493505
]
494506

495-
try:
496-
idx = possible_handler_typestrs.index(handler_typestr)
497-
return possible_handlers[idx]
498-
except ValueError:
499-
logging.warning(
500-
'No handler found for typestr %s. The checkpointable may be restored'
501-
' with different handler logic than was used for saving.',
502-
handler_typestr,
503-
)
507+
if handler_typestr:
508+
try:
509+
idx = possible_handler_typestrs.index(handler_typestr)
510+
return possible_handlers[idx]
511+
except ValueError:
512+
logging.warning(
513+
'No handler found for typestr %s. The checkpointable may be restored'
514+
' with different handler logic than was used for saving.',
515+
handler_typestr,
516+
)
504517

505-
# Prefer the first handler in the absence of any other information.
506-
return possible_handlers[-1]
518+
if abstract_checkpointable:
519+
# Prefer the last handler in the absence of any other information.
520+
return possible_handlers[-1]
521+
522+
raise NoEntryError(
523+
f'No entry for checkpointable={name} in the registry, using'
524+
f' handler_typestr={handler_typestr} and'
525+
f' abstract_checkpointable={abstract_checkpointable}. Registry contents:'
526+
f' {registry.get_all_entries()}'
527+
)

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration_test.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -197,18 +197,31 @@ def test_resolve_handler_for_load(
197197

198198
def test_resolve_handler_for_load_resolution_order(self):
199199

200-
class HandlerOne(handler_utils.DictHandler):
201-
pass
200+
class HandlerOne(handler_utils.BazHandler):
201+
def is_abstract_handleable(
202+
self, abstract_checkpointable: handler_utils.AbstractBaz
203+
) -> bool:
204+
return isinstance(abstract_checkpointable, handler_utils.AbstractBaz)
202205

203-
class HandlerTwo(handler_utils.DictHandler):
204-
pass
206+
class HandlerTwo(handler_utils.BazHandler):
207+
def is_abstract_handleable(
208+
self, abstract_checkpointable: handler_utils.AbstractBaz
209+
) -> bool:
210+
return isinstance(abstract_checkpointable, handler_utils.AbstractBaz)
205211

206212
handlers_to_register = [HandlerOne, HandlerTwo]
207213

208214
with self.subTest('globally_registered'):
215+
with self.assertRaises(registration.NoEntryError):
216+
registration.resolve_handler_for_load(
217+
registration.local_registry(),
218+
None,
219+
name='checkpointable_name',
220+
handler_typestr='unknown_class',
221+
)
209222
resolved_handler = registration.resolve_handler_for_load(
210223
registration.local_registry(),
211-
None,
224+
handler_utils.AbstractBaz(),
212225
name='checkpointable_name',
213226
handler_typestr='unknown_class',
214227
)
@@ -219,9 +232,16 @@ class HandlerTwo(handler_utils.DictHandler):
219232
)
220233
for handler in handlers_to_register:
221234
local_registry.add(handler)
235+
with self.assertRaises(registration.NoEntryError):
236+
registration.resolve_handler_for_load(
237+
local_registry,
238+
None,
239+
name='checkpointable_name',
240+
handler_typestr='unknown_class',
241+
)
222242
resolved_handler = registration.resolve_handler_for_load(
223243
local_registry,
224-
None,
244+
handler_utils.AbstractBaz(),
225245
name='checkpointable_name',
226246
handler_typestr='unknown_class',
227247
)
@@ -243,9 +263,16 @@ class HandlerTwo(handler_utils.DictHandler):
243263
)
244264
for handler in reversed(handlers_to_register):
245265
local_registry.add(handler)
266+
with self.assertRaises(registration.NoEntryError):
267+
registration.resolve_handler_for_load(
268+
local_registry,
269+
None,
270+
name='checkpointable_name',
271+
handler_typestr='unknown_class',
272+
)
246273
resolved_handler = registration.resolve_handler_for_load(
247274
local_registry,
248-
None,
275+
handler_utils.AbstractBaz(),
249276
name='checkpointable_name',
250277
handler_typestr='unknown_class',
251278
)
@@ -273,15 +300,33 @@ def test_resolve_handler_for_load_no_matching_name(self):
273300
handler_typestr='unused',
274301
)
275302

276-
def test_resolve_handler_for_load_checkpointable(self):
303+
def test_resolve_handler_for_load_no_handler_typestr(self):
277304
local_registry = registration.local_registry()
278305
local_registry.add(handler_utils.FooHandler)
306+
resolved = registration.resolve_handler_for_load(
307+
local_registry,
308+
handler_utils.AbstractFoo(),
309+
name='unregistered_name',
310+
handler_typestr=None,
311+
)
312+
self.assertIsInstance(resolved, handler_utils.FooHandler)
313+
279314
with self.assertRaises(registration.NoEntryError):
280315
registration.resolve_handler_for_load(
281316
local_registry,
282-
handler_utils.Foo(1, 'hi'),
283-
name='foo',
284-
handler_typestr='unused',
317+
handler_utils.AbstractBar(),
318+
name='unregistered_name',
319+
handler_typestr=None,
320+
)
321+
322+
def test_resolve_handler_for_load_no_checkpointable_no_metadata(self):
323+
local_registry = registration.local_registry()
324+
with self.assertRaises(registration.NoEntryError):
325+
registration.resolve_handler_for_load(
326+
local_registry,
327+
None,
328+
name='unregistered_name',
329+
handler_typestr=None,
285330
)
286331

287332

0 commit comments

Comments
 (0)