Skip to content

Commit 28876f6

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Restore and broadcast benchmark.
PiperOrigin-RevId: 873114079
1 parent 9b22138 commit 28876f6

35 files changed

+1157
-56
lines changed

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-405b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ checkpoint_config:
1313
path: "gs://orbax-benchmarks/checkpoints/llama-3.1-405B-checkpoints/0/items"
1414

1515
benchmarks:
16-
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
1717
options:
1818
# --- Generator Options ---
19-
# These keys must match the attributes of the `V1BenchmarkOptions` class
20-
# associated with the `V1Benchmark` generator.
19+
# These keys must match the attributes of the `BenchmarkOptions` class
20+
# associated with the `Benchmark` generator.
2121
async_enabled: true
2222
use_ocdbt: true
2323
use_zarr3: true

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-70b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ checkpoint_config:
1313
path: "gs://orbax-benchmarks/checkpoints/llama-3.1-70B-checkpoints/0/items"
1414

1515
benchmarks:
16-
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
1717
options:
1818
# --- Generator Options ---
19-
# These keys must match the attributes of the `V1BenchmarkOptions` class
20-
# associated with the `V1Benchmark` generator.
19+
# These keys must match the attributes of the `BenchmarkOptions` class
20+
# associated with the `Benchmark` generator.
2121
async_enabled: true
2222
use_ocdbt: true
2323
use_zarr3: true

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-8b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ checkpoint_config:
1212
path: "gs://orbax-benchmarks/checkpoints/llama-3.1-8B-checkpoints/0/items"
1313

1414
benchmarks:
15-
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
15+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
1616
options:
1717
# --- Generator Options ---
18-
# These keys must match the attributes of the `V1BenchmarkOptions` class
19-
# associated with the `V1Benchmark` generator.
18+
# These keys must match the attributes of the `BenchmarkOptions` class
19+
# associated with the `Benchmark` generator.
2020
async_enabled: true
2121
use_ocdbt: true
2222
use_zarr3: true
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# The name for the entire test suite run.
2+
# Assumes v5p-256 (128 chips)
3+
suite_name: "Llama 3.1 70B"
4+
num_repeats: 1
5+
6+
7+
mesh_config:
8+
mesh_axes: ["data", "fsdp", "tensor"]
9+
# SHould match reference_sharding_path.
10+
ici_parallelism: {"data": 4, "fsdp": 8, "tensor": 4}
11+
12+
# Note: checkpoint_config field not specified.
13+
checkpoint_config:
14+
path: "gs://orbax-benchmarks/checkpoints/llama-3.1-70B-checkpoints/0/items"
15+
sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-256-data-4-fsdp-8-tensor-4/abstract_state.json"
16+
17+
benchmarks:
18+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
19+
options:
20+
# --- Generator Options ---
21+
# These keys must match the attributes of the `BenchmarkOptions` class
22+
# associated with the `Benchmark` generator.
23+
async_enabled: true
24+
use_ocdbt: true
25+
use_zarr3: true
26+
use_replica_parallel: false
27+
use_compression: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# The name for the entire test suite run.
2+
# Assumes v5p-256 (128 chips)
3+
suite_name: "Llama 3.1 70B"
4+
num_repeats: 1
5+
6+
7+
mesh_config:
8+
mesh_axes: ["data", "fsdp", "tensor"]
9+
# SHould match reference_sharding_path.
10+
ici_parallelism: {"data": 4, "fsdp": 8, "tensor": 4}
11+
12+
# Note: checkpoint_config field not specified.
13+
checkpoint_config:
14+
path: "gs://orbax-benchmarks/checkpoints/llama-3.1-70B-checkpoints/0/items"
15+
sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-256-data-4-fsdp-8-tensor-4/abstract_state.json"
16+
17+
benchmarks:
18+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
19+
options:
20+
# --- Generator Options ---
21+
# These keys must match the attributes of the `BenchmarkOptions` class
22+
# associated with the `Benchmark` generator.
23+
async_enabled: true
24+
use_ocdbt: true
25+
use_zarr3: true
26+
use_replica_parallel: false
27+
use_compression: true
28+
chunk_byte_size: 33554432 # 32 MiB
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# The name for the entire test suite run.
2+
# Assumes v5p-256 (128 chips)
3+
suite_name: "llama-70b_reshard_4-8-4_subchunked_to_1-128-1"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["data", "fsdp", "tensor"]
9+
# SHould match reference_sharding_path.
10+
ici_parallelism: {"data": 1, "fsdp": 128, "tensor": 1}
11+
12+
# Note: checkpoint_config field not specified.
13+
14+
benchmarks:
15+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
16+
options:
17+
# --- Generator Options ---
18+
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
19+
# associated with the `ReshardingBenchmark` generator.
20+
async_enabled: true
21+
use_ocdbt: true
22+
use_zarr3: true
23+
use_replica_parallel: false
24+
use_compression: true
25+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
26+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-256-data-1-fsdp-128-tensor-1/abstract_state.json"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# The name for the entire test suite run.
2+
# Assumes v5p-64 (32 chips)
3+
suite_name: "llama-70b_reshard_4-8-4_to_1-32-1"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["data", "fsdp", "tensor"]
9+
# SHould match reference_sharding_path.
10+
ici_parallelism: {"data": 1, "fsdp": 32, "tensor": 1}
11+
12+
# Note: checkpoint_config field not specified.
13+
14+
benchmarks:
15+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
16+
options:
17+
# --- Generator Options ---
18+
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
19+
# associated with the `ReshardingBenchmark` generator.
20+
async_enabled: true
21+
use_ocdbt: true
22+
use_zarr3: true
23+
use_replica_parallel: false
24+
use_compression: true
25+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
26+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-64-data-1-fsdp-32-tensor-1/abstract_state.json"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# The name for the entire test suite run.
2+
# Assumes v5p-256 (128 chips)
3+
suite_name: "llama-70b_reshard_4-8-4_to_1-128-1"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["data", "fsdp", "tensor"]
9+
# SHould match reference_sharding_path.
10+
ici_parallelism: {"data": 1, "fsdp": 128, "tensor": 1}
11+
12+
# Note: checkpoint_config field not specified.
13+
14+
benchmarks:
15+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
16+
options:
17+
# --- Generator Options ---
18+
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
19+
# associated with the `ReshardingBenchmark` generator.
20+
async_enabled: true
21+
use_ocdbt: true
22+
use_zarr3: true
23+
use_replica_parallel: false
24+
use_compression: true
25+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4/ckpt"
26+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-256-data-1-fsdp-128-tensor-1/abstract_state.json"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# The name for the entire test suite run.
2+
# Assumes v5p-64 (32 chips)
3+
suite_name: "llama-70b_reshard_4-8-4_to_1-32-1"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["data", "fsdp", "tensor"]
9+
# SHould match reference_sharding_path.
10+
ici_parallelism: {"data": 1, "fsdp": 32, "tensor": 1}
11+
12+
# Note: checkpoint_config field not specified.
13+
14+
benchmarks:
15+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
16+
options:
17+
# --- Generator Options ---
18+
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
19+
# associated with the `ReshardingBenchmark` generator.
20+
async_enabled: true
21+
use_ocdbt: true
22+
use_zarr3: true
23+
use_replica_parallel: false
24+
use_compression: true
25+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4/ckpt"
26+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-64-data-1-fsdp-32-tensor-1/abstract_state.json"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# The name for the entire test suite run.
2+
# Assumes v5p-256 (128 chips)
3+
# This benchmark is the no-resharding baseline.
4+
suite_name: "llama-70b_reshard_4-8-4_to_4-8-4"
5+
num_repeats: 20
6+
7+
8+
mesh_config:
9+
mesh_axes: ["data", "fsdp", "tensor"]
10+
# SHould match reference_sharding_path.
11+
ici_parallelism: {"data": 4, "fsdp": 8, "tensor": 4}
12+
13+
# Note: checkpoint_config field not specified.
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
17+
options:
18+
# --- Generator Options ---
19+
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
20+
# associated with the `ReshardingBenchmark` generator.
21+
async_enabled: true
22+
use_ocdbt: true
23+
use_zarr3: true
24+
use_replica_parallel: false
25+
use_compression: true
26+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4/ckpt"
27+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-256-data-4-fsdp-8-tensor-4/abstract_state.json"

0 commit comments

Comments
 (0)