Skip to content

Commit 6765279

Browse files
angel-coreOrbax Authors
authored andcommitted
Create v1 ocp.load_checkpointables backwards compatibility tests against static v0 and v1 checkpoints.
PiperOrigin-RevId: 875725609
1 parent 144e79c commit 6765279

File tree

438 files changed

+2328
-291
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

438 files changed

+2328
-291
lines changed

checkpoint/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ devices anyway.
4141
`load_checkpointables()` each with their own dedicated loading logic
4242
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
4343
tests
44+
- Refactor logic for handler resolution and loading checkpointables for
45+
`OrbaxLayout` and `OrbaxV0Layout`, adding additional fallback capabilities for
46+
non-standard checkpoint formats.
4447

4548
### Fixed
4649

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
given checkpointable will be used.
2121
"""
2222

23-
from typing import Type
23+
from typing import Sequence, Type
2424

2525
from orbax.checkpoint.experimental.v1._src.handlers import json_handler
2626
from orbax.checkpoint.experimental.v1._src.handlers import proto_handler
@@ -34,23 +34,45 @@
3434
def _try_register_handler(
3535
handler_type: Type[handler_types.CheckpointableHandler],
3636
name: str | None = None,
37+
recognized_handler_typestrs: Sequence[str] | None = None,
3738
):
39+
"""Tries to register handler globally with name and recognized typestrs."""
3840
try:
39-
registration.global_registry().add(handler_type, name)
41+
registration.global_registry().add(
42+
handler_type,
43+
name,
44+
recognized_handler_typestrs=recognized_handler_typestrs,
45+
)
4046
except registration.AlreadyExistsError:
4147
pass
4248

4349

44-
_try_register_handler(proto_handler.ProtoHandler)
45-
_try_register_handler(json_handler.JsonHandler)
50+
_try_register_handler(
51+
proto_handler.ProtoHandler,
52+
recognized_handler_typestrs=[
53+
'orbax.checkpoint._src.handlers.proto_checkpoint_handler.ProtoCheckpointHandler',
54+
],
55+
)
56+
_try_register_handler(
57+
json_handler.JsonHandler,
58+
recognized_handler_typestrs=[
59+
'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler',
60+
],
61+
)
4662
_try_register_handler(
4763
stateful_checkpointable_handler.StatefulCheckpointableHandler
4864
)
4965
_try_register_handler(
5066
json_handler.MetricsHandler,
5167
checkpoint_layout.METRICS_CHECKPOINTABLE_KEY,
5268
)
53-
_try_register_handler(pytree_handler.PyTreeHandler)
69+
_try_register_handler(
70+
pytree_handler.PyTreeHandler,
71+
recognized_handler_typestrs=[
72+
'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler',
73+
'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler',
74+
],
75+
)
5476
_try_register_handler(
5577
pytree_handler.PyTreeHandler, checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
5678
)

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ def add_all(
7373
) -> CheckpointableHandlerRegistry:
7474
"""Adds all entries from `other_registry` to `registry`."""
7575
for handler, checkpointable in other_registry.get_all_entries():
76-
registry.add(handler, checkpointable)
76+
registry.add(
77+
handler,
78+
checkpointable,
79+
recognized_handler_typestrs=other_registry.get_recognized_handler_typestrs(
80+
handler
81+
),
82+
)
7783
return registry
7884

7985

@@ -87,6 +93,7 @@ def add(
8793
self,
8894
handler_type: Type[CheckpointableHandler],
8995
checkpointable: str | None = None,
96+
recognized_handler_typestrs: Sequence[str] | None = None,
9097
) -> CheckpointableHandlerRegistry:
9198
"""Adds an entry to the registry."""
9299
...
@@ -110,6 +117,13 @@ def get_all_entries(
110117
) -> Sequence[RegistryEntry]:
111118
...
112119

120+
def get_recognized_handler_typestrs(
121+
self,
122+
handler_type: Type[CheckpointableHandler],
123+
) -> Sequence[str]:
124+
"""Returns the recognized handler typestrs for a given handler_type."""
125+
...
126+
113127

114128
class AlreadyExistsError(ValueError):
115129
"""Raised when an entry already exists in the registry."""
@@ -126,6 +140,9 @@ def __init__(
126140
self, other_registry: CheckpointableHandlerRegistry | None = None
127141
):
128142
self._registry: list[RegistryEntry] = []
143+
self._recognized_handler_typestrs: dict[
144+
Type[CheckpointableHandler], Sequence[str]
145+
] = {}
129146

130147
# Initialize the registry with entries from other registry.
131148
if other_registry:
@@ -135,6 +152,7 @@ def add(
135152
self,
136153
handler_type: Type[CheckpointableHandler],
137154
checkpointable: str | None = None,
155+
recognized_handler_typestrs: Sequence[str] | None = None,
138156
) -> CheckpointableHandlerRegistry:
139157
"""Adds an entry to the registry.
140158
@@ -143,6 +161,8 @@ def add(
143161
checkpointable: The checkpointable name. If not-None, the registered
144162
handler will be scoped to that specific name. Otherwise, the handler
145163
will be available for any checkpointable name.
164+
recognized_handler_typestrs: A sequence of alternate typestrs that are
165+
recognized and mapped to this handler.
146166
147167
Returns:
148168
The registry itself.
@@ -170,6 +190,10 @@ def add(
170190
f'Handler type {handler_type} already exists in the registry.'
171191
)
172192
self._registry.append((handler_type, checkpointable))
193+
if recognized_handler_typestrs is not None:
194+
self._recognized_handler_typestrs[handler_type] = (
195+
recognized_handler_typestrs
196+
)
173197
return self
174198

175199
def get(
@@ -220,6 +244,13 @@ def get_all_entries(
220244
"""Returns all entries in the registry."""
221245
return self._registry
222246

247+
def get_recognized_handler_typestrs(
248+
self,
249+
handler_type: Type[CheckpointableHandler],
250+
) -> Sequence[str]:
251+
"""Returns the recognized handler typestrs for a given handler_type."""
252+
return self._recognized_handler_typestrs.get(handler_type, [])
253+
223254
def __repr__(self):
224255
return f'_DefaultCheckpointableHandlerRegistry({self.get_all_entries()})'
225256

@@ -237,6 +268,7 @@ def add(
237268
self,
238269
handler_type: Type[CheckpointableHandler],
239270
checkpointable: str | None = None,
271+
recognized_handler_typestrs: Sequence[str] | None = None,
240272
) -> CheckpointableHandlerRegistry:
241273
raise NotImplementedError('Adding not implemented for read-only registry.')
242274

@@ -257,6 +289,12 @@ def get_all_entries(
257289
) -> Sequence[RegistryEntry]:
258290
return self._registry.get_all_entries()
259291

292+
def get_recognized_handler_typestrs(
293+
self,
294+
handler_type: Type[CheckpointableHandler],
295+
) -> Sequence[str]:
296+
return self._registry.get_recognized_handler_typestrs(handler_type)
297+
260298
def __repr__(self):
261299
return f'ReadOnlyCheckpointableHandlerRegistry({self.get_all_entries()})'
262300

@@ -303,6 +341,8 @@ def local_registry(
303341

304342
def register_handler(
305343
cls: CheckpointableHandlerType,
344+
*,
345+
recognized_handler_typestrs: Sequence[str] | None = None,
306346
) -> CheckpointableHandlerType:
307347
"""Registers a :py:class:`~.v1.handlers.CheckpointableHandler` globally.
308348
@@ -322,11 +362,15 @@ class FooHandler(ocp.handlers.CheckpointableHandler[Foo, AbstractFoo]):
322362
323363
Args:
324364
cls: The handler class.
365+
recognized_handler_typestrs: A sequence of alternate handler typestrs that
366+
are recognized and mapped to this handler.
325367
326368
Returns:
327369
The handler class.
328370
"""
329-
_GLOBAL_REGISTRY.add(cls)
371+
_GLOBAL_REGISTRY.add(
372+
cls, recognized_handler_typestrs=recognized_handler_typestrs
373+
)
330374
return cls
331375

332376

@@ -392,6 +436,16 @@ def _get_possible_handlers(
392436
return possible_handlers
393437

394438

439+
def get_registered_handler_by_name(
440+
registry: CheckpointableHandlerRegistry,
441+
name: str,
442+
) -> CheckpointableHandler | None:
443+
"""Returns the handler for the given name if registered."""
444+
if registry.has(name):
445+
return _construct_handler_instance(name, registry.get(name))
446+
return None
447+
448+
395449
def resolve_handler_for_save(
396450
registry: CheckpointableHandlerRegistry,
397451
checkpointable: Any,
@@ -435,7 +489,7 @@ def is_handleable_fn(handler: CheckpointableHandler, ckpt: Any) -> bool:
435489
registry, is_handleable_fn, checkpointable, name
436490
)
437491

438-
# Prefer the first handler in the absence of any other information.
492+
# Prefer the last handler in the absence of any other information.
439493
return possible_handlers[-1]
440494

441495

@@ -444,7 +498,7 @@ def resolve_handler_for_load(
444498
abstract_checkpointable: Any | None,
445499
*,
446500
name: str,
447-
handler_typestr: str,
501+
handler_typestr: str | None = None,
448502
) -> CheckpointableHandler:
449503
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading.
450504
@@ -471,7 +525,9 @@ def resolve_handler_for_load(
471525
abstract_checkpointable: An abstract checkpointable to resolve.
472526
name: The name of the checkpointable.
473527
handler_typestr: A :py:class:`~.v1.handlers.CheckpointableHandler` typestr
474-
to guide resolution.
528+
to guide resolution. We allow a None value for handler_typestr as its
529+
possible to find the last registered handler given a specified
530+
abstract_checkpointable.
475531
476532
Returns:
477533
A :py:class:`~.v1.handlers.CheckpointableHandler` instance.
@@ -492,15 +548,34 @@ def is_handleable_fn(
492548
handler_types.typestr(type(handler)) for handler in possible_handlers
493549
]
494550

495-
try:
496-
idx = possible_handler_typestrs.index(handler_typestr)
497-
return possible_handlers[idx]
498-
except ValueError:
551+
if handler_typestr:
552+
if handler_typestr in possible_handler_typestrs:
553+
idx = possible_handler_typestrs.index(handler_typestr)
554+
return possible_handlers[idx]
555+
556+
# Check if handler_typestr is recognized by any possible handler.
557+
# Check backwards to prioritize most recently added handlers.
558+
for i in reversed(range(len(possible_handlers))):
559+
if handler_typestr in registry.get_recognized_handler_typestrs(
560+
type(possible_handlers[i])
561+
):
562+
return possible_handlers[i]
563+
564+
# 3. If neither worked, log the warning and fall through.
499565
logging.warning(
500-
'No handler found for typestr %s. The checkpointable may be restored'
501-
' with different handler logic than was used for saving.',
566+
'No handler found for typestr %s (or its converted form). The '
567+
'checkpointable may be restored with different handler logic '
568+
'than was used for saving.',
502569
handler_typestr,
503570
)
504571

505-
# Prefer the first handler in the absence of any other information.
506-
return possible_handlers[-1]
572+
if abstract_checkpointable:
573+
# Prefer the last handler in the absence of any other information.
574+
return possible_handlers[-1]
575+
576+
raise NoEntryError(
577+
f'No entry for checkpointable={name} in the registry, using'
578+
f' handler_typestr={handler_typestr} and'
579+
f' abstract_checkpointable={abstract_checkpointable}. Registry contents:'
580+
f' {registry.get_all_entries()}'
581+
)

0 commit comments

Comments
 (0)