Skip to content

Commit b76c7cc

Browse files
tohtanatjruwase
andauthored
Add example of DeepCompile (#967)
* import files for deepcompile benchmark Signed-off-by: Masahiro Tanaka <[email protected]> * add figures Signed-off-by: Masahiro Tanaka <[email protected]> * add figures Signed-off-by: Masahiro Tanaka <[email protected]> * update document Signed-off-by: Masahiro Tanaka <[email protected]> * fix links to images Signed-off-by: Masahiro Tanaka <[email protected]> * add images Signed-off-by: Masahiro Tanaka <[email protected]> * specify deepspeed version Signed-off-by: Masahiro Tanaka <[email protected]> --------- Signed-off-by: Masahiro Tanaka <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 7b34e07 commit b76c7cc

27 files changed

+1729
-0
lines changed

Diff for: benchmarks/deepcompile/.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*.log
2+
*.pyc
3+
*.png

Diff for: benchmarks/deepcompile/README.md

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Benchmarks for DeepCompile
2+
3+
## Setup
4+
5+
This experiment scripts require 4 nodes that has 8 A100/H100 GPUs each.
6+
We tested the scripts with Python 3.10.12 and CUDA 12.4.
7+
8+
### Libraries
9+
10+
In addition, you need to install the following:
11+
12+
- PyTorch v2.6.0
13+
- DeepSpeed (v0.16.6 or newer)
14+
- transformers
15+
- accelerate
16+
- datasets v3.1
17+
18+
Here are an example of installation commands:
19+
20+
```bash
21+
pip3 install torch==2.6.0 torchvision torchaudio
22+
pip3 install transformers datasets==3.1 accelerate
23+
24+
# Install DeepSpeed
25+
pip install deepspeed
26+
27+
# Clone this repository
28+
git clone https://github.com/deepspeedai/DeepSpeedExamples
29+
cd benchmarks/deepcompile
30+
```
31+
32+
You need to set up these on all nodes.
33+
34+
### Setup for multiple nodes run
35+
36+
You need to set host names in `hostfile_n${NUM_NODES}`. The file should look like the following:
37+
38+
```
39+
node-0 slots=8
40+
node-1 slots=8
41+
node-2 slots=8
42+
node-3 slots=8
43+
```
44+
45+
## Evaluation on throughput
46+
47+
The following script runs the throughput benchmark. This sweeps the following conditions:
48+
49+
- Models: meta-llama/Meta-Llama-3-70B-Instruct, mistralai/Mixtral-8x7B-v0.1
50+
- Batch size: 1, 2, 4
51+
- Sequence length: 512 1024 2048
52+
- Frameworks and settings:
53+
- DeepSpeed ZeRO3 (ZeRO3)
54+
- DeepSpeed ZeRO3 +Compiler (ZeRO3 (C))
55+
- FSDP (FSDP)
56+
- FSDP + Compiler (FSDP (C))
57+
- DeepCompile + proactive prefetching (DeepCompile (P))
58+
- DeepCompile + selective unsharding (DeepCompile (S))
59+
- DeepCompile + proactive prefetching + selective unsharding (DeepCompile (P+S))
60+
61+
The script downloads the models from HuggingFace Model Hub. Please make sure that you have access to the models.
62+
63+
```bash
64+
export PROFILE_DIR=/path/to/profile
65+
bash run_bench.sh
66+
```
67+
68+
The logs resulting from our experiments are stored in `logs/` directory. The summary of results is output to `profiles/result.txt`. You can copy the file to `results/acc_step_1` and plot the throughput with the following commands.
69+
70+
```bash
71+
python plot.py --result_dir results/acc_step_1 --metric throughput
72+
```
73+
74+
Here are some example charts:
75+
76+
<table>
77+
<tr>
78+
<td><img src="results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs1.png" alt="Llama-3-70B/bs=1" width="300"></td>
79+
<td><img src="results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs1.png" alt="Mixtral-8x7B/bs=1" width="300"></td>
80+
</tr>
81+
</table>
82+
83+
The following script runs the benchmark with different number of gradient accumulation steps (2, 4, 8, 16).
84+
85+
The batch size and sequence length are fixed to 1 and 1024, respectively. (Note that FSDP doesn't work for this experiment)
86+
87+
```bash
88+
bash run_bench_acc.sh
89+
```
90+
91+
You can use the same script with `--acc_step_eval` to plot the results along gradient accumulation steps.
92+
93+
```bash
94+
ython plot.py --result_dir results/acc_step_1_16 --acc_step_eval --metric throughput
95+
```
96+
97+
Here are some example charts:
98+
99+
<table>
100+
<tr>
101+
<td><img src="results/acc_step_1_16/throughput/chart_throughput_Llama-3-70B_np32_bs1.png" alt="Llama-3-70B/bs=1" width="300"></td>
102+
<td><img src="results/acc_step_1_16/throughput/chart_throughput_Mixtral-8x7B_np32_bs1.png" alt="Mixtral-8x7B/bs=1" width="300"></td>
103+
</tr>
104+
</table>
105+
106+
## APIs and custom optimization passes
107+
108+
To enable DeepCompile, simply set "deepcompile": true in the compile section of your DeepSpeed configuration JSON:
109+
110+
```json
111+
{
112+
113+
"zero_optimization": {
114+
"stage": 3,
115+
},
116+
"compile": {
117+
"deepcompile": true,
118+
},
119+
120+
}
121+
```
122+
123+
In your training script, call the compile() API to invoke DeepCompile. The function signature is:
124+
125+
```python
126+
def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None:
127+
```
128+
129+
You can pass a custom optimization schedule using the schedule argument. For example, to apply ZeRO-3-style partitioning and the optimizations described above, you can define the schedule as follows:
130+
131+
```python
132+
schedule = []
133+
schedule.append((0, [zero3_compile.add_z3_gather_release]))
134+
schedule.append(
135+
(WARMUP,
136+
[zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather]))
137+
```
138+
139+
A schedule is defined as a list of tuples, where each tuple consists of:
140+
141+
- A step index (e.g., 0 or "WARMUP"), indicating when to apply the passes
142+
- A list of optimization functions to apply at that step
143+
144+
In the example above, `add_z3_gather_release` is applied at step 0 to minimize memory usage. After a warmup phase (e.g., after the first few training iterations), additional optimizations such as prefetching and selective unsharding are applied based on profiled memory usage.
145+
Each optimization pass takes a standardized set of arguments provided by DeepCompile. For details, please refer to the implementation of each pass:
146+
147+
- [ZeRO3 (All-gather and reduce-scatter insertion)](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/zero3_compile.py)
148+
- [Proactive prefetching](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/prefetch.py)
149+
- [Selective unsharding](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/selective_gather.py)
150+
- [Reduce-scatter insertion (ZeRO1)](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/zero1_compile.py)
151+
- [Adaptive offloading](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/offload_adam_states.py)
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: MULTI_GPU
4+
machine_rank: {{ machine_rank }}
5+
main_training_function: main
6+
mixed_precision: bf16
7+
num_machines: {{ num_machines }}
8+
num_processes: {{ num_processes }}
9+
rdzv_backend: static
10+
same_network: true
11+
tpu_env: []
12+
tpu_use_cluster: false
13+
tpu_use_sudo: false
14+
use_cpu: false
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
{% if fp16 %}
3+
"fp16": {
4+
"enabled": true,
5+
"initial_scale_power": 8
6+
},
7+
{% else %}
8+
"bf16": {
9+
"enabled": true
10+
},
11+
{% endif %}
12+
"zero_optimization": {
13+
"stage": {{ zero_stage }},
14+
"sub_group_size": 100000000
15+
},
16+
"compile": {
17+
"deepcompile": {{ deepcompile }},
18+
"offload_activation": false,
19+
"offload_opt_states": false,
20+
"double_buffer": true,
21+
"symmetric_memory": false,
22+
"free_activation": false,
23+
"debug_log": {{ debug_log }},
24+
"sync_before_reduce": {{ sync_before_reduce }},
25+
"sync_after_reduce": {{ sync_after_reduce }}
26+
},
27+
"gradient_accumulation_steps": {{ gradient_accumulation_steps }},
28+
"gradient_clipping": "auto",
29+
"steps_per_print": 2000,
30+
"train_batch_size": "auto",
31+
"train_micro_batch_size_per_gpu": "auto",
32+
"wall_clock_breakdown": false
33+
}
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
deepspeed_config:
4+
deepspeed_multinode_launcher: standard
5+
{%- if zero_stage == 3 %}
6+
zero3_init_flag: true
7+
{%- endif %}
8+
deepspeed_config_file: configs/ds_config.json
9+
distributed_type: DEEPSPEED
10+
machine_rank: {{ machine_rank }}
11+
main_training_function: main
12+
num_machines: {{ num_machines }}
13+
num_processes: {{ num_processes }}
14+
rdzv_backend: static
15+
same_network: true
16+
tpu_env: []
17+
tpu_use_cluster: false
18+
tpu_use_sudo: false
19+
use_cpu: false
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: FSDP
4+
fsdp_config:
5+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
6+
fsdp_backward_prefetch: BACKWARD_PRE
7+
fsdp_cpu_ram_efficient_loading: true
8+
fsdp_forward_prefetch: false
9+
fsdp_offload_params: false
10+
{%- if zero_stage == 3 %}
11+
fsdp_sharding_strategy: FULL_SHARD
12+
{%- else %}
13+
fsdp_sharding_strategy: SHARD_GRAD_OP
14+
{%- endif %}
15+
fsdp_state_dict_type: SHARDED_STATE_DICT
16+
fsdp_sync_module_states: true
17+
fsdp_use_orig_params: true
18+
machine_rank: {{ machine_rank }}
19+
main_training_function: main
20+
mixed_precision: bf16
21+
num_machines: {{ num_machines }}
22+
num_processes: {{ num_processes }}
23+
rdzv_backend: static
24+
same_network: true
25+
tpu_env: []
26+
tpu_use_cluster: false
27+
tpu_use_sudo: false
28+
use_cpu: false
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: NO
4+
main_training_function: main
5+
mixed_precision: bf16
6+
use_cpu: false

0 commit comments

Comments
 (0)