Skip to content

Commit 41df5ed

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Restore and broadcast benchmark.
PiperOrigin-RevId: 873114079
1 parent 60b50ba commit 41df5ed

23 files changed

+1164
-4
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
---
2+
# 1. Headless Service: Required for distributed pods to discover each other
3+
apiVersion: v1
4+
kind: Service
5+
metadata:
6+
name: ${JOB_NAME}
7+
namespace: default
8+
spec:
9+
clusterIP: None
10+
selector:
11+
job-name: ${JOB_NAME}
12+
---
13+
# 2. Indexed Job: Manages the distributed workload and queues via Kueue
14+
apiVersion: batch/v1
15+
kind: Job
16+
metadata:
17+
name: ${JOB_NAME}
18+
namespace: default
19+
labels:
20+
kueue.x-k8s.io/queue-name: multislice-queue
21+
spec:
22+
completions: ${TOTAL_PODS}
23+
parallelism: ${TOTAL_PODS}
24+
completionMode: Indexed
25+
template:
26+
metadata:
27+
labels:
28+
job-name: ${JOB_NAME}
29+
spec:
30+
subdomain: ${JOB_NAME}
31+
restartPolicy: Never
32+
containers:
33+
- name: benchmark
34+
image: ${IMAGE}
35+
36+
# ---> IMPORTANT: UPDATE THIS COMMAND <---
37+
command:
38+
- "python3"
39+
- "/path/to/your/benchmark_script.py"
40+
- "--config_file=${FULL_CONFIG_PATH}"
41+
- "--output_directory=${OUTPUT_DIR}"
42+
43+
# 3. Distributed Setup: Injecting JAX environment variables natively
44+
env:
45+
- name: JAX_COORDINATOR_ADDRESS
46+
value: "${JOB_NAME}-0.${JOB_NAME}.default.svc.cluster.local"
47+
- name: JAX_COORDINATOR_PORT
48+
value: "1234"
49+
- name: JAX_PROCESS_COUNT
50+
value: "${TOTAL_PODS}"
51+
- name: JAX_PROCESS_INDEX
52+
valueFrom:
53+
fieldRef:
54+
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
55+
56+
# 4. Resource constraint tailored to your cluster
57+
resources:
58+
requests:
59+
cpu: "1"
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
apiVersion: kueue.x-k8s.io/v1beta1
2+
kind: ResourceFlavor
3+
metadata:
4+
name: "spot-flavor"
5+
---
6+
apiVersion: kueue.x-k8s.io/v1beta1
7+
kind: ClusterQueue
8+
metadata:
9+
name: "xpk-cluster-queue"
10+
spec:
11+
namespaceSelector: {} # Allows jobs from any namespace
12+
resourceGroups:
13+
- coveredResources: ["cpu", "memory"]
14+
flavors:
15+
- name: "spot-flavor"
16+
resources:
17+
- name: "cpu"
18+
nominalQuota: 1000 # Set artificially high to allow scaling
19+
- name: "memory"
20+
nominalQuota: 4000Gi
21+
---
22+
apiVersion: kueue.x-k8s.io/v1beta1
23+
kind: LocalQueue
24+
metadata:
25+
name: "multislice-queue" # XPK strictly looks for this name by default
26+
namespace: "default"
27+
spec:
28+
clusterQueue: "xpk-cluster-queue"
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# The name for the entire test suite run.
2+
# Assumes n2-standard-32-32 (32 machines) X 16 replicas
3+
suite_name: "llama-70b_replicas_16"
4+
num_repeats: 1
5+
6+
7+
mesh_config:
8+
mesh_axes: ["replica", "model"]
9+
# Should match reference_sharding_path.
10+
ici_parallelism: {"replica": 1, "model": 32}
11+
dcn_parallelism: {"replica": 16}
12+
13+
# Note: checkpoint_config field not specified.
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
17+
options:
18+
# --- Generator Options ---
19+
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
20+
# associated with the `RestoreAndBroadcastBenchmark` generator.
21+
async_enabled: true
22+
use_ocdbt: true
23+
use_zarr3: true
24+
use_replica_parallel: false
25+
use_compression: true
26+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
27+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
28+
use_load_and_broadcast: true
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# The name for the entire test suite run.
2+
# Assumes n2-standard-2-64 (64 machines) X 2 replicas
3+
suite_name: "llama-70b_replicas_2"
4+
num_repeats: 1
5+
6+
7+
mesh_config:
8+
mesh_axes: ["replica", "model"]
9+
# Should match reference_sharding_path.
10+
ici_parallelism: {"replica": 1, "model": 64}
11+
dcn_parallelism: {"replica": 2}
12+
13+
# Note: checkpoint_config field not specified.
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
17+
options:
18+
# --- Generator Options ---
19+
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
20+
# associated with the `RestoreAndBroadcastBenchmark` generator.
21+
async_enabled: true
22+
use_ocdbt: true
23+
use_zarr3: true
24+
use_replica_parallel: false
25+
use_compression: true
26+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
27+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
28+
use_load_and_broadcast: true
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# The name for the entire test suite run.
2+
# Assumes n2-standard-32-32 (32 machines) X 4 replicas
3+
suite_name: "llama-70b_replicas_4"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["replica", "model"]
9+
# Should match reference_sharding_path.
10+
ici_parallelism: {"replica": 1, "model": 64}
11+
dcn_parallelism: {"replica": 4}
12+
13+
# Note: checkpoint_config field not specified.
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
17+
options:
18+
# --- Generator Options ---
19+
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
20+
# associated with the `RestoreAndBroadcastBenchmark` generator.
21+
async_enabled: true
22+
use_ocdbt: true
23+
use_zarr3: true
24+
use_replica_parallel: false
25+
use_compression: true
26+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
27+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
28+
use_load_and_broadcast: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# The name for the entire test suite run.
2+
# Assumes n2-standard-32-32 (32 machines) X 4 replicas
3+
suite_name: "llama-70b_replicas_4_no_broadcast"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["replica", "model"]
9+
# Should match reference_sharding_path.
10+
ici_parallelism: {"replica": 1, "model": 64}
11+
dcn_parallelism: {"replica": 4}
12+
13+
# Note: checkpoint_config field not specified.
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
17+
options:
18+
# --- Generator Options ---
19+
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
20+
# associated with the `RestoreAndBroadcastBenchmark` generator.
21+
async_enabled: true
22+
use_ocdbt: true
23+
use_zarr3: true
24+
use_replica_parallel: false
25+
use_compression: true
26+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4/ckpt"
27+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-256-data-4-fsdp-8-tensor-4/abstract_state.json"
28+
use_load_and_broadcast: False
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# The name for the entire test suite run.
2+
# Assumes 32 devices X 16 replicas
3+
suite_name: "llama-8b_replicas_16"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["replica", "model"]
9+
# Should match reference_sharding_path.
10+
ici_parallelism: {"replica": 1, "model": 32}
11+
dcn_parallelism: {"replica": 16}
12+
13+
# Note: checkpoint_config field not specified.
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
17+
options:
18+
# --- Generator Options ---
19+
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
20+
# associated with the `RestoreAndBroadcastBenchmark` generator.
21+
async_enabled: true
22+
use_ocdbt: true
23+
use_zarr3: true
24+
use_replica_parallel: false
25+
use_compression: true
26+
# Note, uses a bucket in EU, assuming the benchmark will run from a cell in the US. This
27+
# should increase the storage latency and make the effect of broadcasting more pronounced,
28+
# since the scale we can run at is too small to see much difference otherwise.
29+
reference_checkpoint_path: "gs://cpgaffney-eu-bucket/checkpoints/llama-8b_generate_8-2-1/ckpt"
30+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-8-data-1-fsdp-4-tensor-1/abstract_state.json"
31+
use_load_and_broadcast: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# The name for the entire test suite run.
2+
# Assumes 32 devices X 16 replicas
3+
suite_name: "llama-8b_replicas_16_no_broadcast"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["replica", "model"]
9+
# Should match reference_sharding_path.
10+
ici_parallelism: {"replica": 1, "model": 32}
11+
dcn_parallelism: {"replica": 16}
12+
13+
# Note: checkpoint_config field not specified.
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
17+
options:
18+
# --- Generator Options ---
19+
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
20+
# associated with the `RestoreAndBroadcastBenchmark` generator.
21+
async_enabled: true
22+
use_ocdbt: true
23+
use_zarr3: true
24+
use_replica_parallel: false
25+
use_compression: true
26+
# Note, uses a bucket in EU, assuming the benchmark will run from a cell in the US. This
27+
# should increase the storage latency and make the effect of broadcasting more pronounced,
28+
# since the scale we can run at is too small to see much difference otherwise.
29+
reference_checkpoint_path: "gs://cpgaffney-eu-bucket/checkpoints/llama-8b_generate_8-2-1/ckpt"
30+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-8-data-1-fsdp-4-tensor-1/abstract_state.json"
31+
use_load_and_broadcast: false
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# The name for the entire test suite run.
2+
# Assumes v5p-8 (4 chips) X 2 replicas
3+
suite_name: "llama-8b_replicas_2"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["replica", "model"]
9+
# Should match reference_sharding_path.
10+
ici_parallelism: {"replica": 1, "model": 4}
11+
dcn_parallelism: {"replica": 2}
12+
13+
# Note: checkpoint_config field not specified.
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
17+
options:
18+
# --- Generator Options ---
19+
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
20+
# associated with the `RestoreAndBroadcastBenchmark` generator.
21+
async_enabled: true
22+
use_ocdbt: true
23+
use_zarr3: true
24+
use_replica_parallel: false
25+
use_compression: true
26+
# Note, uses a bucket in EU, assuming the benchmark will run from a cell in the US. This
27+
# should increase the storage latency and make the effect of broadcasting more pronounced,
28+
# since the scale we can run at is too small to see much difference otherwise.
29+
reference_checkpoint_path: "gs://cpgaffney-eu-bucket/checkpoints/llama-8b_generate_8-2-1/ckpt"
30+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-8-data-1-fsdp-4-tensor-1/abstract_state.json"
31+
use_load_and_broadcast: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# The name for the entire test suite run.
2+
# Assumes v5p-8 (4 chips) X 2 replicas
3+
suite_name: "llama-8b_replicas_2_no_broadcast"
4+
num_repeats: 20
5+
6+
7+
mesh_config:
8+
mesh_axes: ["replica", "model"]
9+
# Should match reference_sharding_path.
10+
ici_parallelism: {"replica": 1, "model": 4}
11+
dcn_parallelism: {"replica": 2}
12+
13+
# Note: checkpoint_config field not specified.
14+
15+
benchmarks:
16+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
17+
options:
18+
# --- Generator Options ---
19+
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
20+
# associated with the `RestoreAndBroadcastBenchmark` generator.
21+
async_enabled: true
22+
use_ocdbt: true
23+
use_zarr3: true
24+
use_replica_parallel: false
25+
use_compression: true
26+
# Note, uses a bucket in EU, assuming the benchmark will run from a cell in the US. This
27+
# should increase the storage latency and make the effect of broadcasting more pronounced,
28+
# since the scale we can run at is too small to see much difference otherwise.
29+
reference_checkpoint_path: "gs://cpgaffney-eu-bucket/checkpoints/llama-8b_generate_8-2-1/ckpt"
30+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-8-data-1-fsdp-4-tensor-1/abstract_state.json"
31+
use_load_and_broadcast: false

0 commit comments

Comments
 (0)