1919from absl import logging
2020from etils import epath
2121import jax
22+ import numpy as np
2223import orbax .checkpoint as ocp
2324from orbax .checkpoint import args as args_lib
2425from orbax .checkpoint import checkpoint_manager
3839
3940def _create_persistent_handler (
4041 mp_options : checkpoint_manager .MultiprocessingOptions ,
42+ replica_axis_index : int ,
43+ is_single_slice : bool ,
4144) -> ocp .PyTreeCheckpointHandler :
4245 """Creates a PyTreeCheckpointHandler for persistent storage.
4346
4447 Args:
4548 mp_options: Multiprocessing options for the checkpoint handler.
49+ replica_axis_index: The index of the replica axis in the mesh.
50+ is_single_slice: Whether the mesh is single-slice.
4651
4752 Returns:
4853 A PyTreeCheckpointHandler configured for persistent storage.
4954 """
50- registry = type_handler_registry .create_type_handler_registry ((
51- jax .Array ,
52- type_handlers .ArrayHandler (
53- primary_host = mp_options .primary_host ,
54- replica_id = _PRIMARY_REPLICA_ID ,
55- use_replica_parallel = False ,
55+ handler = type_handlers .SingleReplicaArrayHandler (
56+ replica_axis_index = replica_axis_index ,
57+ broadcast_memory_limit_bytes = 1024 * 1024 * 1000 ,
58+ primary_host = mp_options .primary_host ,
59+ replica_id = _PRIMARY_REPLICA_ID ,
60+ use_replica_parallel = False ,
61+ )
62+ if is_single_slice :
63+ handler = type_handlers .ArrayHandler (
64+ primary_host = mp_options .primary_host ,
65+ replica_id = _PRIMARY_REPLICA_ID ,
66+ use_replica_parallel = False ,
67+ )
68+ registry = type_handler_registry .create_type_handler_registry (
69+ (
70+ jax .Array ,
71+ handler ,
5672 ),
57- ))
73+ )
5874 return ocp .PyTreeCheckpointHandler (
5975 use_ocdbt = True ,
6076 use_zarr3 = True ,
@@ -84,28 +100,10 @@ def __init__(
84100 self ._global_mesh ,
85101 replica_axis_index = self ._replica_axis_index ,
86102 )
87- self ._in_primary_slice = multislice .in_replica (
88- self ._process_index ,
89- global_mesh ,
90- replica_axis_index = self ._replica_axis_index ,
91- replica_id = _PRIMARY_REPLICA_ID ,
92- )
93-
94- replica_devices = multislice .replica_devices (
95- self ._global_mesh ,
96- replica_axis_index = self ._replica_axis_index ,
97- replica_id = self ._replica_id ,
98- )
99- primary_host = multislice .primary_process_in_replica (
100- self ._global_mesh ,
101- replica_axis_index = self ._replica_axis_index ,
102- replica_id = self ._replica_id ,
103- )
104- active_processes = multihost .unique_processes_from_devices (replica_devices )
105103 mp_options = checkpoint_manager .MultiprocessingOptions (
106- primary_host = primary_host ,
107- active_processes = active_processes ,
108- barrier_sync_key_prefix = f'persistent_fallback_ { self . _replica_id } ' ,
104+ primary_host = 0 ,
105+ active_processes = None ,
106+ barrier_sync_key_prefix = 'persistent_fallback ' ,
109107 )
110108
111109 internal_options = checkpoint_manager .CheckpointManagerOptions (
@@ -117,7 +115,13 @@ def __init__(
117115 enable_async_checkpointing = True ,
118116 )
119117
120- item_handlers = dict (state = _create_persistent_handler (mp_options ))
118+ item_handlers = dict (
119+ state = _create_persistent_handler (
120+ mp_options ,
121+ self ._replica_axis_index ,
122+ self ._global_mesh .devices .shape [self ._replica_axis_index ] == 1 ,
123+ )
124+ )
121125 if utils .pygrain () is not None :
122126 item_handlers ['data_iter' ] = utils .pygrain ().PyGrainCheckpointHandler ()
123127
@@ -141,9 +145,7 @@ def save(
141145 * ,
142146 force : bool = False ,
143147 ) -> bool :
144- if self ._in_primary_slice :
145- return self ._manager .save (step , args = args , force = force )
146- return True
148+ return self ._manager .save (step , args = args , force = force )
147149
148150 def restore (
149151 self ,
@@ -166,14 +168,46 @@ def restore(
166168 self ._replica_id ,
167169 )
168170 abstract_state = args .state
171+ if isinstance (args .state , args_lib .PyTreeRestore ):
172+ abstract_state = args .state .item
173+
174+ primary_replica_devices_list = multislice .replica_devices (
175+ self ._global_mesh ,
176+ replica_axis_index = self ._replica_axis_index ,
177+ replica_id = _PRIMARY_REPLICA_ID ,
178+ )
179+ replica_mesh_shape = list (self ._global_mesh .devices .shape )
180+ replica_mesh_shape [self ._replica_axis_index ] = 1
181+ primary_replica_devices = np .array (primary_replica_devices_list ).reshape (
182+ replica_mesh_shape
183+ )
184+ primary_replica_mesh = jax .sharding .Mesh (
185+ primary_replica_devices , axis_names = self ._global_mesh .axis_names
186+ )
187+
188+ def _get_sr_restore_args (x ):
189+ if (
190+ self ._global_mesh .devices .shape [self ._replica_axis_index ] > 1
191+ and isinstance (x , jax .ShapeDtypeStruct )
192+ and isinstance (x .sharding , jax .sharding .NamedSharding )
193+ ):
194+ single_replica_sharding = jax .sharding .NamedSharding (
195+ primary_replica_mesh , x .sharding .spec
196+ )
197+ return type_handlers .SingleReplicaArrayRestoreArgs (
198+ sharding = x .sharding ,
199+ single_replica_sharding = single_replica_sharding ,
200+ global_shape = x .shape ,
201+ dtype = x .dtype ,
202+ )
203+ else :
204+ return checkpoint_utils .construct_restore_args (x )
205+
206+ restore_args_tree = jax .tree .map (_get_sr_restore_args , abstract_state )
169207
170- sharding_tree = jax .tree .map (lambda x : x .sharding , abstract_state )
171- # TODO(exlin): Enable SingleReplicaRestore.
172208 restore_args_obj = args_lib .PyTreeRestore (
173209 item = abstract_state ,
174- restore_args = checkpoint_utils .construct_restore_args (
175- abstract_state , sharding_tree
176- ),
210+ restore_args = restore_args_tree ,
177211 )
178212 restore_kwargs = {'state' : restore_args_obj }
179213 if constants .DATA_ITER_KEY in args :
@@ -183,8 +217,7 @@ def restore(
183217 )
184218
185219 def delete (self , step : int ):
186- if self ._in_primary_slice :
187- self ._manager .delete (step )
220+ self ._manager .delete (step )
188221
189222 def wait_until_finished (self ):
190223 self ._manager .wait_until_finished ()
0 commit comments