-
Notifications
You must be signed in to change notification settings - Fork 127
Expand file tree
/
Copy pathevo2_pretrain.yaml
More file actions
130 lines (129 loc) · 3.87 KB
/
evo2_pretrain.yaml
File metadata and controls
130 lines (129 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
scope: partial-conv
time_limit: 14400
# artifacts:
# # Artifact data mount paths for script execution, specified as mount_path: artifact_tag pairs.
# # See Confluence Onboarding Guide section 5.4 for more details on locating this data.
# # Needs update of script_args.data_path: /data-jetart/evo2. Cannot be enabled since Evo2 does not work with read-only folders as data mount.
# /data-jetart/evo2/data : text/opengenome2/processed/2025-01
key_segments:
# Modify keys to be renamed (str) or excluded (False) from run identifier. By default, all args under script_args are included.
data_path: False
clip_grad: False
lr: False
min_lr: False
wu_steps: False
pckg_url: False
file_name_wheel: False
seed: False
ops_kwargs: False
script_args:
# All arguments referenced in the script string must be specified here.
# Arguments not referenced in the script string must have the 'arg' field specified.
# See jet/core/configs.py for the specification of the configuration class
workspace: /workspace/bionemo2
data_path: /data/evo2
pckg_url: gitlab-master.nvidia.com/api/v4/projects/180496/packages/pypi/simple/
file_name_wheel: subquadratic-ops
model: evo2
variant: train
precision: fp8
gpus: 8
nodes: 4
batch_size: 8
max_steps: 490000
stop_steps: 6900
pp: 1
cp: 1
tp: 1
seq_len: 8192
acc_grad: 1
clip_grad: 250
seed: 3735928559
lr: 0.00015
min_lr: 0.000015
wu_steps: 5000
wd: 0.1
products:
- config_name: 1b
ops_kwargs: "--use-subquadratic_ops"
tp: 1
pp: 1
batch_size: 8
stop_steps: 6900
# FIXME: mamba training is not finished
# - config_name: hybrid_mamba_8b
# ops_kwargs: ""
# tp: 8
# pp: 2
# acc_grad: 1
# batch_size: 1
# stop_steps: 4000
script: |-
INSTALL_FLAG="/tmp/install_done_${{SLURMD_NODENAME}}";
if [ "$SLURM_LOCALID" = "0" ]; then
pip install ${file_name_wheel} --index-url https://oauth2:$JET_GITLAB_TOKEN@${pckg_url} --extra-index-url https://pypi.org/simple/
touch $INSTALL_FLAG
fi
# All ranks wait until install flag file appears
while [ ! -f $INSTALL_FLAG ]; do
sleep 1
done
WANDB_API_KEY=$BIONEMO_WANDB_API_KEY ${variant}_${model} \
-d /workspace/bionemo2/sub-packages/bionemo-evo2/examples/configs/full_pretrain_shortphase_config.yaml \
--dataset-dir ${data_path} \
--grad-acc-batches ${acc_grad} \
--fp8 --fp8-wgrad --activation-checkpoint-recompute-num-layers 5 \
--enable-preemption \
--ckpt-async-save \
--use-megatron-comm-overlap-llama3-8k \
--overlap-grad-reduce \
--clip-grad=${clip_grad} \
--eod-pad-in-loss-mask \
--seq-length=${seq_len} \
--seed ${seed} \
--lr=${lr} \
--wd=${wd} \
--min-lr=${min_lr} \
--warmup-steps=${wu_steps} \
--tensor-parallel-size=${tp} \
--context-parallel-size=${cp} \
--pipeline-model-parallel-size=${pp} \
--workers 8 \
--num-nodes=${nodes} \
--devices=${gpus} \
--micro-batch-size=${batch_size} \
--model-size=${config_name} \
--max-steps=${max_steps} \
--early-stop-on-step ${stop_steps} \
--limit-val-batches=20 \
--log-every-n-steps=50 \
--val-check-interval=500 \
${ops_kwargs} \
--create-tflops-callback \
--create-tensorboard-logger \
--result-dir=${tensorboard_dir} \
--wandb-project=${wandb_project_name} \
--wandb-group=${model}_${variant}_${config_name}__${target}__slen${seq_len} \
--wandb-job-type=${pipeline_label} \
--disable-checkpointing;
tests:
- logic_type: static
product_identifier: { 'target': 'dgxh100_eos', 'config_name': 1b }
logic_spec:
exit_codes:
- 0
baselines:
consumed_samples:
operator: eq
value: 1766400
val_loss:
operator: range
max: 1.26
min: 1.22
reduced_train_loss:
operator: range
max: 1.24
min: 1.19
TFLOPS_per_GPU:
operator: geq
value: 390