Skip to content

Commit 44fbb39

Browse files
author
Orbax Authors
committed
Add benchmarks for P2P CheckpointManager.
PiperOrigin-RevId: 873854861
1 parent 99bfb4b commit 44fbb39

File tree

6 files changed

+892
-1
lines changed

6 files changed

+892
-1
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# The name for the entire test suite run.
2+
suite_name: "P2P CheckpointManager Benchmark"
3+
4+
mesh_configs:
5+
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6+
# ICI: Within a slice. Assuming 8 devices per slice.
7+
# DCN: Across slices.
8+
ici_parallelism: {"fsdp": 1, "tensor": 1, "data": 1}
9+
dcn_parallelism: {"data": 1} # num_slices on the axis at replica_axis_index
10+
process_is_granule: true
11+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
12+
ici_parallelism: {"data": 1, "model": 1}
13+
dcn_parallelism: {"data": 4, "model": 1}
14+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
15+
ici_parallelism: {"data": 1, "model": 16}
16+
dcn_parallelism: {"data": 2, "model": 1}
17+
allow_split_physical_axes: true
18+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
19+
ici_parallelism: {"data": 2, "model": 8}
20+
dcn_parallelism: {"data": 4, "model": 1}
21+
allow_split_physical_axes: true
22+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
23+
ici_parallelism: {"data": 2, "model": 4}
24+
dcn_parallelism: {"data": 2, "model": 1}
25+
allow_split_physical_axes: true
26+
27+
checkpoint_config:
28+
spec:
29+
a_1d: {dtype: "float32", shape: [32], sharding: [null]}
30+
b_1d: {dtype: "float32", shape: [32], sharding: ["tensor"]}
31+
c_2d: {dtype: "float32", shape: [32, 32], sharding: [null, "tensor"]}
32+
d_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", null]}
33+
e_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", "fsdp"]}
34+
f_2d: {dtype: "float32", shape: [32, 32], sharding: ["fsdp", "tensor"]}
35+
g_2d: {dtype: "float32", shape: [32, 32], sharding: [null, null]}
36+
h_3d: {dtype: "float32", shape: [32, 32, 32], sharding: ["tensor", null, "fsdp"]}
37+
i_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "tensor"]}
38+
j_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "fsdp"]}
39+
k_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, null]}
40+
custom_array: {dtype: "float32", shape: [8192, 64], sharding: ["tensor", null]}
41+
42+
benchmarks:
43+
- generator: "orbax.checkpoint._src.testing.benchmarks.p2p_checkpoint_manager_benchmark.P2pCheckpointManagerBenchmark"
44+
options:
45+
persistent_save_interval_steps: [2]
46+
persistent_max_to_keep: [5]
47+
local_save_interval_steps: [2]
48+
local_max_to_keep: 2
49+
replica_axis_index: 0
50+
train_steps: 5
51+
experimental_orbax_use_distributed_process_id: true
52+
experimental_use_distributed_id_for_mesh_consistency: true

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def run(self, repeat_index: int | None = None) -> TestResult:
159159
path = directory_setup.setup_test_directory(
160160
self.name, self.output_dir, repeat_index
161161
)
162+
local_path = epath.Path(self.local_directory) / name
163+
if repeat_index is not None:
164+
local_path = local_path / f"repeat_{repeat_index}"
162165

163166
with benchmark_metrics.measure(
164167
"sync_global_processes:benchmark:setup_test_directory"
@@ -185,7 +188,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
185188
options=self.options,
186189
mesh=self.mesh,
187190
repeat_index=repeat_index,
188-
local_path=self.local_directory,
191+
local_path=local_path,
189192
)
190193

191194
test_context_summary = self._build_test_context_summary(context)

0 commit comments

Comments
 (0)