Skip to content

Commit 7592df0

Browse files
author
Orbax Authors
committed
Add escape hatch for jax_init_info\
PiperOrigin-RevId: 869850598
1 parent cf5a1ce commit 7592df0

File tree

3 files changed

+119
-27
lines changed

3 files changed

+119
-27
lines changed

checkpoint/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- `uvloop` dependency for improved event loop performance
13+
- Add escape hatch for multi-tier checkpointing initialization by
14+
setting the env var IGNORE_MTC_PROCESS_IDS=true
1315

1416
### Removed
1517

checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization.py

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,34 @@ def _create_replicator_file(
7070
os.rename(temp_file, replicator_file)
7171

7272

73+
def _initialize_jax_from_mtc(
74+
local_checkpoint_directory: epath.Path,
75+
jax_initialization_timeout_seconds: int = 900,
76+
) -> str:
77+
"""Initialize jax with jax_init_info."""
78+
local_checkpoint_directory = epath.Path(local_checkpoint_directory)
79+
process_id, coordinator_address = _retrieve_jax_init_info(
80+
local_checkpoint_directory
81+
)
82+
if not process_id or not coordinator_address:
83+
raise ValueError(
84+
'Data is missing from the JAX init info file: Current values:'
85+
f' process_id: {process_id}, coordinator_address: {coordinator_address}'
86+
)
87+
logging.info(
88+
'Using process_id %s and coordinator_address %s to initialize JAX'
89+
' distributed runtime...',
90+
process_id,
91+
coordinator_address,
92+
)
93+
jax.distributed.initialize(
94+
process_id=int(process_id),
95+
coordinator_address=coordinator_address,
96+
initialization_timeout=jax_initialization_timeout_seconds,
97+
)
98+
return process_id
99+
100+
73101
def initialize_multi_tier_checkpointing(
74102
local_checkpoint_directory: epath.Path,
75103
*,
@@ -92,26 +120,20 @@ def initialize_multi_tier_checkpointing(
92120
will be inferred from the number of slices.
93121
jax_initialization_timeout_seconds: The timeout for JAX initialization.
94122
"""
95-
local_checkpoint_directory = epath.Path(local_checkpoint_directory)
96-
process_id, coordinator_address = _retrieve_jax_init_info(
97-
local_checkpoint_directory
98-
)
99-
if not process_id or not coordinator_address:
100-
raise ValueError(
101-
'Data is missing from the JAX init info file: Current values:'
102-
f' process_id: {process_id}, coordinator_address: {coordinator_address}'
123+
use_mtc_init = True
124+
if os.getenv('IGNORE_MTC_PROCESS_IDS', '').lower() == 'true':
125+
use_mtc_init = False
126+
127+
if use_mtc_init:
128+
process_id = _initialize_jax_from_mtc(
129+
local_checkpoint_directory, jax_initialization_timeout_seconds
103130
)
104-
logging.info(
105-
'Using process_id %s and coordinator_address %s to initialize JAX'
106-
' distributed runtime...',
107-
process_id,
108-
coordinator_address,
109-
)
110-
jax.distributed.initialize(
111-
process_id=int(process_id),
112-
coordinator_address=coordinator_address,
113-
initialization_timeout=jax_initialization_timeout_seconds,
114-
)
131+
else:
132+
process_id = None
133+
jax.distributed.initialize(
134+
initialization_timeout=jax_initialization_timeout_seconds,
135+
)
136+
115137
multihost.initialize_runtime_to_distributed_ids()
116138
multihost.initialize_distributed_to_device_ids()
117139
_wait_for_replicator_file_to_disappear(local_checkpoint_directory)
@@ -127,14 +149,24 @@ def initialize_multi_tier_checkpointing(
127149
process_index_to_node_rank = (
128150
multihost.runtime_to_distributed_ids()
129151
)
130-
logging.info(
131-
'Mapping of IDs: jax-init-info.txt=%s, NodeRank=%s, ProcessIndex=%s,'
132-
' ProcessIndex->NodeRank=%s',
133-
process_id,
134-
node_rank,
135-
my_process_index,
136-
process_index_to_node_rank,
137-
)
152+
if use_mtc_init:
153+
logging.info(
154+
'Mapping of IDs: jax-init-info.txt=%s, NodeRank=%s, ProcessIndex=%s,'
155+
' ProcessIndex->NodeRank=%s',
156+
process_id,
157+
node_rank,
158+
my_process_index,
159+
process_index_to_node_rank,
160+
)
161+
else:
162+
logging.info(
163+
'Mapping of IDs (jax-init-info not used): NodeRank=%s, ProcessIndex=%s,'
164+
' ProcessIndex->NodeRank=%s',
165+
node_rank,
166+
my_process_index,
167+
process_index_to_node_rank,
168+
)
169+
138170
my_in_pipeline_index = my_process_index % nodes_per_slice
139171
peer_ranks = []
140172
for i in range(num_slices):

checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Initialization test for multi-tier checkpointing."""
1616

17+
import os
1718
import tempfile
1819
from unittest import mock
1920

@@ -288,6 +289,63 @@ def test_initialize_multi_tier_checkpointing_run_name_not_set(
288289
mock_initialize_distributed_to_device_ids.assert_called_once()
289290
self.assertEqual(mock_wait_for_replicator_file_to_disappear.call_count, 1)
290291

292+
@mock.patch.object(
293+
initialization, "_wait_for_replicator_file_to_disappear", autospec=True
294+
)
295+
@mock.patch.object(initialization, "_create_replicator_file", autospec=True)
296+
@mock.patch.object(jax.distributed, "initialize", autospec=True)
297+
@mock.patch.object(
298+
multihost, "initialize_runtime_to_distributed_ids", autospec=True
299+
)
300+
@mock.patch.object(
301+
multihost, "initialize_distributed_to_device_ids", autospec=True
302+
)
303+
@mock.patch.object(multihost, "runtime_to_distributed_ids", autospec=True)
304+
def test_initialize_multi_tier_checkpointing_skip_init_info(
305+
self,
306+
mock_runtime_to_distributed_ids,
307+
mock_initialize_distributed_to_device_ids,
308+
mock_initialize_runtime_to_distributed_ids,
309+
mock_jax_distributed_initialize,
310+
mock_create_replicator_file,
311+
mock_wait_for_replicator_file_to_disappear,
312+
):
313+
mock_runtime_to_distributed_ids.return_value = [0, 1]
314+
mock_jax_distributed_initialize.return_value = None
315+
mock_initialize_runtime_to_distributed_ids.return_value = [None, None]
316+
mock_initialize_distributed_to_device_ids.return_value = None
317+
mock_create_replicator_file.return_value = [None, None]
318+
mock_wait_for_replicator_file_to_disappear.return_value = False
319+
320+
with tempfile.TemporaryDirectory() as tmp_dir:
321+
epath.Path(tmp_dir).mkdir(parents=True, exist_ok=True)
322+
replicator_file = epath.Path(tmp_dir) / initialization._REPLICATOR_FILE
323+
replicator_file.write_text("replicator.yaml")
324+
self.assertTrue(replicator_file.exists())
325+
326+
restore_dir = epath.Path(tmp_dir) / "test-run-s1-n0-w0.restore"
327+
restore_dir.write_text("restore_dir")
328+
self.assertTrue(restore_dir.exists())
329+
330+
os.environ["IGNORE_MTC_PROCESS_IDS"] = "true"
331+
initialization.initialize_multi_tier_checkpointing(
332+
epath.Path(tmp_dir),
333+
num_slices=1,
334+
run_name="test-run",
335+
data_parallelism=1,
336+
)
337+
mock_jax_distributed_initialize.assert_called_once_with(
338+
initialization_timeout=900,
339+
)
340+
mock_initialize_runtime_to_distributed_ids.assert_called_once()
341+
mock_initialize_distributed_to_device_ids.assert_called_once()
342+
self.assertEqual(mock_wait_for_replicator_file_to_disappear.call_count, 2)
343+
mock_create_replicator_file.assert_called_once()
344+
expected_restore_dir = epath.Path(tmp_dir) / "1"
345+
self.assertTrue(expected_restore_dir.exists())
346+
347+
del os.environ["IGNORE_MTC_PROCESS_IDS"]
348+
291349

292350
if __name__ == "__main__":
293351
absltest.main()

0 commit comments

Comments
 (0)