Skip to content

Commit 04c9e57

Browse files
author
Orbax Authors
committed
Add benchmarks for P2P CheckpointManager.
PiperOrigin-RevId: 873854861
1 parent 99bfb4b commit 04c9e57

File tree

5 files changed

+499
-0
lines changed

5 files changed

+499
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# The name for the entire test suite run.
2+
suite_name: "P2P CheckpointManager Benchmark"
3+
4+
mesh_configs:
5+
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6+
# ICI: Within a slice. Assuming 8 devices per slice.
7+
# DCN: Across slices.
8+
ici_parallelism: {"fsdp": 1, "tensor": 1, "data": 1}
9+
dcn_parallelism: {"data": 1} # num_slices on the axis at replica_axis_index
10+
process_is_granule: true
11+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
12+
ici_parallelism: {"data": 1, "model": 1}
13+
dcn_parallelism: {"data": 4, "model": 1}
14+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
15+
ici_parallelism: {"data": 1, "model": 16}
16+
dcn_parallelism: {"data": 2, "model": 1}
17+
allow_split_physical_axes: true
18+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
19+
ici_parallelism: {"data": 2, "model": 8}
20+
dcn_parallelism: {"data": 2, "model": 1}
21+
allow_split_physical_axes: true
22+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
23+
ici_parallelism: {"data": 2, "model": 4}
24+
dcn_parallelism: {"data": 2, "model": 1}
25+
allow_split_physical_axes: true
26+
27+
checkpoint_config:
28+
spec:
29+
a_1d: {dtype: "float32", shape: [32], sharding: [null]}
30+
b_1d: {dtype: "float32", shape: [32], sharding: ["tensor"]}
31+
c_2d: {dtype: "float32", shape: [32, 32], sharding: [null, "tensor"]}
32+
d_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", null]}
33+
e_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", "fsdp"]}
34+
f_2d: {dtype: "float32", shape: [32, 32], sharding: ["fsdp", "tensor"]}
35+
g_2d: {dtype: "float32", shape: [32, 32], sharding: [null, null]}
36+
h_3d: {dtype: "float32", shape: [32, 32, 32], sharding: ["tensor", null, "fsdp"]}
37+
i_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "tensor"]}
38+
j_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "fsdp"]}
39+
k_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, null]}
40+
custom_array: {dtype: "float32", shape: [8192, 64], sharding: ["tensor", null]}
41+
42+
benchmarks:
43+
- generator: "orbax.checkpoint._src.testing.benchmarks.p2p_checkpoint_manager_benchmark.P2pCheckpointManagerBenchmark"
44+
options:
45+
persistent_save_interval_steps: [2]
46+
persistent_max_to_keep: [5]
47+
local_save_interval_steps: [2]
48+
local_max_to_keep: 2
49+
replica_axis_index: 0
50+
train_steps: 5
51+
experimental_orbax_use_distributed_process_id: true
52+
experimental_use_distributed_id_for_mesh_consistency: true
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Benchmarks for P2P CheckpointManager.
16+
17+
This module contains benchmarks for
18+
orbax.checkpoint.experimental.emergency.p2p.checkpoint_manager.CheckpointManager.
19+
"""
20+
21+
from collections.abc import Sequence
22+
import dataclasses
23+
import inspect
24+
from typing import Any
25+
from absl import logging
26+
from etils import epath
27+
import jax
28+
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
29+
from orbax.checkpoint._src.multihost import multihost
30+
from orbax.checkpoint._src.serialization import type_handlers
31+
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
32+
from orbax.checkpoint._src.testing.benchmarks.core import mesh_utils
33+
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
34+
from orbax.checkpoint._src.testing.benchmarks.core import pytree_utils
35+
from orbax.checkpoint._src.tree import utils
36+
from orbax.checkpoint.experimental.emergency.p2p import args as p2p_args_lib
37+
from orbax.checkpoint.experimental.emergency.p2p import checkpoint_manager as p2p_checkpoint_manager
38+
from orbax.checkpoint.experimental.emergency.p2p import options as p2p_options
39+
40+
41+
# ==============================================================================
42+
# 1. Define the Options Dataclass
43+
# ==============================================================================
44+
@dataclasses.dataclass(frozen=True)
45+
class P2pBenchmarkOptions(benchmarks_core.BenchmarkOptions):
46+
"""Configuration options for benchmarks targeting P2P CheckpointManager.
47+
48+
Attributes:
49+
persistent_save_interval_steps: The interval at which persistent checkpoints
50+
should be saved.
51+
persistent_max_to_keep: The maximum number of persistent checkpoints to
52+
keep.
53+
local_save_interval_steps: The interval at which local checkpoints should be
54+
saved.
55+
local_max_to_keep: The maximum number of local checkpoints to keep.
56+
replica_axis_index: The index of the replica axis in the global mesh.
57+
train_steps: The number of training steps to run.
58+
"""
59+
60+
persistent_save_interval_steps: int | Sequence[int] = 5
61+
persistent_max_to_keep: int | Sequence[int] = 5
62+
local_save_interval_steps: int | Sequence[int] = 2
63+
local_max_to_keep: int | Sequence[int] = 2
64+
replica_axis_index: int | Sequence[int] = 0
65+
train_steps: int | Sequence[int] = 10
66+
experimental_use_distributed_id_for_mesh_consistency: (
67+
bool | Sequence[bool]
68+
) = True
69+
experimental_orbax_use_distributed_process_id: bool | Sequence[bool] = True
70+
71+
72+
# ==============================================================================
73+
# 2. Implement the Benchmark Generator
74+
# ==============================================================================
75+
def _create_checkpoint_manager(
76+
local_directory: epath.Path,
77+
persistent_directory: epath.Path,
78+
global_mesh: jax.sharding.Mesh,
79+
abstract_state: Any,
80+
options: P2pBenchmarkOptions,
81+
) -> p2p_checkpoint_manager.CheckpointManager:
82+
"""Creates an P2P CheckpointManager."""
83+
return p2p_checkpoint_manager.CheckpointManager(
84+
local_directory=local_directory,
85+
persistent_directory=persistent_directory,
86+
global_mesh=global_mesh,
87+
abstract_state=abstract_state,
88+
options=p2p_options.CheckpointManagerOptions(
89+
local=p2p_options.LocalCheckpointOptions(
90+
save_interval_steps=options.local_save_interval_steps,
91+
max_to_keep=options.local_max_to_keep,
92+
),
93+
persistent=p2p_options.PersistentCheckpointOptions(
94+
save_interval_steps=options.persistent_save_interval_steps,
95+
max_to_keep=options.persistent_max_to_keep,
96+
),
97+
replica_axis_index=options.replica_axis_index,
98+
),
99+
)
100+
101+
102+
def _restore_and_validate(
103+
manager: p2p_checkpoint_manager.CheckpointManager,
104+
metrics: metric_lib.Metrics,
105+
pytree: Any,
106+
abstract_pytree: Any,
107+
step: int,
108+
local_directory: epath.Path,
109+
restore_args: Any,
110+
test_name: str = '',
111+
delete_before_restore: str = 'local_p0',
112+
):
113+
"""Restores a checkpoint and validates it."""
114+
prefix = f'{test_name}_' if test_name else ''
115+
# Wait for save to complete on all hosts.
116+
with metrics.measure(f'{prefix}sync_global_processes_{step}'):
117+
multihost.sync_global_processes(f'{prefix}save_completed_{step}')
118+
119+
step_dir = local_directory / str(step)
120+
if delete_before_restore == 'local_p0':
121+
if multihost.process_index() == 0 and step_dir.exists():
122+
logging.info(
123+
'Process 0: removing local checkpoint to trigger P2P restore.'
124+
)
125+
step_dir.rmtree()
126+
manager.reload()
127+
elif delete_before_restore == 'local_all':
128+
if step_dir.exists():
129+
logging.info(
130+
'All processes: removing local checkpoint to trigger GCS restore.'
131+
)
132+
step_dir.rmtree()
133+
manager.reload()
134+
elif delete_before_restore == 'none':
135+
logging.info('Skipping deletion of local checkpoint for local restore.')
136+
else:
137+
raise ValueError(
138+
f'Invalid delete_before_restore: {delete_before_restore}'
139+
)
140+
141+
logging.info('Not using restore args: %r', restore_args)
142+
143+
with metrics.measure(f'{prefix}restore_{step}'):
144+
restored = manager.restore(
145+
step,
146+
args=p2p_args_lib.Composite(
147+
state=pytree_checkpoint_handler.PyTreeRestoreArgs(
148+
restore_args=restore_args,
149+
item=abstract_pytree,
150+
)
151+
),
152+
)['state']
153+
pytree_utils.log_pytree('Restored Pytree', restored)
154+
logging.info('Assert Restored Pytree')
155+
pytree_utils.assert_pytree_equal(pytree, restored)
156+
with metrics.measure(f'{prefix}reload_after_restore_{step}'):
157+
manager.reload()
158+
159+
160+
@benchmarks_core.benchmark_options(P2pBenchmarkOptions)
161+
class P2pCheckpointManagerBenchmark(benchmarks_core.BenchmarksGenerator):
162+
"""A generator for benchmarking P2P CheckpointManager."""
163+
164+
def _run_test(
165+
self,
166+
test_name: str,
167+
context: benchmarks_core.TestContext,
168+
metrics: metric_lib.Metrics,
169+
abstract_pytree: Any,
170+
restore_args: Any,
171+
delete_before_restore: str = 'local_p0',
172+
):
173+
"""Runs a single test case."""
174+
logging.info('Running test: %s', test_name)
175+
pytree = context.pytree
176+
persistent_directory = context.path / test_name / 'persistent_p2p_ckpt'
177+
if context.local_path is not None:
178+
local_path = epath.Path(context.local_path) / test_name / 'local_p2p_ckpt'
179+
local_directory = epath.Path(local_path)
180+
else:
181+
local_directory = (
182+
context.path
183+
/ test_name
184+
/ 'local_p2p_ckpt'
185+
/ f'process_{multihost.process_index()}'
186+
)
187+
options = context.options
188+
mesh = context.mesh
189+
assert isinstance(options, P2pBenchmarkOptions)
190+
191+
with metrics.measure(f'{test_name}_create_directories'):
192+
if jax.process_index() == 0:
193+
persistent_directory.mkdir(parents=True, exist_ok=True)
194+
local_directory.mkdir(parents=True, exist_ok=True)
195+
multihost.sync_global_processes(f'{test_name}_create_directories')
196+
197+
with metrics.measure(f'{test_name}_create_checkpoint_manager'):
198+
manager = _create_checkpoint_manager(
199+
local_directory=local_directory,
200+
persistent_directory=persistent_directory,
201+
global_mesh=mesh,
202+
abstract_state=abstract_pytree,
203+
options=options,
204+
)
205+
206+
step = manager.latest_step()
207+
if step is not None:
208+
logging.info('Latest step in test %s: %d', test_name, step)
209+
210+
with metrics.measure(f'{test_name}_restore_and_validate_{step}'):
211+
_restore_and_validate(
212+
manager,
213+
metrics,
214+
pytree,
215+
abstract_pytree,
216+
step,
217+
local_directory,
218+
restore_args,
219+
test_name=test_name,
220+
delete_before_restore=delete_before_restore,
221+
)
222+
223+
start_step = step + 1 if step is not None else 0
224+
with metrics.measure(f'{test_name}_train_loop'):
225+
for step in range(start_step, options.train_steps):
226+
logging.info('Test %s: Training step %d', test_name, step)
227+
with metrics.measure(f'{test_name}_save_{step}'):
228+
manager.save(
229+
step,
230+
args=p2p_args_lib.Composite(
231+
state=pytree_checkpoint_handler.PyTreeSaveArgs(pytree)
232+
),
233+
)
234+
with metrics.measure(f'{test_name}_wait_until_finished_{step}'):
235+
manager.wait_until_finished()
236+
237+
if step % options.local_save_interval_steps == 0 and step != 0:
238+
with metrics.measure(f'{test_name}_restore_and_validate_{step}'):
239+
_restore_and_validate(
240+
manager,
241+
metrics,
242+
pytree,
243+
abstract_pytree,
244+
step,
245+
local_directory,
246+
restore_args,
247+
test_name=test_name,
248+
delete_before_restore=delete_before_restore,
249+
)
250+
251+
manager.close()
252+
253+
def test_local_restore(
254+
self,
255+
context: benchmarks_core.TestContext,
256+
metrics: metric_lib.Metrics,
257+
abstract_pytree: Any,
258+
restore_args: Any,
259+
):
260+
self._run_test(
261+
'test_local_restore',
262+
context,
263+
metrics,
264+
abstract_pytree,
265+
restore_args,
266+
delete_before_restore='none',
267+
)
268+
269+
def test_p2p_restore(
270+
self,
271+
context: benchmarks_core.TestContext,
272+
metrics: metric_lib.Metrics,
273+
abstract_pytree: Any,
274+
restore_args: Any,
275+
):
276+
self._run_test(
277+
'test_p2p_restore',
278+
context,
279+
metrics,
280+
abstract_pytree,
281+
restore_args,
282+
delete_before_restore='local_p0',
283+
)
284+
285+
def test_gcs_restore(
286+
self,
287+
context: benchmarks_core.TestContext,
288+
metrics: metric_lib.Metrics,
289+
abstract_pytree: Any,
290+
restore_args: Any,
291+
):
292+
self._run_test(
293+
'test_gcs_restore',
294+
context,
295+
metrics,
296+
abstract_pytree,
297+
restore_args,
298+
delete_before_restore='local_all',
299+
)
300+
301+
def test_fn(
302+
self, context: benchmarks_core.TestContext
303+
) -> benchmarks_core.TestResult:
304+
"""The core test logic for a single save/restore cycle."""
305+
metrics = metric_lib.Metrics()
306+
pytree = context.pytree
307+
options = context.options
308+
mesh = context.mesh
309+
assert isinstance(options, P2pBenchmarkOptions)
310+
311+
if mesh is None:
312+
raise ValueError(
313+
'Mesh must be provided for P2pCheckpointManagerBenchmark'
314+
)
315+
if not multihost.is_runtime_to_distributed_ids_initialized():
316+
multihost.initialize_runtime_to_distributed_ids()
317+
318+
if not multihost.is_distributed_to_device_ids_initialized():
319+
multihost.initialize_distributed_to_device_ids()
320+
321+
mesh_utils.pretty_log_mesh('Global Mesh: ', mesh)
322+
323+
with metrics.measure('create_abstract_pytree'):
324+
abstract_pytree = jax.tree.map(utils.to_shape_dtype_struct, pytree)
325+
logging.info('abstract_pytree: %r', abstract_pytree)
326+
327+
with metrics.measure('create_restore_args'):
328+
restore_args = type_handlers.SingleReplicaArrayRestoreArgs()
329+
logging.info('restore_args: %r', restore_args)
330+
331+
tests_to_run = []
332+
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
333+
if name.startswith('test_') and name != 'test_fn':
334+
tests_to_run.append(method)
335+
336+
for test in tests_to_run:
337+
test(context, metrics, abstract_pytree, restore_args)
338+
339+
return benchmarks_core.TestResult(metrics=metrics)

0 commit comments

Comments
 (0)