Skip to content

Commit 069bb5f

Browse files
Nikhil BansalOrbax Authors
authored andcommitted
Perf testing P2P
PiperOrigin-RevId: 878947396
1 parent 36f6735 commit 069bb5f

File tree

7 files changed

+883
-5
lines changed

7 files changed

+883
-5
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# The name for the entire test suite run.
2+
suite_name: "P2P CheckpointManager Benchmark"
3+
4+
mesh_configs:
5+
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6+
# ICI: Within a slice. Assuming 8 devices per slice.
7+
# DCN: Across slices.
8+
ici_parallelism: {"fsdp": 1, "tensor": 1, "data": 1}
9+
dcn_parallelism: {"data": 1} # num_slices on the axis at replica_axis_index
10+
process_is_granule: true
11+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
12+
ici_parallelism: {"data": 1, "model": 1}
13+
dcn_parallelism: {"data": 4, "model": 1}
14+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
15+
ici_parallelism: {"data": 1, "model": 16}
16+
dcn_parallelism: {"data": 2, "model": 2}
17+
allow_split_physical_axes: true
18+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
19+
ici_parallelism: {"data": 1, "model": 16}
20+
dcn_parallelism: {"data": 2, "model": 1}
21+
allow_split_physical_axes: true
22+
- mesh_axes: ["data", "model", "tensor", "fsdp"]
23+
ici_parallelism: {"data": 2, "model": 4}
24+
dcn_parallelism: {"data": 2, "model": 1}
25+
allow_split_physical_axes: true
26+
27+
# checkpoint_config:
28+
# spec:
29+
# a_1d: {dtype: "float32", shape: [32], sharding: [null]}
30+
# b_1d: {dtype: "float32", shape: [32], sharding: ["tensor"]}
31+
# c_2d: {dtype: "float32", shape: [32, 32], sharding: [null, "tensor"]}
32+
# d_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", null]}
33+
# e_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", "fsdp"]}
34+
# f_2d: {dtype: "float32", shape: [32, 32], sharding: ["fsdp", "tensor"]}
35+
# g_2d: {dtype: "float32", shape: [32, 32], sharding: [null, null]}
36+
# h_3d: {dtype: "float32", shape: [32, 32, 32], sharding: ["tensor", null, "fsdp"]}
37+
# i_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "tensor"]}
38+
# j_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "fsdp"]}
39+
# k_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, null]}
40+
# custom_array: {dtype: "float32", shape: [8192, 64], sharding: ["tensor", null]}
41+
42+
checkpoint_config:
43+
path: "gs://safetensor-kimi-central/test_model_orbax/llama-3.1-70B-checkpoints/0/items/items"
44+
45+
benchmarks:
46+
- generator: "orbax.checkpoint._src.testing.benchmarks.p2p_checkpoint_manager_benchmark.P2pCheckpointManagerBenchmark"
47+
options:
48+
persistent_save_interval_steps: [2]
49+
persistent_max_to_keep: [5]
50+
local_save_interval_steps: [2]
51+
local_max_to_keep: 1
52+
replica_axis_index: 0
53+
train_steps: 3
54+
experimental_orbax_use_distributed_process_id: true
55+
experimental_use_distributed_id_for_mesh_consistency: true
56+
tests_to_run: ["test_local_restore"]

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def run(self, repeat_index: int | None = None) -> TestResult:
159159
path = directory_setup.setup_test_directory(
160160
self.name, self.output_dir, repeat_index
161161
)
162+
local_path = None
163+
if self.local_directory is not None:
164+
local_path = epath.Path(self.local_directory) / name
165+
if repeat_index is not None:
166+
local_path = local_path / f"repeat_{repeat_index}"
162167

163168
with benchmark_metrics.measure(
164169
"sync_global_processes:benchmark:setup_test_directory"
@@ -185,7 +190,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
185190
options=self.options,
186191
mesh=self.mesh,
187192
repeat_index=repeat_index,
188-
local_path=self.local_directory,
193+
local_path=local_path,
189194
)
190195

191196
test_context_summary = self._build_test_context_summary(context)

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/directory_setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def setup_test_directory(
4040
path = path / f"repeat_{repeat_index}"
4141
logging.info("Setting up test directory at: %s", path)
4242
if multihost.get_process_index() == 0:
43-
if path.exists():
43+
if path.exists() and not base_path.startswith("gs://"):
4444
logging.warning("Test directory %s already exists. Deleting it.", path)
4545
path.rmtree()
46-
path.mkdir(parents=True, exist_ok=False)
46+
path.mkdir(parents=True, exist_ok=True)
4747
return path

0 commit comments

Comments
 (0)