Skip to content

Commit d0b6e9c

Browse files
author
Orbax Authors
committed
Separate Pathways tests in Orbax OSS build.
PiperOrigin-RevId: 876687909
1 parent 6f4d675 commit d0b6e9c

File tree

10 files changed

+1029
-243
lines changed

10 files changed

+1029
-243
lines changed

.github/workflows/build.yml

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,17 +344,42 @@ jobs:
344344
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
345345
pip uninstall -y orbax
346346
pip install gcsfs
347-
pip install portpicker pytest chex pyyaml
347+
pip install portpicker pytest chex pyyaml pathwaysutils
348348
if [ "${{ matrix.jax-version }}" = "newest" ]; then
349349
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
350350
elif [ "${{ matrix.jax-version }}" = "nightly" ]; then
351351
pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
352352
else
353353
pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
354354
fi
355-
- name: Run multiprocess tests
355+
- name: Run pathways tests
356+
env:
357+
JAX_DEFAULT_BACKEND: pathways
358+
JAX_PLATFORMS: tpu
359+
# Configures JAX to target a subslice within the TPU allocation.
360+
JAX_BACKEND_TARGET: subslice
361+
# Enables IFRT in Pathways.
362+
PATHWAYS_IFRT: true
363+
# Allows JAX to run even if some TPUs are not utilized.
364+
JAX_ALLOW_UNUSED_TPUS: true
365+
run: |
366+
python -c "import pathwaysutils; pathwaysutils.initialize(); print('Pathways initialized'); print(jax.devices());" && python orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --pathways=1
367+
- name: Run colacated pathways tests
368+
env:
369+
JAX_PLATFORMS: pathways
370+
JAX_BACKEND_TARGET: subslice
371+
PATHWAYS_IFRT: true
372+
JAX_ALLOW_UNUSED_TPUS: true
373+
PATHWAYS_EXPECTED_INSTANCES: df=1x1,df=1x1,df=1x1,df=1x1
374+
USE_COLOCATED_PYTHON: true
375+
run: |
376+
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --pathways=1
377+
- name: Run 2 multiprocess tests
378+
run: |
379+
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=2 --tpu_chips_per_process=4 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=2
380+
- name: Run 4 multiprocess tests
356381
run: |
357-
python orbax/checkpoint/_src/testing/oss/run_multihost.py orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4
382+
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=4 --tpu_chips_per_process=2 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4
358383
- name: Run single process tests
359384
run: |
360385
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=1
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
from unittest import mock
17+
18+
from absl.testing import flagsaver
19+
from absl.testing import parameterized
20+
from etils import epath
21+
import jax
22+
import numpy as np
23+
from orbax.checkpoint import test_utils
24+
from orbax.checkpoint._src.multihost import multihost
25+
from orbax.checkpoint._src.testing import multiprocess_test
26+
27+
28+
class MultihostUtilsTestBase:
29+
30+
class Test(parameterized.TestCase):
31+
32+
def setUp(self):
33+
super().setUp()
34+
self.assertEqual(jax.device_count(), 8)
35+
self.assertEqual(jax.process_count(), 4)
36+
self.assertEqual(jax.local_device_count(), 2)
37+
38+
if not multihost.is_runtime_to_distributed_ids_initialized():
39+
multihost.initialize_runtime_to_distributed_ids()
40+
41+
self.tmpdir = epath.Path(
42+
self.create_tempdir(name='multihost_test').full_path
43+
)
44+
test_utils.sync_global_processes('setUp')
45+
46+
def tearDown(self):
47+
test_utils.sync_global_processes('tearDown')
48+
super().tearDown()
49+
50+
def test_process_errors(self):
51+
if multihost.process_index() == 1:
52+
with self.assertRaises(ValueError):
53+
multihost.sync_global_processes(
54+
'test_process_errors_1', processes={0}
55+
)
56+
57+
def test_sync_global_processes(self):
58+
if multihost.process_index() == 0:
59+
time.sleep(2)
60+
(self.tmpdir / 'dummy').mkdir(parents=False, exist_ok=False)
61+
multihost.sync_global_processes('test_sync_global_processes')
62+
self.assertTrue((self.tmpdir / 'dummy').exists())
63+
64+
def test_sync_global_processes_partial(self):
65+
participating_processes = {0, 2}
66+
primary_process = 0
67+
non_primary_process = 1
68+
69+
directory = self.tmpdir / 'testdir'
70+
if multihost.process_index() == primary_process:
71+
directory.mkdir(parents=False, exist_ok=False)
72+
test_utils.sync_global_processes(
73+
'test_sync_global_processes_partial_setup'
74+
)
75+
76+
if multihost.process_index() == primary_process:
77+
time.sleep(2)
78+
(directory / 'dummy').mkdir(parents=False, exist_ok=False)
79+
if multihost.process_index() in participating_processes:
80+
multihost.sync_global_processes(
81+
'test_sync_global_processes_partial',
82+
processes=participating_processes,
83+
)
84+
if multihost.process_index() in participating_processes:
85+
self.assertTrue((directory / 'dummy').exists())
86+
else:
87+
self.assertFalse((directory / 'dummy').exists())
88+
89+
if multihost.process_index() == primary_process:
90+
time.sleep(2)
91+
(directory / 'foo').mkdir(parents=False, exist_ok=False)
92+
if multihost.process_index() in participating_processes:
93+
multihost.sync_global_processes(
94+
'test_sync_global_processes_partial_second',
95+
processes=participating_processes,
96+
)
97+
if multihost.process_index() in participating_processes:
98+
self.assertTrue((directory / 'foo').exists())
99+
else:
100+
self.assertFalse((directory / 'foo').exists())
101+
102+
multihost.sync_global_processes('test_sync_global_processes_partial_all')
103+
# If non-primary processes get past the above barrier without waiting for
104+
# all, then an error would happen for the primary process when trying to
105+
# create subdirectories.
106+
if multihost.process_index() == non_primary_process:
107+
directory.rmtree()
108+
109+
def test_different_barriers(self):
110+
slice1 = {0, 2}
111+
slice2 = {1, 3}
112+
primary_processes = [0, 1]
113+
114+
if multihost.process_index() in primary_processes:
115+
# Don't sleep for slice1, but do sleep for slice2, so that when slice1
116+
# finishes waiting at the barrier, one file exists but the other does
117+
# not.
118+
time.sleep(3 * multihost.process_index())
119+
(self.tmpdir / f'dummy_{multihost.process_index()}').mkdir(
120+
parents=False, exist_ok=False
121+
)
122+
123+
if multihost.process_index() in slice1:
124+
multihost.sync_global_processes(
125+
'test_different_barriers_slice1',
126+
processes=slice1,
127+
)
128+
else:
129+
multihost.sync_global_processes(
130+
'test_different_barriers_slice2',
131+
processes=slice2,
132+
)
133+
if multihost.process_index() in slice1:
134+
self.assertTrue((self.tmpdir / 'dummy_0').exists())
135+
self.assertFalse((self.tmpdir / 'dummy_1').exists())
136+
else:
137+
self.assertTrue((self.tmpdir / 'dummy_0').exists())
138+
self.assertTrue((self.tmpdir / 'dummy_1').exists())
139+
140+
def test_broadcast_one_to_all(self):
141+
if multihost.process_index() == 0:
142+
tree = {'bar': [5, 12]}
143+
else:
144+
tree = {'bar': [0, 0]}
145+
result = multihost.broadcast_one_to_all(tree)
146+
147+
expected = {
148+
'bar': [np.asarray(5, dtype=np.int32), np.asarray(12, dtype=np.int32)]
149+
}
150+
test_utils.assert_tree_equal(self, expected, result)
151+
152+
153+
def test_sync_global_processes_with_distributed_barrier(self):
154+
with flagsaver.flagsaver(
155+
experimental_orbax_use_distributed_barrier=True
156+
), mock.patch.object(
157+
multihost.multihost_utils, 'sync_global_devices', autospec=True
158+
) as mock_sync_global_devices, mock.patch.object(
159+
multihost, 'get_barrier_sync_fn', autospec=True
160+
) as mock_get_barrier_sync_fn, mock.patch.object(
161+
multihost, 'should_skip_process_sync', return_value=False
162+
):
163+
multihost.sync_global_processes('test_barrier')
164+
165+
mock_sync_global_devices.assert_not_called()
166+
mock_get_barrier_sync_fn.assert_called_once_with(processes=None)
167+
mock_get_barrier_sync_fn.return_value.assert_called_once_with(
168+
key='test_barrier', timeout_ms=300000
169+
)
170+
171+
def test_sync_global_processes_without_distributed_barrier(self):
172+
with flagsaver.flagsaver(
173+
experimental_orbax_use_distributed_barrier=False
174+
), mock.patch.object(
175+
multihost.multihost_utils, 'sync_global_devices', autospec=True
176+
) as mock_sync_global_devices, mock.patch.object(
177+
multihost, 'get_barrier_sync_fn', autospec=True
178+
) as mock_get_barrier_sync_fn, mock.patch.object(
179+
multihost, 'should_skip_process_sync', return_value=False
180+
):
181+
multihost.sync_global_processes('test_barrier')
182+
183+
mock_sync_global_devices.assert_called_once()
184+
mock_get_barrier_sync_fn.assert_not_called()
185+
186+
187+
class MultihostUtilsTestStandard(MultihostUtilsTestBase.Test):
188+
189+
def setUp(self):
190+
self.enter_context(
191+
flagsaver.flagsaver(experimental_orbax_use_distributed_process_id=False)
192+
)
193+
super().setUp()
194+
195+
def test_sync_global_processes_partial(self):
196+
self.skipTest('Fix this scenario.')
197+
198+
def test_different_barriers(self):
199+
self.skipTest('Fix this scenario.')
200+
201+
202+
class MultihostUtilsTestDistributedId(MultihostUtilsTestBase.Test):
203+
204+
def setUp(self):
205+
self.enter_context(
206+
flagsaver.flagsaver(experimental_orbax_use_distributed_process_id=True)
207+
)
208+
super().setUp()
209+
210+
211+
if __name__ == '__main__':
212+
multiprocess_test.main()

0 commit comments

Comments
 (0)