Skip to content

Commit 5fbb01b

Browse files
angel-coreOrbax Authors
authored andcommitted
Introduce registered handler to alternate handler typestrs mapping to v1 API.
PiperOrigin-RevId: 874792153
1 parent 60b50ba commit 5fbb01b

File tree

510 files changed

+1453
-306
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

510 files changed

+1453
-306
lines changed

checkpoint/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ devices anyway.
4141
`load_checkpointables()` each with their own dedicated loading logic
4242
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
4343
tests
44+
- Refactor logic for handler resolution and loading checkpointables for
45+
`OrbaxLayout` and `OrbaxV0Layout`, adding additional fallback capabilities for
46+
non-standard checkpoint formats.
4447

4548
### Fixed
4649

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
given checkpointable will be used.
2121
"""
2222

23-
from typing import Type
23+
from typing import Sequence, Type
2424

2525
from orbax.checkpoint.experimental.v1._src.handlers import json_handler
2626
from orbax.checkpoint.experimental.v1._src.handlers import proto_handler
@@ -34,23 +34,45 @@
3434
def _try_register_handler(
3535
handler_type: Type[handler_types.CheckpointableHandler],
3636
name: str | None = None,
37+
secondary_typestrs: Sequence[str] | None = None,
3738
):
39+
"""Tries to register handler globally with name and recognized typestrs."""
3840
try:
39-
registration.global_registry().add(handler_type, name)
41+
registration.global_registry().add(
42+
handler_type,
43+
name,
44+
secondary_typestrs=secondary_typestrs,
45+
)
4046
except registration.AlreadyExistsError:
4147
pass
4248

4349

44-
_try_register_handler(proto_handler.ProtoHandler)
45-
_try_register_handler(json_handler.JsonHandler)
50+
_try_register_handler(
51+
proto_handler.ProtoHandler,
52+
secondary_typestrs=[
53+
'orbax.checkpoint._src.handlers.proto_checkpoint_handler.ProtoCheckpointHandler',
54+
],
55+
)
56+
_try_register_handler(
57+
json_handler.JsonHandler,
58+
secondary_typestrs=[
59+
'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler',
60+
],
61+
)
4662
_try_register_handler(
4763
stateful_checkpointable_handler.StatefulCheckpointableHandler
4864
)
4965
_try_register_handler(
5066
json_handler.MetricsHandler,
5167
checkpoint_layout.METRICS_CHECKPOINTABLE_KEY,
5268
)
53-
_try_register_handler(pytree_handler.PyTreeHandler)
69+
_try_register_handler(
70+
pytree_handler.PyTreeHandler,
71+
secondary_typestrs=[
72+
'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler',
73+
'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler',
74+
],
75+
)
5476
_try_register_handler(
5577
pytree_handler.PyTreeHandler, checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
5678
)

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py

Lines changed: 112 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,31 @@
3535
registry.add(BarHandler)
3636
# Scope this handler specifically to checkpointables named 'baz'.
3737
registry.add(BazHandler, 'baz')
38+
# Scope this handler to legacy typestr 'OldBazHandlerTypestr'. Secondary
39+
# typestrs provide a way to map legacy identifiers to a new handler class.
40+
registry.add(BazHandler, secondary_typestrs=['OldBazHandlerTypestr'])
3841
3942
checkpointables_options = ocp.options.CheckpointablesOptions(
4043
registry=registry
4144
)
4245
with ocp.Context(checkpointables_options=checkpointables_options):
4346
ocp.save_checkpointables(...)
4447
45-
If a registered handler is scoped to a specific name (e.g.
46-
`registry.add(BazHandler, 'baz')`), then this handler will always be
47-
prioritized for saving or loading the checkpointable with that name, even if
48-
the handler is not capable of saving/loading the checkpointable.
49-
50-
In the most common case, where a handler is not scoped to a specific name,
51-
a given checkpointable (or abstract_checkpointable) will be resolved to a
52-
handler returning True for `is_handleable` (or `is_abstract_handleable`),
53-
respectively. If multiple handlers are usable, the first usable handler will be
54-
returned. When loading, the handler type used for saving will be recorded in
55-
the metadata, and will be used to resolve the handler, if a corresponding
56-
handler is present in the registry.
48+
Handler resolution for saving/loading follows this logic:
49+
50+
1. If a registered handler is scoped to a specific name
51+
(e.g. `registry.add(BazHandler, 'baz')`), then this handler will always
52+
be prioritized for saving or loading the checkpointable with that name,
53+
even if the handler is not capable of saving/loading the checkpointable.
54+
2. In the absence of an explicit name match, the registry filters for
55+
handlers returning `True` for `is_handleable` (during save) or
56+
`is_abstract_handleable` (during load).
57+
3. [Pertains to loading only] The handler type used for saving will be
58+
recorded in the metadata, and will be used to resolve the handler, if a
59+
corresponding handler is present in the registry. If not, scan the
60+
secondary typestrs of registered handlers for a match.
61+
4. If no metadata match is found (or during saving), the most recently
62+
registered capable handler is returned.
5763
"""
5864

5965
from __future__ import annotations
@@ -73,7 +79,13 @@ def add_all(
7379
) -> CheckpointableHandlerRegistry:
7480
"""Adds all entries from `other_registry` to `registry`."""
7581
for handler, checkpointable in other_registry.get_all_entries():
76-
registry.add(handler, checkpointable)
82+
registry.add(
83+
handler,
84+
checkpointable,
85+
secondary_typestrs=other_registry.get_secondary_typestrs(
86+
handler
87+
),
88+
)
7789
return registry
7890

7991

@@ -87,6 +99,7 @@ def add(
8799
self,
88100
handler_type: Type[CheckpointableHandler],
89101
checkpointable: str | None = None,
102+
secondary_typestrs: Sequence[str] | None = None,
90103
) -> CheckpointableHandlerRegistry:
91104
"""Adds an entry to the registry."""
92105
...
@@ -110,6 +123,13 @@ def get_all_entries(
110123
) -> Sequence[RegistryEntry]:
111124
...
112125

126+
def get_secondary_typestrs(
127+
self,
128+
handler_type: Type[CheckpointableHandler],
129+
) -> Sequence[str]:
130+
"""Returns all secondary typestrs associated with the given handler type."""
131+
...
132+
113133

114134
class AlreadyExistsError(ValueError):
115135
"""Raised when an entry already exists in the registry."""
@@ -126,6 +146,9 @@ def __init__(
126146
self, other_registry: CheckpointableHandlerRegistry | None = None
127147
):
128148
self._registry: list[RegistryEntry] = []
149+
self._secondary_typestrs: dict[
150+
Type[CheckpointableHandler], Sequence[str]
151+
] = {}
129152

130153
# Initialize the registry with entries from other registry.
131154
if other_registry:
@@ -135,14 +158,25 @@ def add(
135158
self,
136159
handler_type: Type[CheckpointableHandler],
137160
checkpointable: str | None = None,
161+
secondary_typestrs: Sequence[str] | None = None,
138162
) -> CheckpointableHandlerRegistry:
139163
"""Adds an entry to the registry.
140164
165+
Adds a primary handler_type to the registry with an optional checkpointable
166+
name and an optional sequence of secondary typestrs that can be used to
167+
identify the handler.
168+
169+
Note: We only guarantee unique handler type entries in the registry and do
170+
not explicitly prevent a primary handler type from being registered and its
171+
typestr being used as a secondary_typestr for itself or another handler.
172+
141173
Args:
142174
handler_type: The handler type.
143175
checkpointable: The checkpointable name. If not-None, the registered
144176
handler will be scoped to that specific name. Otherwise, the handler
145177
will be available for any checkpointable name.
178+
secondary_typestrs: A sequence of alternate typestrs that serve as
179+
secondary identifiers for the handler.
146180
147181
Returns:
148182
The registry itself.
@@ -170,6 +204,8 @@ def add(
170204
f'Handler type {handler_type} already exists in the registry.'
171205
)
172206
self._registry.append((handler_type, checkpointable))
207+
if secondary_typestrs is not None:
208+
self._secondary_typestrs[handler_type] = secondary_typestrs
173209
return self
174210

175211
def get(
@@ -220,6 +256,13 @@ def get_all_entries(
220256
"""Returns all entries in the registry."""
221257
return self._registry
222258

259+
def get_secondary_typestrs(
260+
self,
261+
handler_type: Type[CheckpointableHandler],
262+
) -> Sequence[str]:
263+
"""Returns all secondary typestrs associated with the given handler type."""
264+
return self._secondary_typestrs.get(handler_type, [])
265+
223266
def __repr__(self):
224267
return f'_DefaultCheckpointableHandlerRegistry({self.get_all_entries()})'
225268

@@ -237,6 +280,7 @@ def add(
237280
self,
238281
handler_type: Type[CheckpointableHandler],
239282
checkpointable: str | None = None,
283+
secondary_typestrs: Sequence[str] | None = None,
240284
) -> CheckpointableHandlerRegistry:
241285
raise NotImplementedError('Adding not implemented for read-only registry.')
242286

@@ -257,6 +301,12 @@ def get_all_entries(
257301
) -> Sequence[RegistryEntry]:
258302
return self._registry.get_all_entries()
259303

304+
def get_secondary_typestrs(
305+
self,
306+
handler_type: Type[CheckpointableHandler],
307+
) -> Sequence[str]:
308+
return self._registry.get_secondary_typestrs(handler_type)
309+
260310
def __repr__(self):
261311
return f'ReadOnlyCheckpointableHandlerRegistry({self.get_all_entries()})'
262312

@@ -303,6 +353,8 @@ def local_registry(
303353

304354
def register_handler(
305355
cls: CheckpointableHandlerType,
356+
*,
357+
secondary_typestrs: Sequence[str] | None = None,
306358
) -> CheckpointableHandlerType:
307359
"""Registers a :py:class:`~.v1.handlers.CheckpointableHandler` globally.
308360
@@ -322,16 +374,20 @@ class FooHandler(ocp.handlers.CheckpointableHandler[Foo, AbstractFoo]):
322374
323375
Args:
324376
cls: The handler class.
377+
secondary_typestrs: A sequence of alternate handler typestrs that serve as
378+
secondary identifiers for the handler.
325379
326380
Returns:
327381
The handler class.
328382
"""
329-
_GLOBAL_REGISTRY.add(cls)
383+
_GLOBAL_REGISTRY.add(
384+
cls, secondary_typestrs=secondary_typestrs
385+
)
330386
return cls
331387

332388

333389
def _construct_handler_instance(
334-
name: str,
390+
name: str | None,
335391
handler_type: Type[CheckpointableHandler],
336392
) -> CheckpointableHandler:
337393
"""Attempts to default-construct a handler type if possible."""
@@ -392,6 +448,16 @@ def _get_possible_handlers(
392448
return possible_handlers
393449

394450

451+
def get_registered_handler_by_name(
452+
registry: CheckpointableHandlerRegistry,
453+
name: str,
454+
) -> CheckpointableHandler | None:
455+
"""Returns the handler for the given name if registered."""
456+
if registry.has(name):
457+
return _construct_handler_instance(name, registry.get(name))
458+
return None
459+
460+
395461
def resolve_handler_for_save(
396462
registry: CheckpointableHandlerRegistry,
397463
checkpointable: Any,
@@ -435,7 +501,7 @@ def is_handleable_fn(handler: CheckpointableHandler, ckpt: Any) -> bool:
435501
registry, is_handleable_fn, checkpointable, name
436502
)
437503

438-
# Prefer the first handler in the absence of any other information.
504+
# Prefer the last handler in the absence of any other information.
439505
return possible_handlers[-1]
440506

441507

@@ -444,7 +510,7 @@ def resolve_handler_for_load(
444510
abstract_checkpointable: Any | None,
445511
*,
446512
name: str,
447-
handler_typestr: str,
513+
handler_typestr: str | None = None,
448514
) -> CheckpointableHandler:
449515
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading.
450516
@@ -456,8 +522,9 @@ def resolve_handler_for_load(
456522
4. If multiple handlers are usable, return the handler with the matching
457523
typestr. If no matching typestr is found, then the handler used for saving
458524
may not be available now.
459-
4. Return the *last* usable handler. This allows us to resolve the most
460-
recently-registered handler.
525+
5. Return the *last* usable handler. This allows us to resolve the most
526+
recently-registered handler, unless abstract_checkpointable is None, in
527+
which case raise a NoEntryError.
461528
462529
Raises:
463530
NoEntryError: If no compatible
@@ -471,7 +538,9 @@ def resolve_handler_for_load(
471538
abstract_checkpointable: An abstract checkpointable to resolve.
472539
name: The name of the checkpointable.
473540
handler_typestr: A :py:class:`~.v1.handlers.CheckpointableHandler` typestr
474-
to guide resolution.
541+
to guide resolution. We allow a None value for handler_typestr as its
542+
possible to find the last registered handler given a specified
543+
abstract_checkpointable.
475544
476545
Returns:
477546
A :py:class:`~.v1.handlers.CheckpointableHandler` instance.
@@ -492,15 +561,30 @@ def is_handleable_fn(
492561
handler_types.typestr(type(handler)) for handler in possible_handlers
493562
]
494563

495-
try:
496-
idx = possible_handler_typestrs.index(handler_typestr)
497-
return possible_handlers[idx]
498-
except ValueError:
564+
if handler_typestr:
565+
if handler_typestr in possible_handler_typestrs:
566+
idx = possible_handler_typestrs.index(handler_typestr)
567+
return possible_handlers[idx]
568+
# Attempt to find a handler with a matching secondary typestr.
569+
for i in reversed(range(len(possible_handlers))):
570+
if handler_typestr in registry.get_secondary_typestrs(
571+
type(possible_handlers[i])
572+
):
573+
return possible_handlers[i]
499574
logging.warning(
500-
'No handler found for typestr %s. The checkpointable may be restored'
501-
' with different handler logic than was used for saving.',
575+
'No handler found for typestr %s (or its converted form). The '
576+
'checkpointable may be restored with different handler logic '
577+
'than was used for saving.',
502578
handler_typestr,
503579
)
504580

505-
# Prefer the first handler in the absence of any other information.
506-
return possible_handlers[-1]
581+
if abstract_checkpointable:
582+
# Prefer the last handler in the absence of any other information.
583+
return possible_handlers[-1]
584+
585+
raise NoEntryError(
586+
f'No entry for checkpointable={name} in the registry, using'
587+
f' handler_typestr={handler_typestr} and'
588+
f' abstract_checkpointable={abstract_checkpointable}. Registry contents:'
589+
f' {registry.get_all_entries()}'
590+
)

0 commit comments

Comments
 (0)