Skip to content

Commit 919b8a8

Browse files
angel-coreOrbax Authors
authored andcommitted
#v1 Refactor logic for handler resolution and loading checkpointables + additional fallback capabilities for non-standard checkpoint formats.
PiperOrigin-RevId: 869904948
1 parent daec61e commit 919b8a8

File tree

14 files changed

+444
-743
lines changed

14 files changed

+444
-743
lines changed

checkpoint/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626
`load_checkpointables()` each with their own dedicated loading logic
2727
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
2828
tests
29+
- Refactor `CompositeHandler` logic into the orbax layout objects and handler
30+
resolution utility, deprecating and deleting the `CompositeHandler` class.
2931

3032
## [0.11.32] - 2026-01-20
3133

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py

Lines changed: 0 additions & 326 deletions
This file was deleted.

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ def _get_possible_handlers(
392392
return possible_handlers
393393

394394

395+
def get_registered_handler_by_name(
396+
registry: CheckpointableHandlerRegistry,
397+
name: str,
398+
) -> CheckpointableHandler | None:
399+
"""Returns the handler for the given name if registered."""
400+
if registry.has(name):
401+
return _construct_handler_instance(name, registry.get(name))
402+
return None
403+
404+
395405
def resolve_handler_for_save(
396406
registry: CheckpointableHandlerRegistry,
397407
checkpointable: Any,
@@ -444,7 +454,7 @@ def resolve_handler_for_load(
444454
abstract_checkpointable: Any | None,
445455
*,
446456
name: str,
447-
handler_typestr: str,
457+
handler_typestr: str | None = None,
448458
) -> CheckpointableHandler:
449459
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading.
450460
@@ -492,15 +502,16 @@ def is_handleable_fn(
492502
handler_types.typestr(type(handler)) for handler in possible_handlers
493503
]
494504

495-
try:
496-
idx = possible_handler_typestrs.index(handler_typestr)
497-
return possible_handlers[idx]
498-
except ValueError:
499-
logging.warning(
500-
'No handler found for typestr %s. The checkpointable may be restored'
501-
' with different handler logic than was used for saving.',
502-
handler_typestr,
503-
)
505+
if handler_typestr:
506+
try:
507+
idx = possible_handler_typestrs.index(handler_typestr)
508+
return possible_handlers[idx]
509+
except ValueError:
510+
logging.warning(
511+
'No handler found for typestr %s. The checkpointable may be restored'
512+
' with different handler logic than was used for saving.',
513+
handler_typestr,
514+
)
504515

505516
# Prefer the first handler in the absence of any other information.
506517
return possible_handlers[-1]

0 commit comments

Comments
 (0)