Skip to content

Commit d9c51ad

Browse files
lyglstOrbax Authors
authored andcommitted
Add flags for PyTorch & DCP support in the Orbax checkpoint benchmark launcher.
PiperOrigin-RevId: 865621844
1 parent 60b50ba commit d9c51ad

File tree

10 files changed

+646
-20
lines changed

10 files changed

+646
-20
lines changed

checkpoint/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- #v1 Add `use_load_and_broadcast` option.
13+
- Add PyTorch DCP (Distributed Checkpoint) to the benchmark suite.
1314

1415
### Removed
1516

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from absl import logging
2626
from etils import epath
2727
import jax
28-
from orbax.checkpoint._src.multihost import multihost
2928
from orbax.checkpoint._src.testing.benchmarks.core import checkpoint_generation
3029
from orbax.checkpoint._src.testing.benchmarks.core import configs
3130
from orbax.checkpoint._src.testing.benchmarks.core import device_mesh
3231
from orbax.checkpoint._src.testing.benchmarks.core import directory_setup
3332
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
33+
from orbax.checkpoint._src.testing.benchmarks.core import multihost
3434

3535

3636
@dataclasses.dataclass(frozen=True)
@@ -148,7 +148,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
148148
name += f"_repeat_{repeat_index}"
149149
logging.info(
150150
"[process_id=%s] Setting up test: %s",
151-
multihost.process_index(),
151+
multihost.get_process_index(),
152152
name,
153153
)
154154

@@ -193,7 +193,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
193193

194194
logging.info(
195195
"[process_id=%s] Executing test function: %s",
196-
multihost.process_index(),
196+
multihost.get_process_index(),
197197
name,
198198
)
199199
try:
@@ -203,13 +203,13 @@ def run(self, repeat_index: int | None = None) -> TestResult:
203203
# execution is recorded in the TestResult.
204204
if sys.version_info >= (3, 11):
205205
e.add_note(
206-
f"[process_id={multihost.process_index()}],"
206+
f"[process_id={multihost.get_process_index()}],"
207207
f" {test_context_summary[:100]}"
208208
)
209209
logging.error(
210210
"[process_id=%s] Test function '%s' context: %s, raised an"
211211
" exception: %s",
212-
multihost.process_index(),
212+
multihost.get_process_index(),
213213
name,
214214
test_context_summary[:100],
215215
e,
@@ -223,7 +223,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
223223

224224
logging.info(
225225
"[process_id=%s] Test finished: %s",
226-
multihost.process_index(),
226+
multihost.get_process_index(),
227227
name,
228228
)
229229

@@ -306,13 +306,13 @@ def _get_options_product(self) -> Sequence[BenchmarkOptions]:
306306
option_instances.append(option_instance)
307307
logging.info(
308308
"[process_id=%s] Generating valid option combination: %s",
309-
multihost.process_index(),
309+
multihost.get_process_index(),
310310
option_instance,
311311
)
312312
else:
313313
logging.info(
314314
"[process_id=%s] Skipping invalid option combination: %s",
315-
multihost.process_index(),
315+
multihost.get_process_index(),
316316
option_instance,
317317
)
318318
return option_instances

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
import os
1617
from typing import List
1718
from unittest import mock
1819

@@ -23,12 +24,14 @@
2324
import jax
2425
import numpy as np
2526
from orbax.checkpoint import test_utils
27+
from orbax.checkpoint._src.multihost import multihost
2628
from orbax.checkpoint._src.testing.benchmarks.core import checkpoint_generation
2729
from orbax.checkpoint._src.testing.benchmarks.core import configs
2830
from orbax.checkpoint._src.testing.benchmarks.core import core
2931
from orbax.checkpoint._src.testing.benchmarks.core import device_mesh
3032
from orbax.checkpoint._src.testing.benchmarks.core import directory_setup
3133
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
34+
import torch.distributed as dist
3235

3336

3437
@dataclasses.dataclass(frozen=True)
@@ -472,8 +475,13 @@ def test_run_no_benchmarks_generated(
472475
@mock.patch.object(directory_setup, 'setup_test_directory')
473476
@mock.patch.object(checkpoint_generation, 'generate_checkpoint')
474477
@mock.patch.object(logging, 'info')
478+
@mock.patch.object(
479+
multihost,
480+
'sync_global_processes',
481+
)
475482
def test_run_generates_report_with_failures(
476483
self,
484+
mock_sync_global_processes,
477485
mock_logging_info,
478486
mock_generate_checkpoint,
479487
mock_setup_test_directory,
@@ -499,6 +507,11 @@ def test_fn(self, test_context: core.TestContext) -> core.TestResult:
499507
suite = core.TestSuite(name='report_suite', benchmarks_generators=[gen])
500508
suite.run()
501509

510+
mock_sync_global_processes.assert_any_call('benchmark:run')
511+
mock_sync_global_processes.assert_any_call('benchmark:setup_test_directory')
512+
mock_sync_global_processes.assert_any_call('benchmark:setup_pytree')
513+
mock_sync_global_processes.assert_any_call('test_suite:run_end')
514+
502515
# 3 benchmarks generated: (1,a), (1,b), (2,b).
503516
# (1,a), (1,b) pass. (2,b) fails because (2,a) is invalid.
504517

@@ -512,6 +525,24 @@ def test_fn(self, test_context: core.TestContext) -> core.TestResult:
512525
self.assertIn('--- Failed Runs ---', report_log)
513526
self.assertIn("Error: ValueError('opt1=2, opt2=b failed')", report_log)
514527

528+
@mock.patch.object(core.Benchmark, 'run')
529+
@mock.patch.object(dist, 'barrier')
530+
def test_run_with_torch(self, mock_dist_barrier, mock_benchmark_run):
531+
# Initialize torch.distributed for testing.
532+
os.environ.setdefault('MASTER_ADDR', 'localhost')
533+
os.environ.setdefault('MASTER_PORT', '12355')
534+
dist.init_process_group(backend='gloo', rank=0, world_size=1)
535+
gen = MyGenerator(
536+
checkpoint_configs=[configs.CheckpointConfig(spec={})],
537+
options=MyBenchmarkOptions(opt1=[1, 2]),
538+
)
539+
suite = core.TestSuite(name='my_suite', benchmarks_generators=[gen])
540+
541+
suite.run()
542+
mock_dist_barrier.assert_any_call()
543+
544+
self.assertEqual(mock_benchmark_run.call_count, 1)
545+
515546

516547
if __name__ == '__main__':
517548
absltest.main()

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/directory_setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from absl import logging
1818
from etils import epath
19-
import jax
19+
from orbax.checkpoint._src.testing.benchmarks.core import multihost
2020

2121

2222
def setup_test_directory(
@@ -39,7 +39,7 @@ def setup_test_directory(
3939
if repeat_index is not None:
4040
path = path / f"repeat_{repeat_index}"
4141
logging.info("Setting up test directory at: %s", path)
42-
if jax.process_index() == 0:
42+
if multihost.get_process_index() == 0:
4343
if path.exists():
4444
logging.warning("Test directory %s already exists. Deleting it.", path)
4545
path.rmtree()

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/directory_setup_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
from absl.testing import absltest
1818
from etils import epath
19-
import jax
2019
from orbax.checkpoint._src.testing.benchmarks.core import directory_setup
20+
from orbax.checkpoint._src.testing.benchmarks.core import multihost
2121

2222

2323
class DirectorySetupTest(absltest.TestCase):
@@ -51,7 +51,7 @@ def test_setup_test_directory_already_exists(self):
5151
self.assertTrue(path.exists())
5252
self.assertFalse((path / 'some_file').exists())
5353

54-
@mock.patch.object(jax, 'process_index', return_value=1)
54+
@mock.patch.object(multihost, 'get_process_index', return_value=1)
5555
def test_setup_test_directory_non_zero_process_index_does_not_exist(self, _):
5656
temp_dir = self.create_tempdir()
5757

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/metric.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from clu import metric_writers
3131
from etils import epath
3232
import numpy as np
33-
from orbax.checkpoint._src.multihost import multihost
33+
from orbax.checkpoint._src.testing.benchmarks.core import multihost
3434
import psutil
3535
import tensorstore as ts
3636

@@ -47,7 +47,7 @@ def start(self):
4747
self._start_time = time.perf_counter()
4848
logging.info(
4949
"[process_id=%s] Starting metric: '%s'...",
50-
multihost.process_index(),
50+
multihost.get_process_index(),
5151
self.name,
5252
)
5353

@@ -56,7 +56,7 @@ def stop(self) -> dict[str, tuple[Any, str]]:
5656
duration = time.perf_counter() - self._start_time
5757
logging.info(
5858
"[process_id=%s] Finished metric: '%s' (took %.4fs)",
59-
multihost.process_index(),
59+
multihost.get_process_index(),
6060
self.name,
6161
duration,
6262
)
@@ -168,7 +168,7 @@ def stop(self) -> dict[str, tuple[Any, str]]:
168168

169169
self._log_tracemalloc_snapshot_diff(
170170
self.name,
171-
multihost.process_index(),
171+
multihost.get_process_index(),
172172
self._start_snapshot,
173173
end_snapshot,
174174
top_n=15,
@@ -285,7 +285,7 @@ def stop(self) -> dict[str, tuple[Any, str]]:
285285
diff = self._diff_metrics(self._start_metrics, end_metrics)
286286
logging.info(
287287
"[process_id=%s] Finished metric: %s, num_diffs=%d",
288-
multihost.process_index(),
288+
multihost.get_process_index(),
289289
self.name,
290290
len(diff),
291291
)
@@ -423,12 +423,12 @@ def report(self):
423423
"""Logs a formatted report of all collected metrics."""
424424
report_lines = []
425425
report_lines.append(
426-
f"---[process_id={multihost.process_index()}] {self.name} Metrics"
426+
f"---[process_id={multihost.get_process_index()}] {self.name} Metrics"
427427
" Report ---"
428428
)
429429
if not self.results:
430430
report_lines.append(
431-
f"[process_id={multihost.process_index()}] No metrics recorded."
431+
f"[process_id={multihost.get_process_index()}] No metrics recorded."
432432
)
433433
else:
434434
for name, (value, unit) in sorted(self.results.items()):
@@ -565,7 +565,7 @@ def add_result(
565565
def _get_writer(self, benchmark_name: str) -> Any:
566566
"""Gets or creates a TensorBoard writer for the given benchmark."""
567567
if benchmark_name not in self._writers:
568-
is_primary_host = multihost.process_index() == 0
568+
is_primary_host = multihost.get_process_index() == 0
569569
self._writers[benchmark_name] = metric_writers.create_default_writer(
570570
self._tensorboard_dir,
571571
just_logging=not is_primary_host,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Multihost utilities for benchmarks."""
16+
17+
import threading
18+
19+
from absl import logging
20+
from orbax.checkpoint._src.multihost import multihost
21+
22+
23+
def get_process_index() -> int:
24+
"""Returns process index from torch if available, else from multihost."""
25+
try:
26+
import torch.distributed as dist # pylint: disable=g-import-not-at-top
27+
28+
if dist.is_initialized():
29+
return dist.get_rank()
30+
except ImportError:
31+
pass
32+
return multihost.process_index()
33+
34+
35+
def sync_global_processes(
36+
name: str,
37+
) -> None:
38+
"""Syncs global processes using torch if available, else multihost."""
39+
try:
40+
import torch.distributed as dist # pylint: disable=g-import-not-at-top
41+
42+
if dist.is_initialized():
43+
logging.vlog(
44+
1,
45+
"[process=%s][thread=%s] sync_global_processes with torch"
46+
" barrier: %s",
47+
dist.get_rank(),
48+
threading.current_thread().name,
49+
name,
50+
)
51+
dist.barrier()
52+
return
53+
except ImportError:
54+
pass
55+
multihost.sync_global_processes(name)

0 commit comments

Comments
 (0)