1313# limitations under the License.
1414
1515"""Initialization for multi-tier checkpointing."""
16+
1617import os
1718import time
1819from typing import List , Optional
1920
2021from absl import logging
2122from etils import epath
2223import jax
24+ from jax .experimental import colocated_python
25+ import numpy as np
2326from orbax .checkpoint ._src .multihost import multihost
2427from orbax .checkpoint ._src .multihost import multislice
2528
29+
2630_REPLICATOR_FILE = 'replicator.yaml'
2731_TEMP_REPLICATOR_FILE_NAME = _REPLICATOR_FILE + '.tmp'
2832_JAX_INIT_INFO_FILE = 'jax-init-info.txt'
@@ -70,6 +74,108 @@ def _create_replicator_file(
7074 os .rename (temp_file , replicator_file )
7175
7276
77+ def _initialize_mtc_colocated (
78+ local_checkpoint_directory : epath .Path ,
79+ backup_interval_minutes : int ,
80+ num_slices : int ,
81+ run_name : str ,
82+ data_parallelism : int ,
83+ timeout_seconds : int ,
84+ ):
85+ """Initializes multi-tier checkpointing with Colocated Python.
86+
87+ Args:
88+ local_checkpoint_directory: The local checkpoint directory.
89+ backup_interval_minutes: The backup interval in minutes.
90+ num_slices: The number of slices.
91+ run_name: The run name.
92+ data_parallelism: The data parallelism.
93+ timeout_seconds: The timeout in seconds.
94+ """
95+ # 1. Obtain CPU devices for all remote hosts
96+ cpu_devices = colocated_python .colocated_cpu_devices (jax .devices ())
97+
98+ # Ensure one CPU device per process (worker node).
99+ unique_cpu_devices = list (
100+ {dev .process_index : dev for dev in cpu_devices }.values ()
101+ )
102+ num_nodes = len (unique_cpu_devices )
103+ nodes_per_slice = max (1 , num_nodes // num_slices )
104+
105+ # 2. Pre-calculate the node_rank and peer_ranks for EVERY node
106+ all_node_ranks = np .arange (num_nodes , dtype = np .int32 )
107+ all_peer_ranks = []
108+ for nr in range (num_nodes ):
109+ my_in_pipeline_index = nr % nodes_per_slice
110+ peers = [
111+ i * nodes_per_slice + my_in_pipeline_index
112+ for i in range (num_slices )
113+ if (i * nodes_per_slice + my_in_pipeline_index ) != nr
114+ ]
115+ all_peer_ranks .append (peers )
116+
117+ # Handle single-slice edge case where peers list is empty
118+ if not all_peer_ranks [0 ]:
119+ all_peer_ranks = np .zeros ((num_nodes , 0 ), dtype = np .int32 )
120+ else :
121+ all_peer_ranks = np .array (all_peer_ranks , dtype = np .int32 )
122+
123+ # 3. Create a 1D Mesh over the remote hosts and shard the configuration arrays
124+ cpu_mesh = jax .sharding .Mesh (np .array (unique_cpu_devices ), ('d' ,))
125+ sharding = jax .sharding .NamedSharding (
126+ cpu_mesh , jax .sharding .PartitionSpec ('d' )
127+ )
128+
129+ # JAX distributes these arrays across the workers natively
130+ sharded_node_ranks = jax .device_put (all_node_ranks , sharding )
131+ sharded_peer_ranks = jax .device_put (all_peer_ranks , sharding )
132+
133+ # 4. Define the SPMD closure that runs on each remote worker
134+ def _setup (local_nr_arr , local_pr_arr ):
135+ loc_dir = epath .Path (local_checkpoint_directory )
136+
137+ # JAX sharding slices the arrays into chunks of shape (1,) and (1, P).
138+ # We must index at [0] to extract the pure scalar and the flat list!
139+ node_rank = int (np .asarray (local_nr_arr )[0 ])
140+ peer_ranks = np .asarray (local_pr_arr )[0 ].tolist ()
141+
142+ _wait_for_replicator_file_to_disappear (
143+ loc_dir , timeout_seconds = timeout_seconds
144+ )
145+
146+ _create_replicator_file (
147+ loc_dir ,
148+ run_name = run_name ,
149+ num_nodes = num_nodes ,
150+ data_parallelism = data_parallelism ,
151+ node_rank = node_rank ,
152+ peer_ranks = peer_ranks ,
153+ backup_interval_minutes = backup_interval_minutes ,
154+ )
155+
156+ _wait_for_replicator_file_to_disappear (
157+ loc_dir , timeout_seconds = timeout_seconds
158+ )
159+ _block_and_process_restore_dir (loc_dir , timeout_seconds = timeout_seconds )
160+
161+ # Return array to satisfy SPMD device matching
162+ return local_nr_arr
163+
164+ # 5. Wrap and dispatch using native JAX SPMD!
165+ wrapped_setup_fn = colocated_python .colocated_python (_setup )
166+ wrapped_setup_fn = wrapped_setup_fn .specialize (out_specs_fn = lambda x , y : x )
167+
168+ # Triggers concurrent execution across all nodes without a thread pool
169+ jax .block_until_ready (
170+ wrapped_setup_fn (sharded_node_ranks , sharded_peer_ranks )
171+ )
172+
173+ logging .info (
174+ 'Successfully initialized multi-tier checkpointing on all remote hosts '
175+ 'via Colocated Python.'
176+ )
177+
178+
73179def _initialize_jax_from_mtc (
74180 local_checkpoint_directory : epath .Path ,
75181 jax_initialization_timeout_seconds : int = 900 ,
@@ -107,6 +213,7 @@ def initialize_multi_tier_checkpointing(
107213 data_parallelism : Optional [int ] = None ,
108214 jax_initialization_timeout_seconds : int = 900 ,
109215 use_mtc_process_ids : bool = True ,
216+ use_colocated_python : bool = False ,
110217):
111218 """Initializes multi-tier checkpointing.
112219
@@ -116,12 +223,34 @@ def initialize_multi_tier_checkpointing(
116223 minutes.
117224 num_slices: The number of slices.
118225 run_name: The name of the run.
119- data_parallelism: Number of identical pipelines in job, should be
120- equal to ICI data parallelism * DCN data parallelism. If not provided, it
121- will be inferred from the number of slices.
226+ data_parallelism: Number of identical pipelines in job, should be equal to
227+ ICI data parallelism * DCN data parallelism. If not provided, it will be
228+ inferred from the number of slices.
122229 jax_initialization_timeout_seconds: The timeout for JAX initialization.
123230 use_mtc_process_ids: Use the MTC rank server to calculate process ids.
231+ use_colocated_python: Whether to use Colocated Python for initialization.
124232 """
233+ run_name = run_name if run_name else os .environ .get ('JOBSET_NAME' )
234+ num_slices = num_slices or multislice .slice_count ()
235+ data_parallelism = data_parallelism or num_slices
236+ if not run_name :
237+ raise ValueError (
238+ 'Run name is not set and JOBSET_NAME is not set in the environment.'
239+ )
240+
241+ if use_colocated_python :
242+ logging .info ('Initializing multi-tier checkpointing via Colocated Python.' )
243+ _initialize_mtc_colocated (
244+ local_checkpoint_directory = local_checkpoint_directory ,
245+ backup_interval_minutes = backup_interval_minutes ,
246+ num_slices = num_slices ,
247+ run_name = run_name ,
248+ data_parallelism = data_parallelism ,
249+ timeout_seconds = jax_initialization_timeout_seconds ,
250+ )
251+ return
252+
253+ # Standard Multi-Controller Path
125254 if use_mtc_process_ids :
126255 process_id = _initialize_jax_from_mtc (
127256 local_checkpoint_directory , jax_initialization_timeout_seconds
@@ -135,14 +264,9 @@ def initialize_multi_tier_checkpointing(
135264 multihost .initialize_runtime_to_distributed_ids ()
136265 multihost .initialize_distributed_to_device_ids ()
137266 _wait_for_replicator_file_to_disappear (local_checkpoint_directory )
138- num_slices = (
139- num_slices
140- or multislice .slice_count ()
141- )
142267 num_nodes = jax .process_count ()
143268 nodes_per_slice = num_nodes // num_slices
144269 node_rank = jax ._src .distributed .global_state .process_id # pylint: disable=protected-access
145- data_parallelism = data_parallelism or num_slices
146270 my_process_index = jax .process_index ()
147271 process_index_to_node_rank = (
148272 multihost .runtime_to_distributed_ids ()
@@ -173,11 +297,7 @@ def initialize_multi_tier_checkpointing(
173297 peer_process_rank = process_index_to_node_rank [peer_process_index ]
174298 peer_ranks .append (peer_process_rank )
175299 logging .info ('Peers for NodeRank %s: %s' , node_rank , peer_ranks )
176- run_name = run_name if run_name else os .environ .get ('JOBSET_NAME' )
177- if not run_name :
178- raise ValueError (
179- 'Run name is not set and JOBSET_NAME is not set in the environment.'
180- )
300+
181301 _create_replicator_file (
182302 local_checkpoint_directory ,
183303 run_name = run_name ,
0 commit comments