Skip to content

Commit 98d2e85

Browse files
SujeethJineshOrbax Authors
authored andcommitted
Add Multi-tiered Checkpointing Support to Pathways Single Controller
PiperOrigin-RevId: 874359124
1 parent 6f4d675 commit 98d2e85

File tree

7 files changed

+320
-59
lines changed

7 files changed

+320
-59
lines changed

checkpoint/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- #v1 Add `use_load_and_broadcast` option.
13+
- Add Multi-tiered checkpointing support for Pathways
1314

1415
### Removed
1516

checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,11 @@ def _cp_wrapper(inp: PyTree) -> PyTree:
335335
input_arrays, abstract=True
336336
)
337337
cpu_result_specs = self._transform_pytree_shardings(result_specs)
338-
_cp_wrapper.specialize(out_specs_fn=lambda _: cpu_result_specs)
338+
specialized_wrapper = _cp_wrapper.specialize(
339+
out_specs_fn=lambda _: cpu_result_specs
340+
)
339341

340-
result = _cp_wrapper(self.to_colocated_python(input_arrays))
342+
result = specialized_wrapper(self.to_colocated_python(input_arrays))
341343
return self._to_final_specs(result, result_specs)
342344

343345

checkpoint/orbax/checkpoint/experimental/emergency/mesh_consistency.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,34 @@ def read_process_metadata(directory: epath.Path):
6565
return distributed_to_device_ids, device_ids
6666

6767

68-
async def save_process_metadata(
68+
def write_process_metadata(
6969
directory: epath.Path,
70-
global_mesh: jax.sharding.Mesh,
70+
device_ids: List[int],
7171
distributed_to_device_ids: List[List[int]],
7272
):
73-
"""Saves process metadata to local storage. Runs on every process."""
73+
"""Synchronously writes process metadata to local storage."""
7474
metadata_folder = process_metadata_folder(directory)
75+
metadata_folder.mkdir(parents=True, exist_ok=True)
7576
logging.info('Saving process index metadata at %s', metadata_folder)
7677

7778
(metadata_folder / _GLOBAL_PROCESS_METADATA_FILE_NAME).write_text(
7879
json.dumps(distributed_to_device_ids)
7980
)
8081
(metadata_folder / _MESH_METADATA_FILE_NAME).write_text(
81-
json.dumps([int(id) for id in global_mesh.device_ids.flatten()])
82+
json.dumps(device_ids)
8283
)
8384

8485

86+
async def save_process_metadata(
87+
directory: epath.Path,
88+
global_mesh: jax.sharding.Mesh,
89+
distributed_to_device_ids: List[List[int]],
90+
):
91+
"""Saves process metadata to local storage. Runs on every process."""
92+
device_ids = [int(id) for id in global_mesh.device_ids.flatten()]
93+
write_process_metadata(directory, device_ids, distributed_to_device_ids)
94+
95+
8596
def consistent_restore_mesh_from_metadata(
8697
global_mesh: jax.sharding.Mesh,
8798
current_distributed_to_device_ids: List[List[int]],

checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization.py

Lines changed: 133 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
# limitations under the License.
1414

1515
"""Initialization for multi-tier checkpointing."""
16+
1617
import os
1718
import time
1819
from typing import List, Optional
1920

2021
from absl import logging
2122
from etils import epath
2223
import jax
24+
from jax.experimental import colocated_python
25+
import numpy as np
2326
from orbax.checkpoint._src.multihost import multihost
2427
from orbax.checkpoint._src.multihost import multislice
2528

29+
2630
_REPLICATOR_FILE = 'replicator.yaml'
2731
_TEMP_REPLICATOR_FILE_NAME = _REPLICATOR_FILE + '.tmp'
2832
_JAX_INIT_INFO_FILE = 'jax-init-info.txt'
@@ -70,6 +74,108 @@ def _create_replicator_file(
7074
os.rename(temp_file, replicator_file)
7175

7276

77+
def _initialize_mtc_colocated(
78+
local_checkpoint_directory: epath.Path,
79+
backup_interval_minutes: int,
80+
num_slices: int,
81+
run_name: str,
82+
data_parallelism: int,
83+
timeout_seconds: int,
84+
):
85+
"""Initializes multi-tier checkpointing with Colocated Python.
86+
87+
Args:
88+
local_checkpoint_directory: The local checkpoint directory.
89+
backup_interval_minutes: The backup interval in minutes.
90+
num_slices: The number of slices.
91+
run_name: The run name.
92+
data_parallelism: The data parallelism.
93+
timeout_seconds: The timeout in seconds.
94+
"""
95+
# 1. Obtain CPU devices for all remote hosts
96+
cpu_devices = colocated_python.colocated_cpu_devices(jax.devices())
97+
98+
# Ensure one CPU device per process (worker node).
99+
unique_cpu_devices = list(
100+
{dev.process_index: dev for dev in cpu_devices}.values()
101+
)
102+
num_nodes = len(unique_cpu_devices)
103+
nodes_per_slice = max(1, num_nodes // num_slices)
104+
105+
# 2. Pre-calculate the node_rank and peer_ranks for EVERY node
106+
all_node_ranks = np.arange(num_nodes, dtype=np.int32)
107+
all_peer_ranks = []
108+
for nr in range(num_nodes):
109+
my_in_pipeline_index = nr % nodes_per_slice
110+
peers = [
111+
i * nodes_per_slice + my_in_pipeline_index
112+
for i in range(num_slices)
113+
if (i * nodes_per_slice + my_in_pipeline_index) != nr
114+
]
115+
all_peer_ranks.append(peers)
116+
117+
# Handle single-slice edge case where peers list is empty
118+
if not all_peer_ranks[0]:
119+
all_peer_ranks = np.zeros((num_nodes, 0), dtype=np.int32)
120+
else:
121+
all_peer_ranks = np.array(all_peer_ranks, dtype=np.int32)
122+
123+
# 3. Create a 1D Mesh over the remote hosts and shard the configuration arrays
124+
cpu_mesh = jax.sharding.Mesh(np.array(unique_cpu_devices), ('d',))
125+
sharding = jax.sharding.NamedSharding(
126+
cpu_mesh, jax.sharding.PartitionSpec('d')
127+
)
128+
129+
# JAX distributes these arrays across the workers natively
130+
sharded_node_ranks = jax.device_put(all_node_ranks, sharding)
131+
sharded_peer_ranks = jax.device_put(all_peer_ranks, sharding)
132+
133+
# 4. Define the SPMD closure that runs on each remote worker
134+
def _setup(local_nr_arr, local_pr_arr):
135+
loc_dir = epath.Path(local_checkpoint_directory)
136+
137+
# JAX sharding slices the arrays into chunks of shape (1,) and (1, P).
138+
# We must index at [0] to extract the pure scalar and the flat list!
139+
node_rank = int(np.asarray(local_nr_arr)[0])
140+
peer_ranks = np.asarray(local_pr_arr)[0].tolist()
141+
142+
_wait_for_replicator_file_to_disappear(
143+
loc_dir, timeout_seconds=timeout_seconds
144+
)
145+
146+
_create_replicator_file(
147+
loc_dir,
148+
run_name=run_name,
149+
num_nodes=num_nodes,
150+
data_parallelism=data_parallelism,
151+
node_rank=node_rank,
152+
peer_ranks=peer_ranks,
153+
backup_interval_minutes=backup_interval_minutes,
154+
)
155+
156+
_wait_for_replicator_file_to_disappear(
157+
loc_dir, timeout_seconds=timeout_seconds
158+
)
159+
_block_and_process_restore_dir(loc_dir, timeout_seconds=timeout_seconds)
160+
161+
# Return array to satisfy SPMD device matching
162+
return local_nr_arr
163+
164+
# 5. Wrap and dispatch using native JAX SPMD!
165+
wrapped_setup_fn = colocated_python.colocated_python(_setup)
166+
wrapped_setup_fn = wrapped_setup_fn.specialize(out_specs_fn=lambda x, y: x)
167+
168+
# Triggers concurrent execution across all nodes without a thread pool
169+
jax.block_until_ready(
170+
wrapped_setup_fn(sharded_node_ranks, sharded_peer_ranks)
171+
)
172+
173+
logging.info(
174+
'Successfully initialized multi-tier checkpointing on all remote hosts '
175+
'via Colocated Python.'
176+
)
177+
178+
73179
def _initialize_jax_from_mtc(
74180
local_checkpoint_directory: epath.Path,
75181
jax_initialization_timeout_seconds: int = 900,
@@ -107,6 +213,7 @@ def initialize_multi_tier_checkpointing(
107213
data_parallelism: Optional[int] = None,
108214
jax_initialization_timeout_seconds: int = 900,
109215
use_mtc_process_ids: bool = True,
216+
use_colocated_python: bool = False,
110217
):
111218
"""Initializes multi-tier checkpointing.
112219
@@ -116,12 +223,34 @@ def initialize_multi_tier_checkpointing(
116223
minutes.
117224
num_slices: The number of slices.
118225
run_name: The name of the run.
119-
data_parallelism: Number of identical pipelines in job, should be
120-
equal to ICI data parallelism * DCN data parallelism. If not provided, it
121-
will be inferred from the number of slices.
226+
data_parallelism: Number of identical pipelines in job, should be equal to
227+
ICI data parallelism * DCN data parallelism. If not provided, it will be
228+
inferred from the number of slices.
122229
jax_initialization_timeout_seconds: The timeout for JAX initialization.
123230
use_mtc_process_ids: Use the MTC rank server to calculate process ids.
231+
use_colocated_python: Whether to use Colocated Python for initialization.
124232
"""
233+
run_name = run_name if run_name else os.environ.get('JOBSET_NAME')
234+
num_slices = num_slices or multislice.slice_count()
235+
data_parallelism = data_parallelism or num_slices
236+
if not run_name:
237+
raise ValueError(
238+
'Run name is not set and JOBSET_NAME is not set in the environment.'
239+
)
240+
241+
if use_colocated_python:
242+
logging.info('Initializing multi-tier checkpointing via Colocated Python.')
243+
_initialize_mtc_colocated(
244+
local_checkpoint_directory=local_checkpoint_directory,
245+
backup_interval_minutes=backup_interval_minutes,
246+
num_slices=num_slices,
247+
run_name=run_name,
248+
data_parallelism=data_parallelism,
249+
timeout_seconds=jax_initialization_timeout_seconds,
250+
)
251+
return
252+
253+
# Standard Multi-Controller Path
125254
if use_mtc_process_ids:
126255
process_id = _initialize_jax_from_mtc(
127256
local_checkpoint_directory, jax_initialization_timeout_seconds
@@ -135,14 +264,9 @@ def initialize_multi_tier_checkpointing(
135264
multihost.initialize_runtime_to_distributed_ids()
136265
multihost.initialize_distributed_to_device_ids()
137266
_wait_for_replicator_file_to_disappear(local_checkpoint_directory)
138-
num_slices = (
139-
num_slices
140-
or multislice.slice_count()
141-
)
142267
num_nodes = jax.process_count()
143268
nodes_per_slice = num_nodes // num_slices
144269
node_rank = jax._src.distributed.global_state.process_id # pylint: disable=protected-access
145-
data_parallelism = data_parallelism or num_slices
146270
my_process_index = jax.process_index()
147271
process_index_to_node_rank = (
148272
multihost.runtime_to_distributed_ids()
@@ -173,11 +297,7 @@ def initialize_multi_tier_checkpointing(
173297
peer_process_rank = process_index_to_node_rank[peer_process_index]
174298
peer_ranks.append(peer_process_rank)
175299
logging.info('Peers for NodeRank %s: %s', node_rank, peer_ranks)
176-
run_name = run_name if run_name else os.environ.get('JOBSET_NAME')
177-
if not run_name:
178-
raise ValueError(
179-
'Run name is not set and JOBSET_NAME is not set in the environment.'
180-
)
300+
181301
_create_replicator_file(
182302
local_checkpoint_directory,
183303
run_name=run_name,

checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization_test.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,43 @@ def test_initialize_multi_tier_checkpointing_run_name_not_set(
279279
num_slices=1,
280280
run_name="",
281281
)
282-
mock_jax_distributed_initialize.assert_called_once_with(
283-
process_id=0,
284-
coordinator_address="coordinator_address",
285-
initialization_timeout=900,
282+
283+
mock_jax_distributed_initialize.assert_not_called()
284+
mock_initialize_runtime_to_distributed_ids.assert_not_called()
285+
mock_initialize_distributed_to_device_ids.assert_not_called()
286+
self.assertEqual(mock_wait_for_replicator_file_to_disappear.call_count, 0)
287+
288+
@mock.patch.object(initialization, "_initialize_mtc_colocated", autospec=True)
289+
@mock.patch.object(jax.distributed, "initialize", autospec=True)
290+
def test_initialize_multi_tier_checkpointing_colocated_success(
291+
self,
292+
mock_jax_distributed_initialize,
293+
mock_init_mtc_colocated,
294+
):
295+
with tempfile.TemporaryDirectory() as tmp_dir:
296+
tmp_dir_path = epath.Path(tmp_dir)
297+
298+
initialization.initialize_multi_tier_checkpointing(
299+
tmp_dir_path,
300+
num_slices=1,
301+
run_name="test-colocated-run",
302+
data_parallelism=1,
303+
use_colocated_python=True,
304+
backup_interval_minutes=15,
286305
)
287-
mock_initialize_runtime_to_distributed_ids.assert_called_once()
288-
mock_initialize_distributed_to_device_ids.assert_called_once()
289-
self.assertEqual(mock_wait_for_replicator_file_to_disappear.call_count, 1)
306+
307+
# Verify colocated Python path is taken
308+
mock_init_mtc_colocated.assert_called_once_with(
309+
local_checkpoint_directory=tmp_dir_path,
310+
backup_interval_minutes=15,
311+
num_slices=1,
312+
run_name="test-colocated-run",
313+
data_parallelism=1,
314+
timeout_seconds=900,
315+
)
316+
317+
# Verify standard multi-controller JAX init is bypassed
318+
mock_jax_distributed_initialize.assert_not_called()
290319

291320
@mock.patch.object(
292321
initialization, "_wait_for_replicator_file_to_disappear", autospec=True

0 commit comments

Comments
 (0)