2424import orbax .checkpoint as ocp
2525from orbax .checkpoint import args as args_lib
2626from orbax .checkpoint import checkpoint_manager
27+ from orbax .checkpoint import checkpoint_utils
2728from orbax .checkpoint import type_handlers
2829from orbax .checkpoint ._src .multihost import multihost
2930from orbax .checkpoint ._src .serialization import type_handler_registry
@@ -109,6 +110,23 @@ class LocalPyGrainRestore(utils.pygrain().PyGrainCheckpointRestore):
109110 item : Any
110111
111112
113+ def _prepare_state_restore_args (
114+ state : args_lib .PyTreeRestore ,
115+ ) -> args_lib .PyTreeRestore :
116+ """Ensures restore_args are populated and converted to ArrayRestoreArgs."""
117+ if state .item is None :
118+ return state
119+
120+ restore_args = jax .tree .map (
121+ lambda x : type_handlers .ArrayRestoreArgs (sharding = x .sharding )
122+ if isinstance (x , jax .ShapeDtypeStruct )
123+ else checkpoint_utils .construct_restore_args (x ),
124+ state .item ,
125+ )
126+
127+ return args_lib .PyTreeRestore (item = state .item , restore_args = restore_args )
128+
129+
112130@final
113131class LocalCheckpointManager :
114132 """Wrapper around Orbax CheckpointManager for local P2P shards."""
@@ -243,6 +261,9 @@ def restore(
243261 ) -> p2p_args_lib .Composite :
244262 """Restores the checkpoint, enforcing process identity check."""
245263 # No need to check for P2P restore directory
264+ if args is None :
265+ raise ValueError ('args must be provided for local restore.' )
266+
246267 if directory is None :
247268 # 1. Fast Fail: Verify Process Identity
248269 stored_index = utils .detect_process_index (self ._directory , step )
@@ -256,13 +277,16 @@ def restore(
256277 )
257278 raise ValueError (error_msg )
258279
280+ args_dict = dict (args .items ())
281+ new_state = _prepare_state_restore_args (args .state )
282+ args_dict ['state' ] = new_state
283+
259284 if utils .pygrain () is not None and args and constants .DATA_ITER_KEY in args :
260285 original_restore = args [constants .DATA_ITER_KEY ]
261- args_dict = dict (args .items ())
262286 args_dict [constants .DATA_ITER_KEY ] = LocalPyGrainRestore (
263287 original_restore .item
264288 )
265- args = args_lib .Composite (** args_dict )
289+ args = args_lib .Composite (** args_dict )
266290
267291 # 2. Delegate to Orbax
268292 restored = self ._manager .restore (
0 commit comments