Skip to content

Commit b4f23c1

Browse files
author
Orbax Authors
committed
Include Orbax Checkpoint experimental/emergency tests in OSS.
PiperOrigin-RevId: 874406518
1 parent 60b50ba commit b4f23c1

File tree

6 files changed

+259
-234
lines changed

6 files changed

+259
-234
lines changed

.github/workflows/build.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,12 @@ jobs:
352352
else
353353
pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
354354
fi
355-
- name: Run multiprocess tests
355+
- name: Run 2 multiprocess tests
356356
run: |
357-
python orbax/checkpoint/_src/testing/oss/run_multihost.py orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4
357+
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=2 --tpu_chips_per_process=4 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=2
358+
- name: Run 4 multiprocess tests
359+
run: |
360+
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=4 --tpu_chips_per_process=2 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4
358361
- name: Run single process tests
359362
run: |
360363
python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=1
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
from unittest import mock
17+
18+
from absl.testing import flagsaver
19+
from absl.testing import parameterized
20+
from etils import epath
21+
import jax
22+
import numpy as np
23+
from orbax.checkpoint import test_utils
24+
from orbax.checkpoint._src.multihost import multihost
25+
from orbax.checkpoint._src.testing import multiprocess_test
26+
27+
28+
class MultihostUtilsTestBase:
29+
30+
class Test(parameterized.TestCase):
31+
32+
def setUp(self):
33+
super().setUp()
34+
self.assertEqual(jax.device_count(), 8)
35+
self.assertEqual(jax.process_count(), 4)
36+
self.assertEqual(jax.local_device_count(), 2)
37+
38+
if not multihost.is_runtime_to_distributed_ids_initialized():
39+
multihost.initialize_runtime_to_distributed_ids()
40+
41+
self.tmpdir = epath.Path(
42+
self.create_tempdir(name='multihost_test').full_path
43+
)
44+
test_utils.sync_global_processes('setUp')
45+
46+
def tearDown(self):
47+
test_utils.sync_global_processes('tearDown')
48+
super().tearDown()
49+
50+
def test_process_errors(self):
51+
if multihost.process_index() == 1:
52+
with self.assertRaises(ValueError):
53+
multihost.sync_global_processes(
54+
'test_process_errors_1', processes={0}
55+
)
56+
57+
def test_sync_global_processes(self):
58+
if multihost.process_index() == 0:
59+
time.sleep(2)
60+
(self.tmpdir / 'dummy').mkdir(parents=False, exist_ok=False)
61+
multihost.sync_global_processes('test_sync_global_processes')
62+
self.assertTrue((self.tmpdir / 'dummy').exists())
63+
64+
def test_sync_global_processes_partial(self):
65+
participating_processes = {0, 2}
66+
primary_process = 0
67+
non_primary_process = 1
68+
69+
directory = self.tmpdir / 'testdir'
70+
if multihost.process_index() == primary_process:
71+
directory.mkdir(parents=False, exist_ok=False)
72+
test_utils.sync_global_processes(
73+
'test_sync_global_processes_partial_setup'
74+
)
75+
76+
if multihost.process_index() == primary_process:
77+
time.sleep(2)
78+
(directory / 'dummy').mkdir(parents=False, exist_ok=False)
79+
if multihost.process_index() in participating_processes:
80+
multihost.sync_global_processes(
81+
'test_sync_global_processes_partial',
82+
processes=participating_processes,
83+
)
84+
if multihost.process_index() in participating_processes:
85+
self.assertTrue((directory / 'dummy').exists())
86+
else:
87+
self.assertFalse((directory / 'dummy').exists())
88+
89+
if multihost.process_index() == primary_process:
90+
time.sleep(2)
91+
(directory / 'foo').mkdir(parents=False, exist_ok=False)
92+
if multihost.process_index() in participating_processes:
93+
multihost.sync_global_processes(
94+
'test_sync_global_processes_partial_second',
95+
processes=participating_processes,
96+
)
97+
if multihost.process_index() in participating_processes:
98+
self.assertTrue((directory / 'foo').exists())
99+
else:
100+
self.assertFalse((directory / 'foo').exists())
101+
102+
multihost.sync_global_processes('test_sync_global_processes_partial_all')
103+
# If non-primary processes get past the above barrier without waiting for
104+
# all, then an error would happen for the primary process when trying to
105+
# create subdirectories.
106+
if multihost.process_index() == non_primary_process:
107+
directory.rmtree()
108+
109+
def test_different_barriers(self):
110+
slice1 = {0, 2}
111+
slice2 = {1, 3}
112+
primary_processes = [0, 1]
113+
114+
if multihost.process_index() in primary_processes:
115+
# Don't sleep for slice1, but do sleep for slice2, so that when slice1
116+
# finishes waiting at the barrier, one file exists but the other does
117+
# not.
118+
time.sleep(3 * multihost.process_index())
119+
(self.tmpdir / f'dummy_{multihost.process_index()}').mkdir(
120+
parents=False, exist_ok=False
121+
)
122+
123+
if multihost.process_index() in slice1:
124+
multihost.sync_global_processes(
125+
'test_different_barriers_slice1',
126+
processes=slice1,
127+
)
128+
else:
129+
multihost.sync_global_processes(
130+
'test_different_barriers_slice2',
131+
processes=slice2,
132+
)
133+
if multihost.process_index() in slice1:
134+
self.assertTrue((self.tmpdir / 'dummy_0').exists())
135+
self.assertFalse((self.tmpdir / 'dummy_1').exists())
136+
else:
137+
self.assertTrue((self.tmpdir / 'dummy_0').exists())
138+
self.assertTrue((self.tmpdir / 'dummy_1').exists())
139+
140+
def test_broadcast_one_to_all(self):
141+
if multihost.process_index() == 0:
142+
tree = {'bar': [5, 12]}
143+
else:
144+
tree = {'bar': [0, 0]}
145+
result = multihost.broadcast_one_to_all(tree)
146+
147+
expected = {
148+
'bar': [np.asarray(5, dtype=np.int32), np.asarray(12, dtype=np.int32)]
149+
}
150+
test_utils.assert_tree_equal(self, expected, result)
151+
152+
153+
def test_sync_global_processes_with_distributed_barrier(self):
154+
with flagsaver.flagsaver(
155+
experimental_orbax_use_distributed_barrier=True
156+
), mock.patch.object(
157+
multihost.multihost_utils, 'sync_global_devices', autospec=True
158+
) as mock_sync_global_devices, mock.patch.object(
159+
multihost, 'get_barrier_sync_fn', autospec=True
160+
) as mock_get_barrier_sync_fn, mock.patch.object(
161+
multihost, 'should_skip_process_sync', return_value=False
162+
):
163+
multihost.sync_global_processes('test_barrier')
164+
165+
mock_sync_global_devices.assert_not_called()
166+
mock_get_barrier_sync_fn.assert_called_once_with(processes=None)
167+
mock_get_barrier_sync_fn.return_value.assert_called_once_with(
168+
key='test_barrier', timeout_ms=300000
169+
)
170+
171+
def test_sync_global_processes_without_distributed_barrier(self):
172+
with flagsaver.flagsaver(
173+
experimental_orbax_use_distributed_barrier=False
174+
), mock.patch.object(
175+
multihost.multihost_utils, 'sync_global_devices', autospec=True
176+
) as mock_sync_global_devices, mock.patch.object(
177+
multihost, 'get_barrier_sync_fn', autospec=True
178+
) as mock_get_barrier_sync_fn, mock.patch.object(
179+
multihost, 'should_skip_process_sync', return_value=False
180+
):
181+
multihost.sync_global_processes('test_barrier')
182+
183+
mock_sync_global_devices.assert_called_once()
184+
mock_get_barrier_sync_fn.assert_not_called()
185+
186+
187+
class MultihostUtilsTestStandard(MultihostUtilsTestBase.Test):
188+
189+
def setUp(self):
190+
self.enter_context(
191+
flagsaver.flagsaver(experimental_orbax_use_distributed_process_id=False)
192+
)
193+
super().setUp()
194+
195+
def test_sync_global_processes_partial(self):
196+
self.skipTest('Fix this scenario.')
197+
198+
def test_different_barriers(self):
199+
self.skipTest('Fix this scenario.')
200+
201+
202+
class MultihostUtilsTestDistributedId(MultihostUtilsTestBase.Test):
203+
204+
def setUp(self):
205+
self.enter_context(
206+
flagsaver.flagsaver(experimental_orbax_use_distributed_process_id=True)
207+
)
208+
super().setUp()
209+
210+
211+
if __name__ == '__main__':
212+
multiprocess_test.main()

checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py

Lines changed: 32 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Script to generate YAML file with test targets based on tags."""
1616

1717
import ast
18+
import collections
1819
import os
1920
import sys
2021

@@ -29,7 +30,11 @@
2930
'pytype_strict_contrib_test',
3031
]
3132
EXCLUDED_PATHS = [
32-
'orbax/checkpoint/experimental',
33+
'orbax/checkpoint/experimental/model_surgery',
34+
'orbax/checkpoint/experimental/v1',
35+
'orbax/checkpoint/experimental/emergency/p2p',
36+
'orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py',
37+
'orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py',
3338
'orbax/checkpoint/google',
3439
]
3540

@@ -60,53 +65,15 @@ def get_str_val(node):
6065
return None
6166

6267

63-
def inherits_from_multiprocess_test(test_file_path):
64-
"""Checks if test file inherits from MultiProcessTest."""
65-
try:
66-
with open(test_file_path, 'r') as f:
67-
content = f.read()
68-
except FileNotFoundError:
69-
return False
70-
try:
71-
tree = ast.parse(content, filename=test_file_path)
72-
except SyntaxError:
73-
return False
74-
75-
imported_as_name = None # if imported as `from ... import MultiProcessTest`
76-
imported_as_module = [] # if imported as `from ... import multiprocess_test`
77-
78-
for node in tree.body:
79-
if isinstance(node, ast.ImportFrom):
80-
if node.module == 'orbax.checkpoint._src.testing.multiprocess_test':
81-
for alias in node.names:
82-
if alias.name == 'MultiProcessTest':
83-
imported_as_name = alias.asname or alias.name
84-
elif node.module == 'orbax.checkpoint._src.testing':
85-
for alias in node.names:
86-
if alias.name == 'multiprocess_test':
87-
imported_as_module.append(alias.asname or alias.name)
88-
89-
if not imported_as_name and not imported_as_module:
90-
return False
91-
92-
for node in tree.body:
93-
if isinstance(node, ast.ClassDef):
94-
for base in node.bases:
95-
if (
96-
imported_as_name
97-
and isinstance(base, ast.Name)
98-
and base.id == imported_as_name
99-
):
100-
return True
101-
if (
102-
imported_as_module
103-
and isinstance(base, ast.Attribute)
104-
and isinstance(base.value, ast.Name)
105-
and base.value.id in imported_as_module
106-
and base.attr == 'MultiProcessTest'
107-
):
108-
return True
109-
return False
68+
def get_num_processes(args):
69+
"""Returns num_processes from args."""
70+
for arg in args:
71+
if arg.startswith('--num_processes='):
72+
try:
73+
return int(arg.split('=', 1)[1])
74+
except ValueError:
75+
return None
76+
return None
11077

11178

11279
def get_build_targets(build_file_path):
@@ -135,18 +102,18 @@ def get_build_targets(build_file_path):
135102

136103
if rule_name in TEST_RULES:
137104
kwargs = get_kwargs(call)
138-
if 'name' in kwargs and 'tags' in kwargs:
105+
if 'name' in kwargs:
139106
name = get_str_val(kwargs['name'])
140-
tags = get_list_val(kwargs['tags'])
107+
tags = get_list_val(kwargs['tags']) if 'tags' in kwargs else []
141108
srcs = get_list_val(kwargs['srcs']) if 'srcs' in kwargs else []
142-
if name and tags:
143-
yield name, tags, srcs
109+
args = get_list_val(kwargs['args']) if 'args' in kwargs else []
110+
if name:
111+
yield name, tags, srcs, args
144112

145113

146114
def run(root_dir, output_file):
147115
"""Runs the script to generate tagged tests file."""
148-
tests_by_tag = {tag: [] for tag in TAG_MAPPING.values()}
149-
tests_by_tag['processes:1'] = []
116+
tests_by_tag = collections.defaultdict(list)
150117

151118
count = 0
152119
for dirpath, dirnames, filenames in os.walk(root_dir):
@@ -166,35 +133,33 @@ def run(root_dir, output_file):
166133
count += 1
167134
build_file = os.path.join(dirpath, 'BUILD')
168135
package_path = dirpath.removeprefix('third_party/py/')
169-
for name, tags, srcs in get_build_targets(build_file):
136+
for name, tags, srcs, args in get_build_targets(build_file):
137+
if not any(tag in TAG_MAPPING for tag in tags):
138+
continue
170139
if srcs and any(
171140
os.path.join(dirpath, srcs[0]).startswith(p) for p in EXCLUDED_PATHS
172141
):
173142
continue
174-
is_multiprocess = False
175-
if srcs:
176-
is_multiprocess = inherits_from_multiprocess_test(
177-
os.path.join(dirpath, srcs[0])
178-
)
179143
target_path = f'{package_path}:{name}'
180-
if not is_multiprocess:
181-
tests_by_tag['processes:1'].append(target_path)
144+
num_processes = get_num_processes(args)
145+
if num_processes and num_processes > 1:
146+
tag = f'processes:{num_processes}'
147+
tests_by_tag[tag].append(target_path)
182148
else:
183-
for tag in tags:
184-
if tag in TAG_MAPPING:
185-
tests_by_tag[TAG_MAPPING[tag]].append(target_path)
149+
tests_by_tag['processes:1'].append(target_path)
186150

187151
print(f'Processed {count} BUILD files.')
188152

153+
result_dict = {}
189154
for tag in tests_by_tag:
190-
tests_by_tag[tag] = sorted(list(set(tests_by_tag[tag])))
155+
result_dict[tag] = sorted(list(set(tests_by_tag[tag])))
191156

192157
header = """# DO NOT EDIT!
193158
"""
194159
os.makedirs(os.path.dirname(output_file), exist_ok=True)
195160
with open(output_file, 'w') as f:
196161
f.write(header)
197-
yaml.dump(tests_by_tag, f, default_flow_style=False)
162+
yaml.dump(result_dict, f, default_flow_style=False)
198163
print(f'Output written to {output_file}')
199164

200165

0 commit comments

Comments
 (0)