Skip to content

Commit 0738d47

Browse files
author
Orbax Authors
committed
#p2p Specify all processes in multiprocess option for local
PiperOrigin-RevId: 872465040
1 parent 4f7a86b commit 0738d47

File tree

1 file changed

+11
-4
lines changed
  • checkpoint/orbax/checkpoint/experimental/emergency/p2p

1 file changed

+11
-4
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/p2p/local.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,18 @@ def __init__(
123123
self._global_mesh = global_mesh
124124
self._process_index = multihost.process_index()
125125

126-
barrier_sync_key_prefix = f'p2p_shard_{self._process_index}'
126+
# While forming smaller process groups for synchronization via
127+
# `active_processes` (e.g., per-replica) might improve barrier
128+
# performance at scale, it would require more complex coordination
129+
# than currently implemented.
127130
mp_options = ocp.options.MultiprocessingOptions(
128-
primary_host=None, # Symmetric read/write
129-
active_processes={self._process_index}, # Only I write to my shard
130-
barrier_sync_key_prefix=barrier_sync_key_prefix,
131+
# `primary_host` is None because all hosts save to local storage
132+
# independently. This causes local checkpoint to be saved on all
133+
# processes.
134+
# It is the caller's responsibility to make sure we are not doubling the
135+
# memory pressure by saving persistent on the same step.
136+
primary_host=None,
137+
barrier_sync_key_prefix='local',
131138
)
132139

133140
p2p_specific_options = checkpoint_manager.CheckpointManagerOptions(

0 commit comments

Comments
 (0)