Skip to content

Commit f486910

Browse files
author
Orbax Authors
committed
#p2p Enable SingleReplicaRestore for PersistentCheckpointManager
PiperOrigin-RevId: 869899634
1 parent cf5a1ce commit f486910

File tree

3 files changed

+133
-74
lines changed

3 files changed

+133
-74
lines changed

checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,7 @@ async def _single_replica_deserialize_and_broadcast(
13461346
deserialization_elapsed_s,
13471347
)
13481348
logging.info(
1349-
'Finished primary replica deserialization in %.2f',
1349+
'Finished primary replica deserialization in %.2f seconds',
13501350
deserialization_elapsed_s,
13511351
)
13521352
else:
@@ -1379,7 +1379,7 @@ def create_zeros(shape_dtype_tup):
13791379
jax.monitoring.record_event_duration_secs(
13801380
'/jax/checkpoint/read/broadcast_duration_secs', broadcast_elapsed_s
13811381
)
1382-
logging.info('Finished broadcasting in %.2f', broadcast_elapsed_s)
1382+
logging.info('Finished broadcasting in %.2f seconds', broadcast_elapsed_s)
13831383

13841384
return shared_state
13851385

checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py

Lines changed: 72 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from absl import logging
2020
from etils import epath
2121
import jax
22+
import numpy as np
2223
import orbax.checkpoint as ocp
2324
from orbax.checkpoint import args as args_lib
2425
from orbax.checkpoint import checkpoint_manager
@@ -38,23 +39,38 @@
3839

3940
def _create_persistent_handler(
4041
mp_options: checkpoint_manager.MultiprocessingOptions,
42+
replica_axis_index: int,
43+
is_single_slice: bool,
4144
) -> ocp.PyTreeCheckpointHandler:
4245
"""Creates a PyTreeCheckpointHandler for persistent storage.
4346
4447
Args:
4548
mp_options: Multiprocessing options for the checkpoint handler.
49+
replica_axis_index: The index of the replica axis in the mesh.
50+
is_single_slice: Whether the mesh is single-slice.
4651
4752
Returns:
4853
A PyTreeCheckpointHandler configured for persistent storage.
4954
"""
50-
registry = type_handler_registry.create_type_handler_registry((
51-
jax.Array,
52-
type_handlers.ArrayHandler(
53-
primary_host=mp_options.primary_host,
54-
replica_id=_PRIMARY_REPLICA_ID,
55-
use_replica_parallel=False,
55+
handler = type_handlers.SingleReplicaArrayHandler(
56+
replica_axis_index=replica_axis_index,
57+
broadcast_memory_limit_bytes=1024 * 1024 * 1000,
58+
primary_host=mp_options.primary_host,
59+
replica_id=_PRIMARY_REPLICA_ID,
60+
use_replica_parallel=False,
61+
)
62+
if is_single_slice:
63+
handler = type_handlers.ArrayHandler(
64+
primary_host=mp_options.primary_host,
65+
replica_id=_PRIMARY_REPLICA_ID,
66+
use_replica_parallel=False,
67+
)
68+
registry = type_handler_registry.create_type_handler_registry(
69+
(
70+
jax.Array,
71+
handler,
5672
),
57-
))
73+
)
5874
return ocp.PyTreeCheckpointHandler(
5975
use_ocdbt=True,
6076
use_zarr3=True,
@@ -84,28 +100,10 @@ def __init__(
84100
self._global_mesh,
85101
replica_axis_index=self._replica_axis_index,
86102
)
87-
self._in_primary_slice = multislice.in_replica(
88-
self._process_index,
89-
global_mesh,
90-
replica_axis_index=self._replica_axis_index,
91-
replica_id=_PRIMARY_REPLICA_ID,
92-
)
93-
94-
replica_devices = multislice.replica_devices(
95-
self._global_mesh,
96-
replica_axis_index=self._replica_axis_index,
97-
replica_id=self._replica_id,
98-
)
99-
primary_host = multislice.primary_process_in_replica(
100-
self._global_mesh,
101-
replica_axis_index=self._replica_axis_index,
102-
replica_id=self._replica_id,
103-
)
104-
active_processes = multihost.unique_processes_from_devices(replica_devices)
105103
mp_options = checkpoint_manager.MultiprocessingOptions(
106-
primary_host=primary_host,
107-
active_processes=active_processes,
108-
barrier_sync_key_prefix=f'persistent_fallback_{self._replica_id}',
104+
primary_host=0,
105+
active_processes=None,
106+
barrier_sync_key_prefix='persistent_fallback',
109107
)
110108

111109
internal_options = checkpoint_manager.CheckpointManagerOptions(
@@ -117,7 +115,13 @@ def __init__(
117115
enable_async_checkpointing=True,
118116
)
119117

120-
item_handlers = dict(state=_create_persistent_handler(mp_options))
118+
item_handlers = dict(
119+
state=_create_persistent_handler(
120+
mp_options,
121+
self._replica_axis_index,
122+
self._global_mesh.devices.shape[self._replica_axis_index] == 1,
123+
)
124+
)
121125
if utils.pygrain() is not None:
122126
item_handlers['data_iter'] = utils.pygrain().PyGrainCheckpointHandler()
123127

@@ -141,9 +145,7 @@ def save(
141145
*,
142146
force: bool = False,
143147
) -> bool:
144-
if self._in_primary_slice:
145-
return self._manager.save(step, args=args, force=force)
146-
return True
148+
return self._manager.save(step, args=args, force=force)
147149

148150
def restore(
149151
self,
@@ -166,14 +168,46 @@ def restore(
166168
self._replica_id,
167169
)
168170
abstract_state = args.state
171+
if isinstance(args.state, args_lib.PyTreeRestore):
172+
abstract_state = args.state.item
173+
174+
primary_replica_devices_list = multislice.replica_devices(
175+
self._global_mesh,
176+
replica_axis_index=self._replica_axis_index,
177+
replica_id=_PRIMARY_REPLICA_ID,
178+
)
179+
replica_mesh_shape = list(self._global_mesh.devices.shape)
180+
replica_mesh_shape[self._replica_axis_index] = 1
181+
primary_replica_devices = np.array(primary_replica_devices_list).reshape(
182+
replica_mesh_shape
183+
)
184+
primary_replica_mesh = jax.sharding.Mesh(
185+
primary_replica_devices, axis_names=self._global_mesh.axis_names
186+
)
187+
188+
def _get_sr_restore_args(x):
189+
if (
190+
self._global_mesh.devices.shape[self._replica_axis_index] > 1
191+
and isinstance(x, jax.ShapeDtypeStruct)
192+
and isinstance(x.sharding, jax.sharding.NamedSharding)
193+
):
194+
single_replica_sharding = jax.sharding.NamedSharding(
195+
primary_replica_mesh, x.sharding.spec
196+
)
197+
return type_handlers.SingleReplicaArrayRestoreArgs(
198+
sharding=x.sharding,
199+
single_replica_sharding=single_replica_sharding,
200+
global_shape=x.shape,
201+
dtype=x.dtype,
202+
)
203+
else:
204+
return checkpoint_utils.construct_restore_args(x)
205+
206+
restore_args_tree = jax.tree.map(_get_sr_restore_args, abstract_state)
169207

170-
sharding_tree = jax.tree.map(lambda x: x.sharding, abstract_state)
171-
# TODO(exlin): Enable SingleReplicaRestore.
172208
restore_args_obj = args_lib.PyTreeRestore(
173209
item=abstract_state,
174-
restore_args=checkpoint_utils.construct_restore_args(
175-
abstract_state, sharding_tree
176-
),
210+
restore_args=restore_args_tree,
177211
)
178212
restore_kwargs = {'state': restore_args_obj}
179213
if constants.DATA_ITER_KEY in args:
@@ -183,8 +217,7 @@ def restore(
183217
)
184218

185219
def delete(self, step: int):
186-
if self._in_primary_slice:
187-
self._manager.delete(step)
220+
self._manager.delete(step)
188221

189222
def wait_until_finished(self):
190223
self._manager.wait_until_finished()

0 commit comments

Comments
 (0)