Skip to content

Commit 441b89a

Browse files
author
Orbax Authors
committed
#p2p Propagate sharding to allow restore with resize
PiperOrigin-RevId: 877213342
1 parent 99bfb4b commit 441b89a

File tree

1 file changed

+26
-2
lines changed
  • checkpoint/orbax/checkpoint/experimental/emergency/p2p

1 file changed

+26
-2
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/p2p/local.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import orbax.checkpoint as ocp
2525
from orbax.checkpoint import args as args_lib
2626
from orbax.checkpoint import checkpoint_manager
27+
from orbax.checkpoint import checkpoint_utils
2728
from orbax.checkpoint import type_handlers
2829
from orbax.checkpoint._src.multihost import multihost
2930
from orbax.checkpoint._src.serialization import type_handler_registry
@@ -109,6 +110,23 @@ class LocalPyGrainRestore(utils.pygrain().PyGrainCheckpointRestore):
109110
item: Any
110111

111112

113+
def _prepare_state_restore_args(
114+
state: args_lib.PyTreeRestore,
115+
) -> args_lib.PyTreeRestore:
116+
"""Ensures restore_args are populated and converted to ArrayRestoreArgs."""
117+
if state.item is None:
118+
return state
119+
120+
restore_args = jax.tree.map(
121+
lambda x: type_handlers.ArrayRestoreArgs(sharding=x.sharding)
122+
if isinstance(x, jax.ShapeDtypeStruct)
123+
else checkpoint_utils.construct_restore_args(x),
124+
state.item,
125+
)
126+
127+
return args_lib.PyTreeRestore(item=state.item, restore_args=restore_args)
128+
129+
112130
@final
113131
class LocalCheckpointManager:
114132
"""Wrapper around Orbax CheckpointManager for local P2P shards."""
@@ -243,6 +261,9 @@ def restore(
243261
) -> p2p_args_lib.Composite:
244262
"""Restores the checkpoint, enforcing process identity check."""
245263
# No need to check for P2P restore directory
264+
if args is None:
265+
raise ValueError('args must be provided for local restore.')
266+
246267
if directory is None:
247268
# 1. Fast Fail: Verify Process Identity
248269
stored_index = utils.detect_process_index(self._directory, step)
@@ -256,13 +277,16 @@ def restore(
256277
)
257278
raise ValueError(error_msg)
258279

280+
args_dict = dict(args.items())
281+
new_state = _prepare_state_restore_args(args.state)
282+
args_dict['state'] = new_state
283+
259284
if utils.pygrain() is not None and args and constants.DATA_ITER_KEY in args:
260285
original_restore = args[constants.DATA_ITER_KEY]
261-
args_dict = dict(args.items())
262286
args_dict[constants.DATA_ITER_KEY] = LocalPyGrainRestore(
263287
original_restore.item
264288
)
265-
args = args_lib.Composite(**args_dict)
289+
args = args_lib.Composite(**args_dict)
266290

267291
# 2. Delegate to Orbax
268292
restored = self._manager.restore(

0 commit comments

Comments
 (0)