Skip to content

Commit 53ee910

Browse files
author
Orbax Authors
committed
Add benchmarks for P2P CheckpointManager.
PiperOrigin-RevId: 873854861
1 parent b192e7f commit 53ee910

File tree

3 files changed

+392
-0
lines changed

3 files changed

+392
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# The name for the entire test suite run.
2+
suite_name: "P2P CheckpointManager Benchmark"
3+
4+
mesh_configs:
5+
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6+
# ICI: Within a slice. Assuming 8 devices per slice.
7+
# DCN: Across slices.
8+
ici_parallelism: {"fsdp": 1, "tensor": 1, "data": 1}
9+
dcn_parallelism: {"data": 1} # num_slices on the axis at replica_axis_index
10+
process_is_granule: true
11+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
12+
ici_parallelism: {"data": 1, "model": 1}
13+
dcn_parallelism: {"data": 4, "model": 1}
14+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
15+
ici_parallelism: {"data": 1, "model": 16}
16+
dcn_parallelism: {"data": 2, "model": 1}
17+
allow_split_physical_axes: true
18+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
19+
ici_parallelism: {"data": 2, "model": 8}
20+
dcn_parallelism: {"data": 2, "model": 1}
21+
allow_split_physical_axes: true
22+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
23+
ici_parallelism: {"data": 2, "model": 4}
24+
dcn_parallelism: {"data": 2, "model": 1}
25+
allow_split_physical_axes: true
26+
27+
checkpoint_config:
28+
spec:
29+
a_1d: {dtype: "float32", shape: [32], sharding: [null]}
30+
b_1d: {dtype: "float32", shape: [32], sharding: ["tensor"]}
31+
c_2d: {dtype: "float32", shape: [32, 32], sharding: [null, "tensor"]}
32+
d_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", null]}
33+
e_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", "fsdp"]}
34+
f_2d: {dtype: "float32", shape: [32, 32], sharding: ["fsdp", "tensor"]}
35+
g_2d: {dtype: "float32", shape: [32, 32], sharding: [null, null]}
36+
h_3d: {dtype: "float32", shape: [32, 32, 32], sharding: ["tensor", null, "fsdp"]}
37+
i_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "tensor"]}
38+
j_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "fsdp"]}
39+
k_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, null]}
40+
custom_array: {dtype: "float32", shape: [8192, 64], sharding: ["tensor", null]}
41+
42+
benchmarks:
43+
- generator: "orbax.checkpoint._src.testing.benchmarks.p2p_checkpoint_manager_benchmark.P2pCheckpointManagerBenchmark"
44+
options:
45+
persistent_save_interval_steps: [2]
46+
persistent_max_to_keep: [5]
47+
local_save_interval_steps: [2]
48+
local_max_to_keep: 2
49+
replica_axis_index: 0
50+
train_steps: 5
51+
experimental_orbax_use_distributed_process_id: true
52+
experimental_use_distributed_id_for_mesh_consistency: true
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Benchmarks for P2P CheckpointManager.
16+
17+
This module contains benchmarks for
18+
orbax.checkpoint.experimental.emergency.p2p.checkpoint_manager.CheckpointManager.
19+
"""
20+
21+
from collections.abc import Sequence
22+
import dataclasses
23+
import inspect
24+
from typing import Any
25+
from absl import logging
26+
from etils import epath
27+
import jax
28+
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
29+
from orbax.checkpoint._src.multihost import multihost
30+
from orbax.checkpoint._src.serialization import type_handlers
31+
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
32+
from orbax.checkpoint._src.testing.benchmarks.core import mesh_utils
33+
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
34+
from orbax.checkpoint._src.testing.benchmarks.core import pytree_utils
35+
from orbax.checkpoint._src.tree import utils
36+
from orbax.checkpoint.experimental.emergency.p2p import args as p2p_args_lib
37+
from orbax.checkpoint.experimental.emergency.p2p import checkpoint_manager as p2p_checkpoint_manager
38+
from orbax.checkpoint.experimental.emergency.p2p import options as p2p_options
39+
40+
41+
# ==============================================================================
42+
# 1. Define the Options Dataclass
43+
# ==============================================================================
44+
@dataclasses.dataclass(frozen=True)
45+
class P2pBenchmarkOptions(benchmarks_core.BenchmarkOptions):
46+
"""Configuration options for benchmarks targeting P2P CheckpointManager.
47+
48+
Attributes:
49+
persistent_save_interval_steps: The interval at which persistent checkpoints
50+
should be saved.
51+
persistent_max_to_keep: The maximum number of persistent checkpoints to
52+
keep.
53+
local_save_interval_steps: The interval at which local checkpoints should be
54+
saved.
55+
local_max_to_keep: The maximum number of local checkpoints to keep.
56+
replica_axis_index: The index of the replica axis in the global mesh.
57+
train_steps: The number of training steps to run.
58+
"""
59+
60+
persistent_save_interval_steps: int | Sequence[int] = 5
61+
persistent_max_to_keep: int | Sequence[int] = 5
62+
local_save_interval_steps: int | Sequence[int] = 2
63+
local_max_to_keep: int | Sequence[int] = 2
64+
replica_axis_index: int | Sequence[int] = 0
65+
train_steps: int | Sequence[int] = 10
66+
experimental_use_distributed_id_for_mesh_consistency: (
67+
bool | Sequence[bool]
68+
) = True
69+
experimental_orbax_use_distributed_process_id: bool | Sequence[bool] = True
70+
71+
72+
# ==============================================================================
73+
# 2. Implement the Benchmark Generator
74+
# ==============================================================================
75+
def _create_checkpoint_manager(
76+
local_directory: epath.Path,
77+
persistent_directory: epath.Path,
78+
global_mesh: jax.sharding.Mesh,
79+
abstract_state: Any,
80+
options: P2pBenchmarkOptions,
81+
) -> p2p_checkpoint_manager.CheckpointManager:
82+
"""Creates an P2P CheckpointManager."""
83+
return p2p_checkpoint_manager.CheckpointManager(
84+
local_directory=local_directory,
85+
persistent_directory=persistent_directory,
86+
global_mesh=global_mesh,
87+
abstract_state=abstract_state,
88+
options=p2p_options.CheckpointManagerOptions(
89+
local=p2p_options.LocalCheckpointOptions(
90+
save_interval_steps=options.local_save_interval_steps,
91+
max_to_keep=options.local_max_to_keep,
92+
),
93+
persistent=p2p_options.PersistentCheckpointOptions(
94+
save_interval_steps=options.persistent_save_interval_steps,
95+
max_to_keep=options.persistent_max_to_keep,
96+
),
97+
replica_axis_index=options.replica_axis_index,
98+
),
99+
)
100+
101+
102+
def _restore_and_validate(
103+
manager: p2p_checkpoint_manager.CheckpointManager,
104+
metrics: metric_lib.Metrics,
105+
pytree: Any,
106+
step: int,
107+
local_directory: epath.Path,
108+
restore_args: Any,
109+
test_name: str = '',
110+
delete_before_restore: str = 'local_p0',
111+
):
112+
"""Restores a checkpoint and validates it."""
113+
prefix = f'{test_name}_' if test_name else ''
114+
# Wait for save to complete on all hosts.
115+
with metrics.measure(f'{prefix}sync_global_processes_{step}'):
116+
multihost.sync_global_processes(f'{prefix}save_completed_{step}')
117+
118+
step_dir = local_directory / str(step)
119+
if delete_before_restore == 'local_p0':
120+
if multihost.process_index() == 0 and step_dir.exists():
121+
logging.info(
122+
'Process 0: removing local checkpoint to trigger P2P restore.'
123+
)
124+
step_dir.rmtree()
125+
manager.reload()
126+
elif delete_before_restore == 'local_all':
127+
if step_dir.exists():
128+
logging.info(
129+
'All processes: removing local checkpoint to trigger GCS restore.'
130+
)
131+
step_dir.rmtree()
132+
manager.reload()
133+
elif delete_before_restore == 'none':
134+
logging.info('Skipping deletion of local checkpoint for local restore.')
135+
else:
136+
raise ValueError(
137+
f'Invalid delete_before_restore: {delete_before_restore}'
138+
)
139+
140+
logging.info('Not using restore args: %r', restore_args)
141+
142+
with metrics.measure(f'{prefix}restore_{step}'):
143+
restored = manager.restore(
144+
step,
145+
args=p2p_args_lib.Composite(
146+
state=pytree_checkpoint_handler.PyTreeRestoreArgs(
147+
restore_args=restore_args
148+
)
149+
),
150+
)['state']
151+
pytree_utils.log_pytree('Restored Pytree', restored)
152+
logging.info('Assert Restored Pytree')
153+
pytree_utils.assert_pytree_equal(pytree, restored)
154+
with metrics.measure(f'{prefix}reload_after_restore_{step}'):
155+
manager.reload()
156+
157+
158+
@benchmarks_core.benchmark_options(P2pBenchmarkOptions)
159+
class P2pCheckpointManagerBenchmark(benchmarks_core.BenchmarksGenerator):
160+
"""A generator for benchmarking P2P CheckpointManager."""
161+
162+
def _run_test(
163+
self,
164+
test_name: str,
165+
context: benchmarks_core.TestContext,
166+
metrics: metric_lib.Metrics,
167+
abstract_pytree: Any,
168+
restore_args: Any,
169+
delete_before_restore: str = 'local_p0',
170+
):
171+
"""Runs a single test case."""
172+
logging.info('Running test: %s', test_name)
173+
pytree = context.pytree
174+
persistent_directory = context.path / test_name / 'persistent_p2p_ckpt'
175+
if context.local_path is not None:
176+
local_path = epath.Path(context.local_path) / test_name / 'local_p2p_ckpt'
177+
local_directory = epath.Path(local_path)
178+
else:
179+
local_directory = (
180+
context.path
181+
/ test_name
182+
/ 'local_p2p_ckpt'
183+
/ f'process_{multihost.process_index()}'
184+
)
185+
options = context.options
186+
mesh = context.mesh
187+
assert isinstance(options, P2pBenchmarkOptions)
188+
189+
with metrics.measure(f'{test_name}_create_directories'):
190+
if jax.process_index() == 0:
191+
persistent_directory.mkdir(parents=True, exist_ok=True)
192+
local_directory.mkdir(parents=True, exist_ok=True)
193+
multihost.sync_global_processes(f'{test_name}_create_directories')
194+
195+
with metrics.measure(f'{test_name}_create_checkpoint_manager'):
196+
manager = _create_checkpoint_manager(
197+
local_directory=local_directory,
198+
persistent_directory=persistent_directory,
199+
global_mesh=mesh,
200+
abstract_state=abstract_pytree,
201+
options=options,
202+
)
203+
204+
step = manager.latest_step()
205+
if step is not None:
206+
logging.info('Latest step in test %s: %d', test_name, step)
207+
208+
with metrics.measure(f'{test_name}_restore_and_validate_{step}'):
209+
_restore_and_validate(
210+
manager,
211+
metrics,
212+
pytree,
213+
step,
214+
local_directory,
215+
restore_args,
216+
test_name=test_name,
217+
delete_before_restore=delete_before_restore,
218+
)
219+
220+
start_step = step + 1 if step is not None else 0
221+
with metrics.measure(f'{test_name}_train_loop'):
222+
for step in range(start_step, options.train_steps):
223+
logging.info('Test %s: Training step %d', test_name, step)
224+
with metrics.measure(f'{test_name}_save_{step}'):
225+
manager.save(
226+
step,
227+
args=p2p_args_lib.Composite(
228+
state=pytree_checkpoint_handler.PyTreeSaveArgs(pytree)
229+
),
230+
)
231+
with metrics.measure(f'{test_name}_wait_until_finished_{step}'):
232+
manager.wait_until_finished()
233+
234+
if step % options.local_save_interval_steps == 0 and step != 0:
235+
with metrics.measure(f'{test_name}_restore_and_validate_{step}'):
236+
_restore_and_validate(
237+
manager,
238+
metrics,
239+
pytree,
240+
step,
241+
local_directory,
242+
restore_args,
243+
test_name=test_name,
244+
delete_before_restore=delete_before_restore,
245+
)
246+
247+
manager.close()
248+
249+
def test_local_restore(
250+
self,
251+
context: benchmarks_core.TestContext,
252+
metrics: metric_lib.Metrics,
253+
abstract_pytree: Any,
254+
restore_args: Any,
255+
):
256+
self._run_test(
257+
'test_local_restore',
258+
context,
259+
metrics,
260+
abstract_pytree,
261+
restore_args,
262+
delete_before_restore='none',
263+
)
264+
265+
def test_p2p_restore(
266+
self,
267+
context: benchmarks_core.TestContext,
268+
metrics: metric_lib.Metrics,
269+
abstract_pytree: Any,
270+
restore_args: Any,
271+
):
272+
self._run_test(
273+
'test_p2p_restore',
274+
context,
275+
metrics,
276+
abstract_pytree,
277+
restore_args,
278+
delete_before_restore='local_p0',
279+
)
280+
281+
def test_gcs_restore(
282+
self,
283+
context: benchmarks_core.TestContext,
284+
metrics: metric_lib.Metrics,
285+
abstract_pytree: Any,
286+
restore_args: Any,
287+
):
288+
self._run_test(
289+
'test_gcs_restore',
290+
context,
291+
metrics,
292+
abstract_pytree,
293+
restore_args,
294+
delete_before_restore='local_all',
295+
)
296+
297+
def test_fn(
298+
self, context: benchmarks_core.TestContext
299+
) -> benchmarks_core.TestResult:
300+
"""The core test logic for a single save/restore cycle."""
301+
metrics = metric_lib.Metrics()
302+
pytree = context.pytree
303+
options = context.options
304+
mesh = context.mesh
305+
assert isinstance(options, P2pBenchmarkOptions)
306+
307+
if mesh is None:
308+
raise ValueError(
309+
'Mesh must be provided for P2pCheckpointManagerBenchmark'
310+
)
311+
if not multihost.is_runtime_to_distributed_ids_initialized():
312+
multihost.initialize_runtime_to_distributed_ids()
313+
314+
if not multihost.is_distributed_to_device_ids_initialized():
315+
multihost.initialize_distributed_to_device_ids()
316+
317+
mesh_utils.pretty_log_mesh('Global Mesh: ', mesh)
318+
319+
with metrics.measure('create_abstract_pytree'):
320+
abstract_pytree = jax.tree.map(utils.to_shape_dtype_struct, pytree)
321+
logging.info('abstract_pytree: %r', abstract_pytree)
322+
323+
with metrics.measure('create_restore_args'):
324+
restore_args = type_handlers.SingleReplicaArrayRestoreArgs()
325+
logging.info('restore_args: %r', restore_args)
326+
327+
tests_to_run = []
328+
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
329+
if name.startswith('test_') and name != 'test_fn':
330+
tests_to_run.append(method)
331+
332+
for test in tests_to_run:
333+
test(context, metrics, abstract_pytree, restore_args)
334+
335+
return benchmarks_core.TestResult(metrics=metrics)

0 commit comments

Comments
 (0)