diff --git a/benchmarks/deepcompile/.gitignore b/benchmarks/deepcompile/.gitignore
new file mode 100644
index 000000000..d73b31758
--- /dev/null
+++ b/benchmarks/deepcompile/.gitignore
@@ -0,0 +1,3 @@
+*.log
+*.pyc
+*.png
diff --git a/benchmarks/deepcompile/README.md b/benchmarks/deepcompile/README.md
new file mode 100644
index 000000000..4e580a47a
--- /dev/null
+++ b/benchmarks/deepcompile/README.md
@@ -0,0 +1,151 @@
+# Benchmarks for DeepCompile
+
+## Setup
+
+This experiment scripts require 4 nodes that has 8 A100/H100 GPUs each.
+We tested the scripts with Python 3.10.12 and CUDA 12.4.
+
+### Libraries
+
+In addition, you need to install the following:
+
+- PyTorch v2.6.0
+- DeepSpeed (v0.16.6 or newer)
+- transformers
+- accelerate
+- datasets v3.1
+
+Here are an example of installation commands:
+
+```bash
+pip3 install torch==2.6.0 torchvision torchaudio
+pip3 install transformers datasets==3.1 accelerate
+
+# Install DeepSpeed
+pip install deepspeed
+
+# Clone this repository
+git clone https://github.com/deepspeedai/DeepSpeedExamples
+cd benchmarks/deepcompile
+```
+
+You need to set up these on all nodes.
+
+### Setup for multiple nodes run
+
+You need to set host names in `hostfile_n${NUM_NODES}`. The file should look like the following:
+
+```
+node-0 slots=8
+node-1 slots=8
+node-2 slots=8
+node-3 slots=8
+```
+
+## Evaluation on throughput
+
+The following script runs the throughput benchmark. This sweeps the following conditions:
+
+- Models: meta-llama/Meta-Llama-3-70B-Instruct, mistralai/Mixtral-8x7B-v0.1
+- Batch size: 1, 2, 4
+- Sequence length: 512 1024 2048
+- Frameworks and settings:
+ - DeepSpeed ZeRO3 (ZeRO3)
+ - DeepSpeed ZeRO3 +Compiler (ZeRO3 (C))
+ - FSDP (FSDP)
+ - FSDP + Compiler (FSDP (C))
+ - DeepCompile + proactive prefetching (DeepCompile (P))
+ - DeepCompile + selective unsharding (DeepCompile (S))
+ - DeepCompile + proactive prefetching + selective unsharding (DeepCompile (P+S))
+
+The script downloads the models from HuggingFace Model Hub. Please make sure that you have access to the models.
+
+```bash
+export PROFILE_DIR=/path/to/profile
+bash run_bench.sh
+```
+
+The logs resulting from our experiments are stored in `logs/` directory. The summary of results is output to `profiles/result.txt`. You can copy the file to `results/acc_step_1` and plot the throughput with the following commands.
+
+```bash
+python plot.py --result_dir results/acc_step_1 --metric throughput
+```
+
+Here are some example charts:
+
+
+
+  |
+  |
+
+
+
+The following script runs the benchmark with different number of gradient accumulation steps (2, 4, 8, 16).
+
+The batch size and sequence length are fixed to 1 and 1024, respectively. (Note that FSDP doesn't work for this experiment)
+
+```bash
+bash run_bench_acc.sh
+```
+
+You can use the same script with `--acc_step_eval` to plot the results along gradient accumulation steps.
+
+```bash
+ython plot.py --result_dir results/acc_step_1_16 --acc_step_eval --metric throughput
+```
+
+Here are some example charts:
+
+
+
+  |
+  |
+
+
+
+## APIs and custom optimization passes
+
+To enable DeepCompile, simply set "deepcompile": true in the compile section of your DeepSpeed configuration JSON:
+
+```json
+{
+…
+ "zero_optimization": {
+ "stage": 3,
+ },
+ "compile": {
+ "deepcompile": true,
+ },
+…
+}
+```
+
+In your training script, call the compile() API to invoke DeepCompile. The function signature is:
+
+```python
+def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None:
+```
+
+You can pass a custom optimization schedule using the schedule argument. For example, to apply ZeRO-3-style partitioning and the optimizations described above, you can define the schedule as follows:
+
+```python
+schedule = []
+schedule.append((0, [zero3_compile.add_z3_gather_release]))
+schedule.append(
+ (WARMUP,
+ [zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather]))
+```
+
+A schedule is defined as a list of tuples, where each tuple consists of:
+
+- A step index (e.g., 0 or "WARMUP"), indicating when to apply the passes
+- A list of optimization functions to apply at that step
+
+In the example above, `add_z3_gather_release` is applied at step 0 to minimize memory usage. After a warmup phase (e.g., after the first few training iterations), additional optimizations such as prefetching and selective unsharding are applied based on profiled memory usage.
+Each optimization pass takes a standardized set of arguments provided by DeepCompile. For details, please refer to the implementation of each pass:
+
+- [ZeRO3 (All-gather and reduce-scatter insertion)](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/zero3_compile.py)
+- [Proactive prefetching](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/prefetch.py)
+- [Selective unsharding](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/selective_gather.py)
+- [Reduce-scatter insertion (ZeRO1)](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/zero1_compile.py)
+- [Adaptive offloading](https://github.com/deepspeedai/DeepSpeed/blob/tohtana/deepcompile/deepspeed/compile/passes/offload_adam_states.py)
diff --git a/benchmarks/deepcompile/configs/ddp_config.yaml.template b/benchmarks/deepcompile/configs/ddp_config.yaml.template
new file mode 100644
index 000000000..947b06949
--- /dev/null
+++ b/benchmarks/deepcompile/configs/ddp_config.yaml.template
@@ -0,0 +1,14 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+machine_rank: {{ machine_rank }}
+main_training_function: main
+mixed_precision: bf16
+num_machines: {{ num_machines }}
+num_processes: {{ num_processes }}
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/benchmarks/deepcompile/configs/ds_config.json.template b/benchmarks/deepcompile/configs/ds_config.json.template
new file mode 100644
index 000000000..b5eb1589c
--- /dev/null
+++ b/benchmarks/deepcompile/configs/ds_config.json.template
@@ -0,0 +1,33 @@
+{
+ {% if fp16 %}
+ "fp16": {
+ "enabled": true,
+ "initial_scale_power": 8
+ },
+ {% else %}
+ "bf16": {
+ "enabled": true
+ },
+ {% endif %}
+ "zero_optimization": {
+ "stage": {{ zero_stage }},
+ "sub_group_size": 100000000
+ },
+ "compile": {
+ "deepcompile": {{ deepcompile }},
+ "offload_activation": false,
+ "offload_opt_states": false,
+ "double_buffer": true,
+ "symmetric_memory": false,
+ "free_activation": false,
+ "debug_log": {{ debug_log }},
+ "sync_before_reduce": {{ sync_before_reduce }},
+ "sync_after_reduce": {{ sync_after_reduce }}
+ },
+ "gradient_accumulation_steps": {{ gradient_accumulation_steps }},
+ "gradient_clipping": "auto",
+ "steps_per_print": 2000,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/benchmarks/deepcompile/configs/ds_config.yaml.template b/benchmarks/deepcompile/configs/ds_config.yaml.template
new file mode 100644
index 000000000..f130fbea7
--- /dev/null
+++ b/benchmarks/deepcompile/configs/ds_config.yaml.template
@@ -0,0 +1,19 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ {%- if zero_stage == 3 %}
+ zero3_init_flag: true
+ {%- endif %}
+ deepspeed_config_file: configs/ds_config.json
+distributed_type: DEEPSPEED
+machine_rank: {{ machine_rank }}
+main_training_function: main
+num_machines: {{ num_machines }}
+num_processes: {{ num_processes }}
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/benchmarks/deepcompile/configs/fsdp_config.yaml.template b/benchmarks/deepcompile/configs/fsdp_config.yaml.template
new file mode 100644
index 000000000..ec1cebaea
--- /dev/null
+++ b/benchmarks/deepcompile/configs/fsdp_config.yaml.template
@@ -0,0 +1,28 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: FSDP
+fsdp_config:
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_backward_prefetch: BACKWARD_PRE
+ fsdp_cpu_ram_efficient_loading: true
+ fsdp_forward_prefetch: false
+ fsdp_offload_params: false
+ {%- if zero_stage == 3 %}
+ fsdp_sharding_strategy: FULL_SHARD
+ {%- else %}
+ fsdp_sharding_strategy: SHARD_GRAD_OP
+ {%- endif %}
+ fsdp_state_dict_type: SHARDED_STATE_DICT
+ fsdp_sync_module_states: true
+ fsdp_use_orig_params: true
+machine_rank: {{ machine_rank }}
+main_training_function: main
+mixed_precision: bf16
+num_machines: {{ num_machines }}
+num_processes: {{ num_processes }}
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/benchmarks/deepcompile/configs/singlegpu_config.yaml.template b/benchmarks/deepcompile/configs/singlegpu_config.yaml.template
new file mode 100644
index 000000000..8763d4d2a
--- /dev/null
+++ b/benchmarks/deepcompile/configs/singlegpu_config.yaml.template
@@ -0,0 +1,6 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: NO
+main_training_function: main
+mixed_precision: bf16
+use_cpu: false
diff --git a/benchmarks/deepcompile/gen_chart_acc_steps.py b/benchmarks/deepcompile/gen_chart_acc_steps.py
new file mode 100644
index 000000000..8b3cbd920
--- /dev/null
+++ b/benchmarks/deepcompile/gen_chart_acc_steps.py
@@ -0,0 +1,263 @@
+import argparse
+import re
+import pandas as pd
+import matplotlib.pyplot as plt
+from pathlib import Path
+
+def throughput_calculator(micro_batch_size, acc_steps, np, elapsed_time_per_iter,
+ hidden_size, num_attention_heads, num_key_value_heads,
+ ffn_hidden_size, num_layers, padded_vocab_size, seq_len,
+ topk: int, swiglu: bool, checkpoint_activations: bool):
+ batch_size = micro_batch_size * acc_steps * np
+ samples_per_second = batch_size / elapsed_time_per_iter
+
+ head_dim = hidden_size // num_attention_heads
+ gqa = num_attention_heads // num_key_value_heads
+ ffn_multiplier = 3 if swiglu else 2
+ macs_per_flops = 2
+
+ pre_and_post_mha_gemm_macs = batch_size * num_layers * (1 + (2 // gqa) + 1) * (hidden_size**2) * seq_len
+ mha_bgemm_macs = batch_size * num_layers * 2 * head_dim * num_attention_heads * (seq_len**2)
+ ffn_gemm_macs = batch_size * num_layers * ffn_multiplier * ffn_hidden_size * hidden_size * seq_len * topk
+ logit_lmhead_gemm_macs = batch_size * padded_vocab_size * hidden_size * seq_len
+
+ fwd_macs = pre_and_post_mha_gemm_macs + mha_bgemm_macs + ffn_gemm_macs + logit_lmhead_gemm_macs
+ bwd_macs = 2 * fwd_macs
+ fwd_bwd_macs = fwd_macs + bwd_macs
+
+ if checkpoint_activations:
+ fwd_bwd_macs += fwd_macs
+
+ flops_per_iteration = fwd_bwd_macs * macs_per_flops
+ tflops = flops_per_iteration / (elapsed_time_per_iter * np * (10**12))
+ return samples_per_second, tflops
+
+
+model_info = {
+ "meta-llama/Meta-Llama-3-8B": {
+ "hidden_size": 4096,
+ "num_attention_heads": 32,
+ "num_key_value_heads": 8,
+ "ffn_hidden_size": 16384,
+ "num_layers": 32,
+ "padded_vocab_size": 32000,
+ "topk": 1,
+ "swiglu": True # Meta-Llama-3ではswigluが使われていると仮定
+ },
+ "meta-llama/Meta-Llama-3-70B-Instruct": {
+ "hidden_size": 8192,
+ "num_attention_heads": 64,
+ "num_key_value_heads": 8,
+ "ffn_hidden_size": 32768,
+ "num_layers": 80,
+ "padded_vocab_size": 32000,
+ "topk": 1,
+ "swiglu": True # Meta-Llama-3ではswigluが使われていると仮定
+ },
+ "mistralai/Mixtral-8x7B-v0.1": {
+ "hidden_size": 4096,
+ "num_attention_heads": 32,
+ "num_key_value_heads": 8,
+ "ffn_hidden_size": 16384,
+ "num_layers": 32,
+ "padded_vocab_size": 32000,
+ "topk": 2, # MixtralではMoEで2エキスパート
+ "swiglu": False # Mistralはswigluを使っていないと仮定
+ }
+}
+
+parser = argparse.ArgumentParser(description="Plot performance metrics.")
+parser.add_argument("--metric", choices=["iteration_time", "throughput", "flops", "mfu", "peak_mem"], required=True,
+ help="Metric to plot: 'iteration_time', 'flops', 'mfu', or 'peak_mem'")
+parser.add_argument("--result_dir", type=str, required=True, help="Path to the directory containing results.txt")
+parser.add_argument("--result_file", type=str, default="results.txt", help="Name of the result file")
+args = parser.parse_args()
+
+
+# データのパース
+pattern = re.compile(
+ r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) "
+ r"seq=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) iteration time: (?P[\d.]+) "
+ r"alloc_mem: (?P\d+) peak_mem: (?P\d+)"
+)
+pattern_ctime = re.compile(
+ r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) "
+ r"seq=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) passes=(?P[\w,_]+) compile_time=(?P[\d.]+) iteration time: (?P[\d.]+) "
+ r"alloc_mem: (?P\d+) peak_mem: (?P\d+)"
+)
+pattern_cs = re.compile(
+ r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) "
+ r"seq=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) schedule=(?P\w+) passes=(?P[\w,_]+) compile_time=(?P[\d.]+) iteration time: (?P[\d.]+) "
+ r"alloc_mem: (?P\d+) peak_mem: (?P\d+)"
+)
+
+file = Path(args.result_dir) / args.result_file
+matches = []
+with open(file) as f:
+ for line in f:
+ match = pattern.match(line)
+ if not match:
+ match = pattern_ctime.match(line)
+ if not match:
+ match = pattern_cs.match(line)
+ if not match:
+ print(f"Not matched: {line}")
+ if match:
+ d = match.groupdict()
+ if "passes" not in d:
+ d["passes"] = ""
+ if "compile_time" not in d:
+ d["compile_time"] = 0
+ if "schedule" not in d:
+ d["schedule"] = d["compile"]
+ matches.append(d)
+
+df = pd.DataFrame(matches)
+
+# 型変換
+df["ds"] = df["ds"] == "True"
+df["compile"] = df["compile"] == "True"
+df["np"] = df["np"].astype(int)
+df["batch_size"] = df["batch_size"].astype(int) # batch_sizeをfloatに変換
+df["seq"] = df["seq"].astype(int)
+df["iteration_time"] = df["iteration_time"].astype(float) # iteration_timeをfloatに変換
+df["alloc_mem"] = df["alloc_mem"].astype(float)
+df["peak_mem"] = df["peak_mem"].astype(float)
+df["acc"] = df["acc"].astype(int) # accも明示的にint型へ
+df["ac"] = df["ac"] == "True" # acを真偽値に変換
+df["compile_time"] = df["compile_time"].astype(float)
+df["schedule"] = df["schedule"] == "True"
+
+
+# モデルごとの計算とプロット
+grouped = df.groupby(["model", "np", "batch_size"])
+
+theoretical_peak = 312 # 理論ピーク性能 (TFLOPS)
+
+
+LABEL_ZERO3 = "ZeRO3"
+LABEL_ZERO3_C = "ZeRO3 (C)"
+LABEL_FSDP = "FSDP"
+LABEL_DC_PS = "DeepCompile (P+S)"
+LABEL_DC_P = "DeepCompile (P)"
+LABEL_DC_S = "DeepCompile (S)"
+
+for (model, np, batch_size), group in grouped:
+ group = group.sort_values("acc")
+ acc_labels = group["acc"].unique()
+
+ print(f"acc_labels: {acc_labels}")
+
+ metric_values = {LABEL_ZERO3: [0] * len(acc_labels),
+ LABEL_ZERO3_C: [0] * len(acc_labels),
+ LABEL_FSDP: [0] * len(acc_labels),
+ LABEL_DC_PS: [0] * len(acc_labels),
+ LABEL_DC_P: [0] * len(acc_labels),
+ LABEL_DC_S: [0] * len(acc_labels)}
+
+ for _, row in group.iterrows():
+
+ if row["ds"] and not row["compile"]:
+ category = LABEL_ZERO3
+ elif not row["ds"] and not row["compile"]:
+ category = LABEL_FSDP
+ elif row["ds"] and row["compile"]:
+ if not row["schedule"]:
+ category = LABEL_ZERO3_C
+ elif row["passes"] == "" or row["passes"] == 'prefetch,selective_gather':
+ category = LABEL_DC_PS
+ # print(f"found prefetch,selective_gather")
+ elif row["passes"] == 'prefetch':
+ category = LABEL_DC_P
+ # print(f"found prefetch")
+ elif row["passes"] == 'selective_gather':
+ category = LABEL_DC_S
+ # print(f"found selective_gather")
+ else:
+ print(f"Unknown category: {row}")
+ continue
+ else:
+ print(f"Unknown category: {row}")
+ continue
+
+ acc_index = list(acc_labels).index(row["acc"])
+ if args.metric == "iteration_time":
+ metric_values[category][acc_index] = row["iteration_time"]
+ elif args.metric == "peak_mem":
+ metric_values[category][acc_index] = row["peak_mem"] / (1024**3)
+ elif args.metric == "throughput":
+ metric_values[category][acc_index] = row["batch_size"] * row["seq"] * row["acc"] / row["iteration_time"]
+ elif args.metric in ["flops", "mfu"]:
+ # モデル情報を使用して FLOPs を計算
+ model_params = model_info[row["model"]]
+ samples_per_second, tflops = throughput_calculator(
+ micro_batch_size=row["batch_size"],
+ acc_steps=row["acc"], # ログから取得
+ np=row["np"],
+ elapsed_time_per_iter=row["iteration_time"],
+ hidden_size=model_params["hidden_size"],
+ num_attention_heads=model_params["num_attention_heads"],
+ num_key_value_heads=model_params["num_key_value_heads"],
+ ffn_hidden_size=model_params["ffn_hidden_size"],
+ num_layers=model_params["num_layers"],
+ padded_vocab_size=model_params["padded_vocab_size"],
+ seq_len=row["seq"],
+ topk=model_params["topk"],
+ swiglu=model_params["swiglu"], # モデル定義から取得
+ checkpoint_activations=row["ac"] # ログから取得
+ )
+ if args.metric == "flops":
+ metric_values[category][acc_index] = tflops
+ elif args.metric == "mfu":
+ metric_values[category][acc_index] = tflops / theoretical_peak
+
+ # グラフ作成
+ x = range(len(acc_labels))
+ width = 0.15 # 棒グラフの幅
+ ylabel = {
+ "iteration_time": "Iteration Time (s)",
+ "flops": "TFLOPS",
+ "throughput": "Throughput (tokens/s/GPU)",
+ "mfu": "MFU",
+ "peak_mem": "Peak Memory (GB)"
+ }[args.metric]
+
+ plt.figure(figsize=(10, 8))
+ adjust = - 0.5 * width
+ plt.bar([i - width*2 + adjust for i in x], metric_values[LABEL_ZERO3], width, label=LABEL_ZERO3, alpha=0.7)
+ plt.bar([i - width + adjust for i in x], metric_values[LABEL_ZERO3_C], width, label=LABEL_ZERO3_C, alpha=0.7)
+ plt.bar([i + adjust for i in x], metric_values[LABEL_FSDP], width, label=LABEL_FSDP, alpha=0.7)
+ plt.bar([i + width + adjust for i in x], metric_values[LABEL_DC_P], width, label=LABEL_DC_P, alpha=0.7)
+ plt.bar([i + width*2 + adjust for i in x], metric_values[LABEL_DC_S], width, label=LABEL_DC_S, alpha=0.7)
+ plt.bar([i + width*3 + adjust for i in x], metric_values[LABEL_DC_PS], width, label=LABEL_DC_PS, alpha=0.7)
+
+ gain_zero3 = [metric_values[LABEL_DC_PS][i] / metric_values[LABEL_ZERO3][i] for i in range(len(acc_labels))]
+ print(f"model {model} np {np} batch_size {batch_size} {LABEL_ZERO3} metric_values: {metric_values[LABEL_ZERO3]} gain_zero3: {gain_zero3}")
+ print(f"model {model} np {np} batch_size {batch_size} {LABEL_DC_PS} metric_values: {metric_values[LABEL_DC_PS]}")
+
+ model = model.split('/')[1]
+ model = model.replace("Meta-Llama-3-8B", "Llama-3-8B")
+ model = model.replace("Meta-Llama-3-70B-Instruct", "Llama-3-70B")
+ model = model.replace("Mixtral-8x7B-v0.1", "Mixtral-8x7B")
+
+ plt.title(f"Model: {model}, #GPUs: {np}, Batch Size: {batch_size}", fontsize=24)
+ plt.xlabel("Acc Steps", fontsize=24)
+ plt.ylabel(ylabel, fontsize=24)
+ plt.xticks(x, acc_labels, fontsize=24)
+
+ if args.metric == "peak_mem":
+ plt.ylim(0, 80)
+
+ plt.yticks(fontsize=20)
+ plt.legend(loc="lower right", fontsize=18)
+ plt.grid(axis="y")
+
+ # ファイル保存
+ metric_name = args.metric
+ model = model.replace("/", "_")
+ chart_dir = Path(args.result_dir) / Path(metric_name)
+ chart_dir.mkdir(parents=True, exist_ok=True)
+ conf_str = f"{metric_name}_{model}_np{np}_bs{batch_size}"
+ img_path = chart_dir / f"chart_{conf_str}.png"
+ plt.savefig(str(img_path))
+ plt.close()
diff --git a/benchmarks/deepcompile/generate_conf.py b/benchmarks/deepcompile/generate_conf.py
new file mode 100644
index 000000000..b901bd9d5
--- /dev/null
+++ b/benchmarks/deepcompile/generate_conf.py
@@ -0,0 +1,52 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import argparse
+from jinja2 import Template
+from pathlib import Path
+
+def get_args():
+ parser = argparse.ArgumentParser(description='Config generation')
+
+ parser.add_argument('--machine_rank', type=int, help='machine_rank')
+ parser.add_argument('--num_machines', type=int, help='num_machines')
+ parser.add_argument('--num_processes', type=int, help='num_processes')
+ parser.add_argument('--zero_stage', type=int, choices=[0, 1, 2, 3], help='ZeRO stage')
+ parser.add_argument('--fp16', action='store_true', help='Use fp16')
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
+ parser.add_argument('--deepcompile', action='store_true', help='Use deepcompile')
+ parser.add_argument('--debug_log', action='store_true', help='Debug log')
+ parser.add_argument('--sync_before_reduce', action='store_true', help='Sync before reduce')
+ parser.add_argument('--sync_after_reduce', action='store_true', help='Sync after reduce')
+ parser.add_argument('--sync_before_allgather', action='store_true', help='Sync before allgather')
+ parser.add_argument('--sync_after_allgather', action='store_true', help='Sync after allgather')
+
+ parser.add_argument('--template_file', type=Path, help='Template file')
+ parser.add_argument('--output_file', type=Path, help='Output file')
+
+ return parser.parse_args()
+
+
+def main(args):
+ with open(args.template_file, 'r') as f:
+ template = Template(f.read())
+
+ with open(args.output_file, 'w') as f:
+ f.write(template.render(machine_rank=args.machine_rank,
+ num_machines=args.num_machines,
+ num_processes=args.num_processes,
+ zero_stage=args.zero_stage,
+ fp16=args.fp16,
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ deepcompile=str(args.deepcompile).lower(),
+ debug_log=str(args.debug_log).lower(),
+ sync_before_reduce=str(args.sync_before_reduce).lower(),
+ sync_after_reduce=str(args.sync_after_reduce).lower(),
+ sync_before_allgather=str(args.sync_before_allgather).lower(),
+ sync_after_allgather=str(args.sync_after_allgather).lower()))
+
+if __name__ == '__main__':
+ args = get_args()
+ main(args)
diff --git a/benchmarks/deepcompile/hostfile_n4 b/benchmarks/deepcompile/hostfile_n4
new file mode 100644
index 000000000..6d23cdd7f
--- /dev/null
+++ b/benchmarks/deepcompile/hostfile_n4
@@ -0,0 +1,4 @@
+node-0 slots=8
+node-1 slots=8
+node-2 slots=8
+node-3 slots=8
diff --git a/benchmarks/deepcompile/plot.py b/benchmarks/deepcompile/plot.py
new file mode 100644
index 000000000..e55fa1e37
--- /dev/null
+++ b/benchmarks/deepcompile/plot.py
@@ -0,0 +1,258 @@
+import argparse
+import re
+import pandas as pd
+import matplotlib.pyplot as plt
+from pathlib import Path
+
+def throughput_calculator(micro_batch_size, acc_steps, np, elapsed_time_per_iter,
+ hidden_size, num_attention_heads, num_key_value_heads,
+ ffn_hidden_size, num_layers, padded_vocab_size, seq_len,
+ topk: int, swiglu: bool, checkpoint_activations: bool):
+ batch_size = micro_batch_size * acc_steps * np
+ samples_per_second = batch_size / elapsed_time_per_iter
+
+ head_dim = hidden_size // num_attention_heads
+ gqa = num_attention_heads // num_key_value_heads
+ ffn_multiplier = 3 if swiglu else 2
+ macs_per_flops = 2
+
+ pre_and_post_mha_gemm_macs = batch_size * num_layers * (1 + (2 // gqa) + 1) * (hidden_size**2) * seq_len
+ mha_bgemm_macs = batch_size * num_layers * 2 * head_dim * num_attention_heads * (seq_len**2)
+ ffn_gemm_macs = batch_size * num_layers * ffn_multiplier * ffn_hidden_size * hidden_size * seq_len * topk
+ logit_lmhead_gemm_macs = batch_size * padded_vocab_size * hidden_size * seq_len
+
+ fwd_macs = pre_and_post_mha_gemm_macs + mha_bgemm_macs + ffn_gemm_macs + logit_lmhead_gemm_macs
+ bwd_macs = 2 * fwd_macs
+ fwd_bwd_macs = fwd_macs + bwd_macs
+
+ if checkpoint_activations:
+ fwd_bwd_macs += fwd_macs
+
+ flops_per_iteration = fwd_bwd_macs * macs_per_flops
+ tflops = flops_per_iteration / (elapsed_time_per_iter * np * (10**12))
+ return samples_per_second, tflops
+
+
+model_info = {
+ "meta-llama/Meta-Llama-3-8B": {
+ "hidden_size": 4096,
+ "num_attention_heads": 32,
+ "num_key_value_heads": 8,
+ "ffn_hidden_size": 16384,
+ "num_layers": 32,
+ "padded_vocab_size": 32000,
+ "topk": 1,
+ "swiglu": True
+ },
+ "meta-llama/Meta-Llama-3-70B-Instruct": {
+ "hidden_size": 8192,
+ "num_attention_heads": 64,
+ "num_key_value_heads": 8,
+ "ffn_hidden_size": 32768,
+ "num_layers": 80,
+ "padded_vocab_size": 32000,
+ "topk": 1,
+ "swiglu": True
+ },
+ "mistralai/Mixtral-8x7B-v0.1": {
+ "hidden_size": 4096,
+ "num_attention_heads": 32,
+ "num_key_value_heads": 8,
+ "ffn_hidden_size": 16384,
+ "num_layers": 32,
+ "padded_vocab_size": 32000,
+ "topk": 2,
+ "swiglu": False
+ }
+}
+
+parser = argparse.ArgumentParser(description="Plot performance metrics.")
+parser.add_argument("--metric", choices=["iteration_time", "throughput", "flops", "mfu", "peak_mem"], required=True,
+ help="Metric to plot: 'iteration_time', 'flops', 'mfu', or 'peak_mem'")
+parser.add_argument("--result_dir", type=str, required=True, help="Path to the directory containing results.txt")
+parser.add_argument("--result_file", type=str, default="results.txt", help="Name of the result file")
+parser.add_argument("--acc_step_eval", action="store_true", help="Evaluate the accuracy of the model")
+args = parser.parse_args()
+
+
+pattern = re.compile(
+ r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) "
+ r"seq=(?P\d+) zero_stage=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) deepcompile=(?P\w+) "
+ f"passes=(?P[\w,_]+) compile_time=(?P[\d.]+) iteration time: (?P[\d.]+) "
+ r"alloc_mem: (?P\d+) peak_mem: (?P\d+)"
+)
+file = Path(args.result_dir) / args.result_file
+matches = []
+with open(file) as f:
+ for line in f:
+ match = pattern.match(line)
+ if not match:
+ print(f"Not matched: {line}")
+ if match:
+ d = match.groupdict()
+ if "passes" not in d:
+ d["passes"] = ""
+ if "compile_time" not in d:
+ d["compile_time"] = 0
+ if "deepcompile" not in d:
+ d["deepcompile"] = d["compile"]
+ matches.append(d)
+
+df = pd.DataFrame(matches)
+print(df)
+
+df["ds"] = df["ds"] == "True"
+df["compile"] = df["compile"] == "True"
+df["np"] = df["np"].astype(int)
+df["batch_size"] = df["batch_size"].astype(int)
+df["seq"] = df["seq"].astype(int)
+df["iteration_time"] = df["iteration_time"].astype(float)
+df["alloc_mem"] = df["alloc_mem"].astype(float)
+df["peak_mem"] = df["peak_mem"].astype(float)
+df["acc"] = df["acc"].astype(int)
+df["ac"] = df["ac"] == "True"
+df["compile_time"] = df["compile_time"].astype(float)
+df["deepcompile"] = df["deepcompile"] == "True"
+
+
+grouped = df.groupby(["model", "np", "batch_size"])
+
+# We used A100
+theoretical_peak = 312
+
+LABEL_ZERO3 = "ZeRO3"
+LABEL_ZERO3_C = "ZeRO3 (C)"
+LABEL_FSDP = "FSDP"
+LABEL_FSDP_C = "FSDP (C)"
+LABEL_DC_PS = "DeepCompile (P+S)"
+LABEL_DC_P = "DeepCompile (P)"
+LABEL_DC_S = "DeepCompile (S)"
+
+for (model, np, batch_size), group in grouped:
+
+ sort_group_name = "acc" if args.acc_step_eval else "seq"
+
+ group = group.sort_values(sort_group_name)
+ labels = group[sort_group_name].unique()
+
+ metric_values = {LABEL_ZERO3: [0] * len(labels),
+ LABEL_ZERO3_C: [0] * len(labels),
+ LABEL_FSDP: [0] * len(labels),
+ LABEL_FSDP_C: [0] * len(labels),
+ LABEL_DC_PS: [0] * len(labels),
+ LABEL_DC_P: [0] * len(labels),
+ LABEL_DC_S: [0] * len(labels)}
+
+ for _, row in group.iterrows():
+ if row["ds"] and not row["compile"]:
+ category = LABEL_ZERO3
+ elif not row["ds"]:
+ if row["compile"]:
+ category = LABEL_FSDP_C
+ else:
+ category = LABEL_FSDP
+ elif row["ds"] and row["compile"]:
+ if not row["deepcompile"]:
+ category = LABEL_ZERO3_C
+ elif row["passes"] == "" or row["passes"] == 'prefetch,selective_gather':
+ category = LABEL_DC_PS
+ elif row["passes"] == 'prefetch':
+ category = LABEL_DC_P
+ elif row["passes"] == 'selective_gather':
+ category = LABEL_DC_S
+ else:
+ print(f"Unknown category1 : {row}")
+ continue
+ else:
+ print(f"Unknown category2 : {row}")
+ continue
+
+ group_index = list(labels).index(row[sort_group_name])
+ if args.metric == "iteration_time":
+ metric_values[category][group_index] = row["iteration_time"]
+ elif args.metric == "peak_mem":
+ metric_values[category][group_index] = row["peak_mem"] / (1024**3)
+ elif args.metric == "throughput":
+ metric_values[category][group_index] = row["batch_size"] * row["seq"] / row["iteration_time"] * row["acc"]
+ elif args.metric in ["flops", "mfu"]:
+ model_params = model_info[row["model"]]
+ samples_per_second, tflops = throughput_calculator(
+ micro_batch_size=row["batch_size"],
+ acc_steps=row["acc"],
+ np=row["np"],
+ elapsed_time_per_iter=row["iteration_time"],
+ hidden_size=model_params["hidden_size"],
+ num_attention_heads=model_params["num_attention_heads"],
+ num_key_value_heads=model_params["num_key_value_heads"],
+ ffn_hidden_size=model_params["ffn_hidden_size"],
+ num_layers=model_params["num_layers"],
+ padded_vocab_size=model_params["padded_vocab_size"],
+ seq_len=row["seq"],
+ topk=model_params["topk"],
+ swiglu=model_params["swiglu"],
+ checkpoint_activations=row["ac"]
+ )
+ if args.metric == "flops":
+ metric_values[category][group_index] = tflops
+ elif args.metric == "mfu":
+ metric_values[category][group_index] = tflops / theoretical_peak
+
+ x = range(len(labels))
+ width = 0.1
+ ylabel = {
+ "iteration_time": "Iteration Time (s)",
+ "flops": "TFLOPS",
+ "throughput": "Throughput (tokens/s/GPU)",
+ "mfu": "MFU",
+ "peak_mem": "Peak Memory (GB)"
+ }[args.metric]
+
+ if args.metric == "peak_mem":
+ plt.figure(figsize=(7, 8))
+ else:
+ plt.figure(figsize=(10, 8))
+ adjust = - .0 * width
+ plt.bar([i - width*3 + adjust for i in x], metric_values[LABEL_ZERO3], width, label=LABEL_ZERO3, alpha=0.7)
+ plt.bar([i - width*2 + adjust for i in x], metric_values[LABEL_ZERO3_C], width, label=LABEL_ZERO3_C, alpha=0.7)
+ plt.bar([i - width + adjust for i in x], metric_values[LABEL_FSDP], width, label=LABEL_FSDP, alpha=0.7)
+ plt.bar([i + adjust for i in x], metric_values[LABEL_FSDP_C], width, label=LABEL_FSDP_C, alpha=0.7)
+ plt.bar([i + width + adjust for i in x], metric_values[LABEL_DC_P], width, label=LABEL_DC_P, alpha=0.7)
+ plt.bar([i + width*2 + adjust for i in x], metric_values[LABEL_DC_S], width, label=LABEL_DC_S, alpha=0.7)
+ plt.bar([i + width*3 + adjust for i in x], metric_values[LABEL_DC_PS], width, label=LABEL_DC_PS, alpha=0.7)
+
+ gain_zero3 = [metric_values[LABEL_DC_PS][i] / metric_values[LABEL_ZERO3][i] for i in range(len(labels))]
+ print(f"model {model} np {np} batch_size {batch_size} {LABEL_ZERO3} metric_values: {metric_values[LABEL_ZERO3]} gain_zero3: {gain_zero3}")
+ gain_fsdp = [0 if metric_values[LABEL_FSDP][i] == 0 else metric_values[LABEL_DC_PS][i] / metric_values[LABEL_FSDP][i] for i in range(len(labels))]
+ print(f"model {model} np {np} batch_size {batch_size} {LABEL_FSDP} metric_values: {metric_values[LABEL_FSDP]} gain_fsdp: {gain_fsdp}")
+ print(f"model {model} np {np} batch_size {batch_size} {LABEL_DC_PS} metric_values: {metric_values[LABEL_DC_PS]}")
+
+ model = model.split('/')[1]
+ model = model.replace("Meta-Llama-3-8B", "Llama-3-8B")
+ model = model.replace("Meta-Llama-3-70B-Instruct", "Llama-3-70B")
+ model = model.replace("Mixtral-8x7B-v0.1", "Mixtral-8x7B")
+
+ plt.title(f"{model}, #GPUs: {np}, Batch Size: {batch_size}", fontsize=20)
+ if args.acc_step_eval:
+ plt.xlabel("Accumulation Steps", fontsize=20)
+ else:
+ plt.xlabel("Sequence Length", fontsize=20)
+ plt.ylabel(ylabel, fontsize=20)
+ plt.xticks(x, labels, fontsize=20)
+ plt.yticks(fontsize=20)
+
+ if args.metric == "peak_mem":
+ plt.ylim(0, 80)
+ plt.legend(loc="lower right", fontsize=16)
+ else:
+ plt.legend(loc="lower right", fontsize=18)
+
+ plt.grid(axis="y")
+
+ metric_name = args.metric
+ model = model.replace("/", "_")
+ chart_dir = Path(args.result_dir) / Path(metric_name)
+ chart_dir.mkdir(parents=True, exist_ok=True)
+ conf_str = f"{metric_name}_{model}_np{np}_bs{batch_size}"
+ img_path = chart_dir / f"chart_{conf_str}.png"
+ plt.savefig(str(img_path))
+ plt.close()
diff --git a/benchmarks/deepcompile/plot_common.py b/benchmarks/deepcompile/plot_common.py
new file mode 100644
index 000000000..8ebdadbeb
--- /dev/null
+++ b/benchmarks/deepcompile/plot_common.py
@@ -0,0 +1,251 @@
+import argparse
+import re
+import pandas as pd
+import matplotlib.pyplot as plt
+from pathlib import Path
+
+def throughput_calculator(micro_batch_size, acc_steps, np, elapsed_time_per_iter,
+ hidden_size, num_attention_heads, num_key_value_heads,
+ ffn_hidden_size, num_layers, padded_vocab_size, seq_len,
+ topk: int, swiglu: bool, checkpoint_activations: bool):
+ batch_size = micro_batch_size * acc_steps * np
+ samples_per_second = batch_size / elapsed_time_per_iter
+
+ head_dim = hidden_size // num_attention_heads
+ gqa = num_attention_heads // num_key_value_heads
+ ffn_multiplier = 3 if swiglu else 2
+ macs_per_flops = 2
+
+ pre_and_post_mha_gemm_macs = batch_size * num_layers * (1 + (2 // gqa) + 1) * (hidden_size**2) * seq_len
+ mha_bgemm_macs = batch_size * num_layers * 2 * head_dim * num_attention_heads * (seq_len**2)
+ ffn_gemm_macs = batch_size * num_layers * ffn_multiplier * ffn_hidden_size * hidden_size * seq_len * topk
+ logit_lmhead_gemm_macs = batch_size * padded_vocab_size * hidden_size * seq_len
+
+ fwd_macs = pre_and_post_mha_gemm_macs + mha_bgemm_macs + ffn_gemm_macs + logit_lmhead_gemm_macs
+ bwd_macs = 2 * fwd_macs
+ fwd_bwd_macs = fwd_macs + bwd_macs
+
+ if checkpoint_activations:
+ fwd_bwd_macs += fwd_macs
+
+ flops_per_iteration = fwd_bwd_macs * macs_per_flops
+ tflops = flops_per_iteration / (elapsed_time_per_iter * np * (10**12))
+ return samples_per_second, tflops
+
+
+model_info = {
+ "meta-llama/Meta-Llama-3-8B": {
+ "hidden_size": 4096,
+ "num_attention_heads": 32,
+ "num_key_value_heads": 8,
+ "ffn_hidden_size": 16384,
+ "num_layers": 32,
+ "padded_vocab_size": 32000,
+ "topk": 1,
+ "swiglu": True
+ },
+ "meta-llama/Meta-Llama-3-70B-Instruct": {
+ "hidden_size": 8192,
+ "num_attention_heads": 64,
+ "num_key_value_heads": 8,
+ "ffn_hidden_size": 32768,
+ "num_layers": 80,
+ "padded_vocab_size": 32000,
+ "topk": 1,
+ "swiglu": True
+ },
+ "mistralai/Mixtral-8x7B-v0.1": {
+ "hidden_size": 4096,
+ "num_attention_heads": 32,
+ "num_key_value_heads": 8,
+ "ffn_hidden_size": 16384,
+ "num_layers": 32,
+ "padded_vocab_size": 32000,
+ "topk": 2,
+ "swiglu": False
+ }
+}
+
+parser = argparse.ArgumentParser(description="Plot performance metrics.")
+parser.add_argument("--metric", choices=["iteration_time", "throughput", "flops", "mfu", "peak_mem"], required=True,
+ help="Metric to plot: 'iteration_time', 'flops', 'mfu', or 'peak_mem'")
+parser.add_argument("--result_dir", type=str, required=True, help="Path to the directory containing results.txt")
+parser.add_argument("--result_file", type=str, default="results.txt", help="Name of the result file")
+args = parser.parse_args()
+
+
+pattern = re.compile(
+ r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) "
+ r"seq=(?P\d+) zero_stage=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) schedule=(?P\w+) "
+ f"passes=(?P[\w,_]+) compile_time=(?P[\d.]+) iteration time: (?P[\d.]+) "
+ r"alloc_mem: (?P\d+) peak_mem: (?P\d+)"
+)
+file = Path(args.result_dir) / args.result_file
+matches = []
+with open(file) as f:
+ for line in f:
+ match = pattern.match(line)
+ if not match:
+ print(f"Not matched: {line}")
+ if match:
+ d = match.groupdict()
+ if "passes" not in d:
+ d["passes"] = ""
+ if "compile_time" not in d:
+ d["compile_time"] = 0
+ if "schedule" not in d:
+ d["schedule"] = d["compile"]
+ matches.append(d)
+
+df = pd.DataFrame(matches)
+print(df)
+
+df["ds"] = df["ds"] == "True"
+df["compile"] = df["compile"] == "True"
+df["np"] = df["np"].astype(int)
+df["batch_size"] = df["batch_size"].astype(int)
+df["seq"] = df["seq"].astype(int)
+df["iteration_time"] = df["iteration_time"].astype(float)
+df["alloc_mem"] = df["alloc_mem"].astype(float)
+df["peak_mem"] = df["peak_mem"].astype(float)
+df["acc"] = df["acc"].astype(int)
+df["ac"] = df["ac"] == "True"
+df["compile_time"] = df["compile_time"].astype(float)
+df["schedule"] = df["schedule"] == "True"
+
+
+grouped = df.groupby(["model", "np", "batch_size"])
+
+# We used A100
+theoretical_peak = 312
+
+LABEL_ZERO3 = "ZeRO3"
+LABEL_ZERO3_C = "ZeRO3 (C)"
+LABEL_FSDP = "FSDP"
+LABEL_FSDP_C = "FSDP (C)"
+LABEL_DC_PS = "DeepCompile (P+S)"
+LABEL_DC_P = "DeepCompile (P)"
+LABEL_DC_S = "DeepCompile (S)"
+
+for (model, np, batch_size), group in grouped:
+ group = group.sort_values("seq")
+ seq_labels = group["seq"].unique()
+
+ metric_values = {LABEL_ZERO3: [0] * len(seq_labels),
+ LABEL_ZERO3_C: [0] * len(seq_labels),
+ LABEL_FSDP: [0] * len(seq_labels),
+ LABEL_FSDP_C: [0] * len(seq_labels),
+ LABEL_DC_PS: [0] * len(seq_labels),
+ LABEL_DC_P: [0] * len(seq_labels),
+ LABEL_DC_S: [0] * len(seq_labels)}
+
+ for _, row in group.iterrows():
+ if row["ds"] and not row["compile"]:
+ category = LABEL_ZERO3
+ elif not row["ds"]:
+ if row["compile"]:
+ category = LABEL_FSDP_C
+ else:
+ category = LABEL_FSDP
+ elif row["ds"] and row["compile"]:
+ if not row["schedule"]:
+ category = LABEL_ZERO3_C
+ elif row["passes"] == "" or row["passes"] == 'prefetch,selective_gather':
+ category = LABEL_DC_PS
+ elif row["passes"] == 'prefetch':
+ category = LABEL_DC_P
+ elif row["passes"] == 'selective_gather':
+ category = LABEL_DC_S
+ else:
+ print(f"Unknown category1 : {row}")
+ continue
+ else:
+ print(f"Unknown category2 : {row}")
+ continue
+
+ seq_index = list(seq_labels).index(row["seq"])
+ if args.metric == "iteration_time":
+ metric_values[category][seq_index] = row["iteration_time"]
+ elif args.metric == "peak_mem":
+ metric_values[category][seq_index] = row["peak_mem"] / (1024**3)
+ elif args.metric == "throughput":
+ metric_values[category][seq_index] = row["batch_size"] * row["seq"] / row["iteration_time"]
+ elif args.metric in ["flops", "mfu"]:
+ model_params = model_info[row["model"]]
+ samples_per_second, tflops = throughput_calculator(
+ micro_batch_size=row["batch_size"],
+ acc_steps=row["acc"],
+ np=row["np"],
+ elapsed_time_per_iter=row["iteration_time"],
+ hidden_size=model_params["hidden_size"],
+ num_attention_heads=model_params["num_attention_heads"],
+ num_key_value_heads=model_params["num_key_value_heads"],
+ ffn_hidden_size=model_params["ffn_hidden_size"],
+ num_layers=model_params["num_layers"],
+ padded_vocab_size=model_params["padded_vocab_size"],
+ seq_len=row["seq"],
+ topk=model_params["topk"],
+ swiglu=model_params["swiglu"],
+ checkpoint_activations=row["ac"]
+ )
+ if args.metric == "flops":
+ metric_values[category][seq_index] = tflops
+ elif args.metric == "mfu":
+ metric_values[category][seq_index] = tflops / theoretical_peak
+
+ x = range(len(seq_labels))
+ width = 0.1
+ ylabel = {
+ "iteration_time": "Iteration Time (s)",
+ "flops": "TFLOPS",
+ "throughput": "Throughput (tokens/s/GPU)",
+ "mfu": "MFU",
+ "peak_mem": "Peak Memory (GB)"
+ }[args.metric]
+
+ if args.metric == "peak_mem":
+ plt.figure(figsize=(7, 8))
+ else:
+ plt.figure(figsize=(10, 8))
+ adjust = - .0 * width
+ plt.bar([i - width*3 + adjust for i in x], metric_values[LABEL_ZERO3], width, label=LABEL_ZERO3, alpha=0.7)
+ plt.bar([i - width*2 + adjust for i in x], metric_values[LABEL_ZERO3_C], width, label=LABEL_ZERO3_C, alpha=0.7)
+ plt.bar([i - width + adjust for i in x], metric_values[LABEL_FSDP], width, label=LABEL_FSDP, alpha=0.7)
+ plt.bar([i + adjust for i in x], metric_values[LABEL_FSDP_C], width, label=LABEL_FSDP_C, alpha=0.7)
+ plt.bar([i + width + adjust for i in x], metric_values[LABEL_DC_P], width, label=LABEL_DC_P, alpha=0.7)
+ plt.bar([i + width*2 + adjust for i in x], metric_values[LABEL_DC_S], width, label=LABEL_DC_S, alpha=0.7)
+ plt.bar([i + width*3 + adjust for i in x], metric_values[LABEL_DC_PS], width, label=LABEL_DC_PS, alpha=0.7)
+
+ gain_zero3 = [metric_values[LABEL_DC_PS][i] / metric_values[LABEL_ZERO3][i] for i in range(len(seq_labels))]
+ print(f"model {model} np {np} batch_size {batch_size} {LABEL_ZERO3} metric_values: {metric_values[LABEL_ZERO3]} gain_zero3: {gain_zero3}")
+ gain_fsdp = [metric_values[LABEL_DC_PS][i] / metric_values[LABEL_FSDP][i] for i in range(len(seq_labels))]
+ print(f"model {model} np {np} batch_size {batch_size} {LABEL_FSDP} metric_values: {metric_values[LABEL_FSDP]} gain_fsdp: {gain_fsdp}")
+ print(f"model {model} np {np} batch_size {batch_size} {LABEL_DC_PS} metric_values: {metric_values[LABEL_DC_PS]}")
+
+ model = model.split('/')[1]
+ model = model.replace("Meta-Llama-3-8B", "Llama-3-8B")
+ model = model.replace("Meta-Llama-3-70B-Instruct", "Llama-3-70B")
+ model = model.replace("Mixtral-8x7B-v0.1", "Mixtral-8x7B")
+
+ plt.title(f"{model}, #GPUs: {np}, Batch Size: {batch_size}", fontsize=20)
+ plt.xlabel("Sequence Length", fontsize=20)
+ plt.ylabel(ylabel, fontsize=20)
+ plt.xticks(x, seq_labels, fontsize=20)
+ plt.yticks(fontsize=20)
+
+ if args.metric == "peak_mem":
+ plt.ylim(0, 80)
+ plt.legend(loc="lower right", fontsize=16)
+ else:
+ plt.legend(loc="lower right", fontsize=18)
+
+ plt.grid(axis="y")
+
+ metric_name = args.metric
+ model = model.replace("/", "_")
+ chart_dir = Path(args.result_dir) / Path(metric_name)
+ chart_dir.mkdir(parents=True, exist_ok=True)
+ conf_str = f"{metric_name}_{model}_np{np}_bs{batch_size}"
+ img_path = chart_dir / f"chart_{conf_str}.png"
+ plt.savefig(str(img_path))
+ plt.close()
diff --git a/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs1.png b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs1.png
new file mode 100644
index 000000000..3aa22100a
Binary files /dev/null and b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs1.png differ
diff --git a/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs2.png b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs2.png
new file mode 100644
index 000000000..666df8570
Binary files /dev/null and b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs2.png differ
diff --git a/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs4.png b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs4.png
new file mode 100644
index 000000000..92909c148
Binary files /dev/null and b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Llama-3-70B_np32_bs4.png differ
diff --git a/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs1.png b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs1.png
new file mode 100644
index 000000000..6758d6fce
Binary files /dev/null and b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs1.png differ
diff --git a/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs2.png b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs2.png
new file mode 100644
index 000000000..0f42e564e
Binary files /dev/null and b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs2.png differ
diff --git a/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs4.png b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs4.png
new file mode 100644
index 000000000..f9bbf34ae
Binary files /dev/null and b/benchmarks/deepcompile/results/acc_step_1/throughput/chart_throughput_Mixtral-8x7B_np32_bs4.png differ
diff --git a/benchmarks/deepcompile/results/acc_step_1_16/throughput/chart_throughput_Llama-3-70B_np32_bs1.png b/benchmarks/deepcompile/results/acc_step_1_16/throughput/chart_throughput_Llama-3-70B_np32_bs1.png
new file mode 100644
index 000000000..557af85a4
Binary files /dev/null and b/benchmarks/deepcompile/results/acc_step_1_16/throughput/chart_throughput_Llama-3-70B_np32_bs1.png differ
diff --git a/benchmarks/deepcompile/results/acc_step_1_16/throughput/chart_throughput_Mixtral-8x7B_np32_bs1.png b/benchmarks/deepcompile/results/acc_step_1_16/throughput/chart_throughput_Mixtral-8x7B_np32_bs1.png
new file mode 100644
index 000000000..c3f992b1f
Binary files /dev/null and b/benchmarks/deepcompile/results/acc_step_1_16/throughput/chart_throughput_Mixtral-8x7B_np32_bs1.png differ
diff --git a/benchmarks/deepcompile/run.sh b/benchmarks/deepcompile/run.sh
new file mode 100644
index 000000000..57da03193
--- /dev/null
+++ b/benchmarks/deepcompile/run.sh
@@ -0,0 +1,225 @@
+#!/bin/bash
+
+
+NUM_NODES=${NUM_NODES:-$(wc -l < /job/hostfile)}
+NGPUS_PER_NODE=${NGPUS_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
+NUM_PROCESSES=$((${NUM_NODES} * ${NGPUS_PER_NODE}))
+
+BACKEND="deepspeed"
+MODEL="meta-llama/Meta-Llama-3-8B"
+ZERO_STAGE=3
+COMPILE=0
+PASSES="ALL"
+EXTRA_OPTS=""
+
+EAGER=0
+DEEPCOMPILE=0
+GRADIENT_ACCUMULATION_STEPS=1
+ACTIVATION_CHECKPOINTING=1
+BATCH_SIZE=1
+SEQ_LENGTH=512
+DEBUG_LOG=0
+SYNC_BEFORE_REDUCE=0
+SYNC_AFTER_REDUCE=0
+SYNC_BEFORE_ALLGATHER=0
+SYNC_AFTER_ALLGATHER=0
+
+echo "NUM_NODES: ${NUM_NODES} NGPUS_PER_NODE: ${NGPUS_PER_NODE} NUM_PROCESSES: ${NUM_PROCESSES}"
+
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --host-ip)
+ HOST_IP="$2"
+ shift 2
+ ;;
+ --backend)
+ BACKEND="$2"
+ shift 2
+ ;;
+ --zero-stage)
+ ZERO_STAGE="$2"
+ shift 2
+ ;;
+ --batch-size)
+ BATCH_SIZE="$2"
+ EXTRA_OPTS="${EXTRA_OPTS} --batch_size $2"
+ shift 2
+ ;;
+ --seq-length)
+ SEQ_LENGTH="$2"
+ EXTRA_OPTS="${EXTRA_OPTS} --seq_length $2"
+ shift 2
+ ;;
+ --gradient-accumulation-steps)
+ GRADIENT_ACCUMULATION_STEPS="$2"
+ EXTRA_OPTS="${EXTRA_OPTS} --gradient_accumulation_steps $2"
+ shift 2
+ ;;
+ --activation-checkpointing)
+ ACTIVATION_CHECKPOINTING=1
+ EXTRA_OPTS="${EXTRA_OPTS} --activation_checkpointing"
+ shift
+ ;;
+ --compile)
+ COMPILE=1
+ EXTRA_OPTS="${EXTRA_OPTS} $1"
+ shift
+ ;;
+ --eager)
+ EAGER=1
+ EXTRA_OPTS="${EXTRA_OPTS} --backend eager"
+ shift
+ ;;
+ --deepcompile)
+ DEEPCOMPILE=1
+ shift
+ ;;
+ --passes)
+ PASSES="$2"
+ EXTRA_OPTS="${EXTRA_OPTS} $1 $2"
+ shift 2
+ ;;
+ --profile)
+ EXTRA_OPTS="${EXTRA_OPTS} $1"
+ shift
+ ;;
+ --profile-dir)
+ EXTRA_OPTS="${EXTRA_OPTS} --profile_dir $2"
+ shift 2
+ ;;
+ --model)
+ MODEL="$2"
+ shift 2
+ ;;
+ --num-layers)
+ EXTRA_OPTS="${EXTRA_OPTS} --num_layers $2"
+ shift 2
+ ;;
+ --debug-log)
+ DEBUG_LOG=1
+ shift
+ ;;
+ --sync-before-reduce)
+ SYNC_BEFORE_REDUCE=1
+ shift
+ ;;
+ --sync-after-reduce)
+ SYNC_AFTER_REDUCE=1
+ shift
+ ;;
+ --sync-before-allgather)
+ SYNC_BEFORE_ALLGATHER=1
+ shift
+ ;;
+ --sync-after-allgather)
+ SYNC_AFTER_ALLGATHER=1
+ shift
+ ;;
+ *)
+ EXTRA_OPTS="${EXTRA_OPTS} $1"
+ shift
+ ;;
+ esac
+done
+
+
+
+export NCCL_DEBUG=WARN
+
+CONFIG_TEMPLATE=configs/ds_config.yaml.template
+if [ "${BACKEND}" == "fsdp" ]; then
+ CONFIG_TEMPLATE=configs/fsdp_config.yaml.template
+elif [ "${BACKEND}" == "ddp" ]; then
+ CONFIG_TEMPLATE=configs/ddp_config.yaml.template
+elif [ "${BACKEND}" == "singlegpu" ]; then
+ CONFIG_TEMPLATE=configs/singlegpu_config.yaml.template
+elif [ "${BACKEND}" != "deepspeed" ]; then
+ echo "Invalid backend: ${BACKEND}"
+ exit 1
+fi
+
+if [ "${BACKEND}" != "deepspeed" ]; then
+ ZERO_STAGE=0
+fi
+
+echo "HOST_IP: ${HOST_IP}"
+echo "NUM_NODES: ${NUM_NODES}"
+echo "NUM_PROCESSES: ${NUM_PROCESSES}"
+echo "BACKEND: ${BACKEND}"
+echo "ZERO_STAGE: ${ZERO_STAGE}"
+echo "MODEL: ${MODEL}"
+echo "GRADIENT_ACCUMULATION_STEPS: ${GRADIENT_ACCUMULATION_STEPS}"
+echo "EXTRA_OPTS: ${EXTRA_OPTS}"
+
+MACHINE_RANK=$(hostname | sed 's/[^0-9]*//g')
+
+python generate_conf.py \
+ --machine_rank ${MACHINE_RANK} \
+ --num_machines ${NUM_NODES} \
+ --num_processes ${NUM_PROCESSES} \
+ --zero_stage ${ZERO_STAGE} \
+ --template_file ${CONFIG_TEMPLATE} \
+ --output_file configs/config.yaml
+
+GAS_OPTS="--gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS}"
+
+if [ "${BACKEND}" == "deepspeed" ]; then
+ DEEPCOMPILE_OPTS=""
+ if [ "${DEEPCOMPILE}" == "1" ]; then
+ DEEPCOMPILE_OPTS="--deepcompile"
+ fi
+
+ DEBUG_LOG_OPTS=""
+ if [ "${DEBUG_LOG}" == "1" ]; then
+ DEBUG_LOG_OPTS="--debug_log"
+ fi
+
+ SYNC_BEFORE_REDUCE_OPTS=""
+ if [ "${SYNC_BEFORE_REDUCE}" == "1" ]; then
+ SYNC_BEFORE_REDUCE_OPTS="--sync_before_reduce"
+ fi
+
+ SYNC_AFTER_REDUCE_OPTS=""
+ if [ "${SYNC_AFTER_REDUCE}" == "1" ]; then
+ SYNC_AFTER_REDUCE_OPTS="--sync_after_reduce"
+ fi
+
+ SYNC_BEFORE_ALLGATHER_OPTS=""
+ if [ "${SYNC_BEFORE_ALLGATHER}" == "1" ]; then
+ SYNC_BEFORE_ALLGATHER_OPTS="--sync_before_allgather"
+ fi
+
+ SYNC_AFTER_ALLGATHER_OPTS=""
+ if [ "${SYNC_AFTER_ALLGATHER}" == "1" ]; then
+ SYNC_AFTER_ALLGATHER_OPTS="--sync_after_allgather"
+ fi
+
+ python generate_conf.py \
+ --machine_rank ${MACHINE_RANK} \
+ --num_machines ${NUM_NODES} \
+ --num_processes ${NUM_PROCESSES} \
+ --zero_stage ${ZERO_STAGE} \
+ --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
+ ${DEEPCOMPILE_OPTS} ${DEBUG_LOG_OPTS} \
+ ${SYNC_BEFORE_REDUCE_OPTS} ${SYNC_AFTER_REDUCE_OPTS} \
+ ${SYNC_BEFORE_ALLGATHER_OPTS} ${SYNC_AFTER_ALLGATHER_OPTS} \
+ --template_file configs/ds_config.json.template \
+ --output_file configs/ds_config.json
+fi
+
+#replace , with _ in PASSES
+PASSES=$(echo $PASSES | tr ',' '_')
+LOG_DIR=logs
+mkdir -p ${LOG_DIR}
+LOG_FILE=${LOG_DIR}/debug_n${MACHINE_RANK}_${MODEL##*/}_${BACKEND}_np${NUM_PROCESSES}z${ZERO_STAGE}c${COMPILE}dc${DEEPCOMPILE}E${EAGER}b${BATCH_SIZE}seq${SEQ_LENGTH}g${GRADIENT_ACCUMULATION_STEPS}a${ACTIVATION_CHECKPOINTING}p${PASSES}.log
+echo "Logging to ${LOG_FILE}"
+
+${HOME}/.local/bin/accelerate launch --main_process_ip ${HOST_IP} --main_process_port 12345 \
+--num_machines ${NUM_NODES} --num_processes ${NUM_PROCESSES} --machine_rank ${MACHINE_RANK} \
+--config_file configs/config.yaml \
+run_acc_lm.py \
+--model_name "${MODEL}" \
+--zero_stage ${ZERO_STAGE} \
+${GAS_OPTS} \
+${EXTRA_OPTS} \
+2>&1 | tee ${LOG_FILE}
diff --git a/benchmarks/deepcompile/run_bench.sh b/benchmarks/deepcompile/run_bench.sh
new file mode 100644
index 000000000..174e34951
--- /dev/null
+++ b/benchmarks/deepcompile/run_bench.sh
@@ -0,0 +1,50 @@
+PROFILE_DIR=${PROFILE_DIR:-"profiles"}
+mkdir -p ${PROFILE_DIR}
+PROFILE_OPTS="--profile --profile-dir ${PROFILE_DIR}"
+COMPILE_OPTS="--compile"
+DC_OPTS="--compile --deepcompile"
+ACC_OPTS="--gradient-accumulation-steps 1"
+AC_OPTS="--activation-checkpointing"
+
+MODEL="meta-llama/Meta-Llama-3-70B-Instruct"
+BATCH_SIZE_OPTS=(1 2 4)
+SEQ_LENGTH_OPTS=(512 1024 2048)
+for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do
+ for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do
+ # skip if batch size is 4 and seq length is 2048, as it causes OOM
+ if [ ${BATCH_SIZE} -eq 4 ] && [ ${SEQ_LENGTH} -eq 2048 ]; then
+ continue
+ fi
+
+ ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} ${ACC_OPTS} ${AC_OPTS}"
+ bash ./run_multinode.sh --backend deepspeed ${ARGS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${COMPILE_OPTS}
+ bash ./run_multinode.sh --backend fsdp ${ARGS}
+ bash ./run_multinode.sh --backend fsdp ${ARGS} ${COMPILE_OPTS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${DC_OPTS} --passes prefetch,selective_gather
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${DC_OPTS} --passes prefetch
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${DC_OPTS} --passes selective_gather
+
+ cp -r logs ${PROFILE_DIR}/
+ done
+done
+
+MODEL="mistralai/Mixtral-8x7B-v0.1"
+BATCH_SIZE_OPTS=(1 2 4)
+SEQ_LENGTH_OPTS=(512 1024 2048)
+for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do
+ for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do
+ # skip if batch size is 4 and seq length is 2048, as it causes OOM
+ ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} ${ACC_OPTS} ${AC_OPTS}"
+ bash ./run_multinode.sh --backend deepspeed ${ARGS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${COMPILE_OPTS}
+ bash ./run_multinode.sh --backend fsdp ${ARGS}
+ bash ./run_multinode.sh --backend fsdp ${ARGS} ${COMPILE_OPTS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${DC_OPTS} --passes prefetch,selective_gather
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${DC_OPTS} --passes prefetch
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${DC_OPTS} --passes selective_gather
+
+ cp -r logs ${PROFILE_DIR}/
+ done
+done
+
diff --git a/benchmarks/deepcompile/run_bench_acc.sh b/benchmarks/deepcompile/run_bench_acc.sh
new file mode 100644
index 000000000..a3b66844d
--- /dev/null
+++ b/benchmarks/deepcompile/run_bench_acc.sh
@@ -0,0 +1,42 @@
+PROFILE_DIR=${PROFILE_DIR:-profiles}
+mkdir -p ${PROFILE_DIR}
+PROFILE_OPTS="--profile --profile-dir ${PROFILE_DIR}"
+COMPILE_OPTS="--compile"
+N3Z_OPTS="--compile --deepcompile"
+AC_OPTS="--activation-checkpointing"
+
+MODEL="meta-llama/Meta-Llama-3-70B-Instruct"
+BATCH_SIZE_OPTS=(1)
+SEQ_LENGTH_OPTS=(1024)
+ACC_OPTS=(2 4 8 16)
+for ACC_STEP in ${ACC_OPTS[@]}; do
+ for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do
+ for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do
+ ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} ${AC_OPTS} ${PROFILE_OPTS} --gradient-accumulation-steps ${ACC_STEP}"
+ bash ./run_multinode.sh --backend deepspeed ${ARGS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${COMPILE_OPTS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch,selective_gather
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes selective_gather
+ cp -r logs ${PROFILE_DIR}/
+ done
+ done
+done
+
+MODEL="mistralai/Mixtral-8x7B-v0.1"
+BATCH_SIZE_OPTS=(1)
+SEQ_LENGTH_OPTS=(1024)
+ACC_OPTS=(2 4 8 16)
+for ACC_STEP in ${ACC_OPTS[@]}; do
+ for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do
+ for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do
+ ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} ${AC_OPTS} ${PROFILE_OPTS} --gradient-accumulation-steps ${ACC_STEP}"
+ bash ./run_multinode.sh --backend deepspeed ${ARGS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${COMPILE_OPTS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch,selective_gather
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes selective_gather
+ cp -r logs ${PROFILE_DIR}/
+ done
+ done
+done
diff --git a/benchmarks/deepcompile/run_bench_lm.py b/benchmarks/deepcompile/run_bench_lm.py
new file mode 100644
index 000000000..f175d84d7
--- /dev/null
+++ b/benchmarks/deepcompile/run_bench_lm.py
@@ -0,0 +1,270 @@
+import os
+import argparse
+import time
+from datetime import datetime
+from contextlib import nullcontext
+from typing import List
+
+import torch
+from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, enable_full_determinism
+from datasets import load_dataset, DownloadConfig
+from accelerate import Accelerator
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import SequentialSampler
+
+from datasets.utils.logging import disable_progress_bar
+
+from patch_phi3_moe import patch_phi3moe
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name", type=str, default="meta-llama/Llama-2-7b-hf")
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--num_epochs", type=int, default=100)
+ parser.add_argument("--seq_length", type=int, default=512)
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
+ parser.add_argument("--activation_checkpointing", action="store_true")
+ parser.add_argument("--dataset_name", type=str, default="timdettmers/openassistant-guanaco")
+ parser.add_argument("--num_layers", type=int, default=0)
+ parser.add_argument("--attn_impl", type=str, default="spda")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--passes", type=str, default=None)
+ parser.add_argument("--backend", type=str, default="inductor")
+ parser.add_argument("--offload_opt_states", action="store_true")
+ parser.add_argument("--profile", action="store_true")
+ parser.add_argument("--deterministic", action="store_true")
+ parser.add_argument("--profile_dir", type=str, default=None)
+ parser.add_argument("--bench_step", type=int, default=30)
+ parser.add_argument("--warmup_step", type=int, default=15)
+ parser.add_argument("--zero_stage", type=int, default=3)
+ parser.add_argument("--print_interval", type=int, default=1)
+ parser.add_argument("--save_weights", action="store_true")
+ parser.add_argument("--load_weights", action="store_true")
+
+ return parser.parse_args()
+
+
+def make_schedule(passes: List[str], warmup):
+ from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states
+
+ schedule = []
+
+ if "offload_adam_states" in passes:
+ assert len(passes) == 1, "offload_adam_states should be the only pass"
+ schedule.append((0, [offload_adam_states.offload_adam_states_for_init, zero3_compile.add_z3_gather_release, offload_adam_states.move_opt_states_sync]))
+ schedule.append((5, [offload_adam_states.offload_adam_states_for_init, zero3_compile.add_z3_gather_release, offload_adam_states.move_opt_states]))
+ elif "offload_adam_states_sync" in passes:
+ assert len(passes) == 1, "offload_adam_states_sync should be the only pass"
+ schedule.append((0, [zero3_compile.add_z3_gather_release, offload_adam_states.move_opt_states_sync]))
+ else:
+ schedule.append((0, [zero3_compile.add_z3_gather_release]))
+ second_opt = [zero3_compile.add_z3_gather_release]
+ if "prefetch" in passes:
+ second_opt.append(prefetch.schedule_prefetch)
+ if "selective_gather" in passes:
+ second_opt.append(selective_gather.selective_gather)
+ schedule.append((warmup, second_opt))
+ return schedule
+
+
+def main():
+ args = get_args()
+ print(args)
+
+ if "offload_adam_states" in args.passes:
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
+
+ if args.deterministic:
+ enable_full_determinism(1)
+ from torch._inductor import config
+ config.fallback_random = True
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
+ device = accelerator.device
+ is_deepspeed = accelerator.state.deepspeed_plugin is not None
+ print(f"Running on device: {device} is_deepspeed: {is_deepspeed}")
+
+ # Load model and tokenizer
+ if accelerator.is_main_process:
+ print("Loading model and tokenizer...")
+
+ model_name = args.model_name
+
+ model_weight_path = f"{model_name.split('/')[1]}_cp_layer{args.num_layers}"
+ if args.load_weights:
+ model = AutoModelForCausalLM.from_pretrained(model_weight_path, trust_remote_code=True)
+ else:
+ if args.num_layers > 0:
+ model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+ print(f"num_hidden_layers: {model_config.num_hidden_layers} -> {args.num_layers}")
+ model_config.num_hidden_layers = args.num_layers
+ model = AutoModelForCausalLM.from_config(model_config, trust_remote_code=True)
+ else:
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
+
+ if patch_phi3moe(model) and accelerator.is_main_process:
+ print("Patched Phi-3.5-MoE model")
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+
+ if args.save_weights and accelerator.is_main_process:
+ model.save_pretrained(model_weight_path)
+
+ if args.activation_checkpointing:
+ model.gradient_checkpointing_enable()
+
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # Load dataset
+ if accelerator.is_main_process:
+ print("Loading dataset...")
+ else:
+ disable_progress_bar()
+
+ dataset = load_dataset('ag_news', split='train[:100%]', download_config=DownloadConfig(disable_tqdm=True))
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(2)
+
+ def tokenize_function(examples):
+ return tokenizer(examples['text'], padding='max_length', max_length=args.seq_length, truncation=True)
+
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, num_proc=1, keep_in_memory=True)
+ tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
+
+ sampler = DistributedSampler(tokenized_dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index)
+ data_loader = DataLoader(tokenized_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4)
+
+ # Prepare optimizer
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
+
+ # Prepare everything with accelerator
+ model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
+ print(f"Model prepared: {model.__class__} optimizer: {optimizer.__class__}")
+
+ if "Mixtral" in model_name:
+ torch._dynamo.config.capture_dynamic_output_shape_ops = True
+ torch._dynamo.config.capture_scalar_outputs = True
+
+
+ if is_deepspeed:
+ if args.compile:
+ schedule = make_schedule(args.passes.split(","), warmup=5) if args.passes else None
+ model.compile(backend=args.backend, schedule=schedule)
+ else:
+ if args.compile:
+ model = torch.compile(model, backend=args.backend)
+
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
+ model_name = args.model_name.split("/")[-1]
+ exp_name = f"{model_name}_np{accelerator.num_processes}ds{1 if is_deepspeed else 0}" \
+ f"B{args.backend}z{args.zero_stage}" \
+ f"L{0 if args.num_layers is None else args.num_layers}" \
+ f"bs{args.batch_size}seq{args.seq_length}acc{args.gradient_accumulation_steps}ac{1 if args.activation_checkpointing else 0}" \
+ f"pass_{'none' if args.passes is None else args.passes.replace(',', '_')}_" \
+ f"os{1 if args.offload_opt_states else 0}" \
+ f"T{timestamp}"
+ if args.profile_dir:
+ if accelerator.is_main_process and args.profile_dir:
+ os.makedirs(args.profile_dir, exist_ok=True)
+ if args.profile:
+ prof_dir = f"{args.profile_dir}/{exp_name}"
+ os.makedirs(prof_dir, exist_ok=True)
+ accelerator.wait_for_everyone()
+
+ do_profile = args.profile and accelerator.is_main_process
+ prof_context = torch.profiler.profile(
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ schedule=torch.profiler.schedule(wait=0, warmup=10*args.gradient_accumulation_steps, active=3, repeat=1),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(prof_dir),
+ ) if do_profile else nullcontext()
+
+ # Training loop
+ model.train()
+ global_step = 0
+
+ iter_times = []
+
+ # See https://github.com/microsoft/DeepSpeed/issues/6793
+ acc_context = nullcontext if is_deepspeed else accelerator.accumulate
+
+ stop = False
+ with prof_context as prof:
+ for epoch in range(args.num_epochs):
+ start_iter = time.time()
+
+ for step, batch in enumerate(data_loader):
+ input_ids = batch['input_ids'].to(device)
+ attention_mask = batch['attention_mask'].to(device)
+
+ with acc_context(model):
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, use_cache=False)
+ loss = outputs.loss
+
+ update_step = (is_deepspeed and model.is_gradient_accumulation_boundary()) \
+ or (not is_deepspeed and accelerator.sync_gradients)
+ accelerator.backward(loss)
+ optimizer.step()
+ optimizer.zero_grad()
+ global_step += 1
+
+ if update_step:
+ if accelerator.is_main_process and global_step % (args.print_interval * args.gradient_accumulation_steps) == 0:
+ print(f"Epoch {epoch+1}, Step {global_step}, Loss: {loss.item()} sync: {accelerator.sync_gradients} time: {time.time() - start_iter} alloc_mem: {torch.cuda.memory_allocated()} peak_mem: {torch.cuda.max_memory_allocated()}")
+
+ iter_times.append(time.time() - start_iter)
+ start_iter = time.time()
+
+ if do_profile:
+ prof.step()
+
+ stop = global_step >= args.bench_step * args.gradient_accumulation_steps
+ if stop:
+ break
+ if stop:
+ break
+
+ iter_times = iter_times[args.warmup_step:]
+
+ if accelerator.is_main_process:
+ compile_time_sum = 0
+ compile_time = 0
+ if args.compile and hasattr(model, "get_compile_time"):
+ compile_time = model.get_compile_time()
+ compile_time_sum = sum(t for _, _, _, t in compile_time)
+
+ is_deepcompile = is_deepspeed and model._config.compile_config.deepcompile
+ msg = f"{args.model_name} ds={is_deepspeed} np={accelerator.num_processes} batch_size={args.batch_size} seq={args.seq_length} zero_stage={args.zero_stage} acc={args.gradient_accumulation_steps} ac={args.activation_checkpointing} compile={args.compile} backend={args.backend} deepcompile={is_deepcompile} passes={args.passes} compile_time={compile_time_sum} iteration time: {sum(iter_times) / len(iter_times):.4f} alloc_mem: {torch.cuda.memory_allocated()} peak_mem: {torch.cuda.max_memory_allocated()}"
+ print(msg)
+
+ if args.profile_dir:
+ from pathlib import Path
+ filepath = Path(args.profile_dir) / f"result.txt"
+ with open(filepath, "a") as f:
+ f.write(f"{timestamp} {msg}" + "\n")
+
+ if args.compile:
+ filepath = Path(args.profile_dir) / f"compile_time.txt"
+ with open(filepath, "a") as f:
+ msg = f"{msg} compile_time={compile_time_sum} {compile_time}"
+ f.write(f"{timestamp} {msg}" + "\n")
+
+ # # Save the model
+ # if accelerator.is_main_process:
+ # accelerator.wait_for_everyone()
+ # unwrapped_model = accelerator.unwrap_model(model)
+ # unwrapped_model.save_pretrained("fine_tuned_model", save_function=accelerator.save)
+ # tokenizer.save_pretrained("fine_tuned_model")
+
+if __name__ == "__main__":
+ torch._dynamo.config.accumulated_cache_size_limit = 256
+ torch._dynamo.config.cache_size_limit = 128
+ torch._dynamo.config.optimize_ddp = False
+
+ main()
diff --git a/benchmarks/deepcompile/run_bench_offload.sh b/benchmarks/deepcompile/run_bench_offload.sh
new file mode 100644
index 000000000..ea72db195
--- /dev/null
+++ b/benchmarks/deepcompile/run_bench_offload.sh
@@ -0,0 +1,25 @@
+PROFILE_DIR=${PROFILE_DIR:-"profile_offload"}
+mkdir -p ${PROFILE_DIR}
+PROFILE_OPTS="--profile --profile-dir ${PROFILE_DIR}"
+COMPILE_OPTS="--compile"
+DC_OPTS="--compile --deepcompile"
+ACC_OPTS="--gradient-accumulation-steps 1"
+AC_OPTS="--activation-checkpointing"
+
+mkdir -p logs
+
+export LOG_BASE="logs_offload"
+mkdir -p ${LOG_BASE}
+
+MODEL="meta-llama/Meta-Llama-3-70B-Instruct"
+BATCH_SIZE_OPTS=(1)
+SEQ_LENGTH_OPTS=(1024)
+for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do
+ for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do
+ ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} ${ACC_OPTS} ${AC_OPTS} ${PROFILE_OPTS}"
+ bash ./run.sh --backend deepspeed ${ARGS} --zero-stage 3
+ bash ./run.sh --backend deepspeed ${ARGS} --zero-stage 3 --ds-offload
+ bash ./run.sh --backend deepspeed ${ARGS} ${DC_OPTS} --zero-stage 3 --eager --passes offload_adam_states
+ bash ./run.sh --backend deepspeed ${ARGS} ${DC_OPTS} --zero-stage 3 --eager --passes offload_adam_states_sync
+ done
+done
diff --git a/benchmarks/deepcompile/run_bench_z1.sh b/benchmarks/deepcompile/run_bench_z1.sh
new file mode 100644
index 000000000..b5491e3fc
--- /dev/null
+++ b/benchmarks/deepcompile/run_bench_z1.sh
@@ -0,0 +1,21 @@
+PROFILE_DIR=${PROFILE_DIR:-profiles}
+mkdir -p ${PROFILE_DIR}
+PROFILE_OPTS="--profile --profile-dir ${PROFILE_DIR}"
+COMPILE_OPTS="--compile"
+DC_OPTS="--compile --deepcompile"
+ACC_OPTS="--gradient-accumulation-steps 1"
+AC_OPTS="--activation-checkpointing"
+
+MODEL="meta-llama/Meta-Llama-3-8B-Instruct"
+BATCH_SIZE_OPTS=(1 2 4)
+SEQ_LENGTH_OPTS=(512 1024 2048)
+for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do
+ for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do
+ ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} --zero-stage 1 ${ACC_OPTS} ${AC_OPTS}"
+ bash ./run_multinode.sh --backend deepspeed ${ARGS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${COMPILE_OPTS}
+ bash ./run_multinode.sh --backend deepspeed ${ARGS} ${DC_OPTS}
+
+ cp -r logs ${PROFILE_DIR}/
+ done
+done
diff --git a/benchmarks/deepcompile/run_multinode.sh b/benchmarks/deepcompile/run_multinode.sh
new file mode 100644
index 000000000..6f3feba9a
--- /dev/null
+++ b/benchmarks/deepcompile/run_multinode.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+
+echo $*
+
+SCRIPT_DIR=$(dirname $(realpath $0))
+HOST_IP=$(hostname -i)
+NUM_NODES=${NUM_NODES:-$(wc -l < /job/hostfile)}
+
+if [ "${NUM_NODES}" == "1" ]; then
+ # avoid dependency on pdsh when possible
+ cd ${SCRIPT_DIR}; bash ./run.sh --host-ip ${HOST_IP} $*
+else
+ ds_ssh -f hostfile_n${NUM_NODES} "cd ${SCRIPT_DIR}; bash ./run.sh --host-ip ${HOST_IP} $*"
+fi