Skip to content

Commit 21631ac

Browse files
committed
config
1 parent d9dbc85 commit 21631ac

File tree

6 files changed

+509
-11
lines changed

6 files changed

+509
-11
lines changed

scripts/models/llama3.1-8B.sh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
MODEL_ARGS=(
2+
--swiglu
3+
--num-layers 32
4+
--hidden-size 4096
5+
--ffn-hidden-size 14336
6+
--num-attention-heads 32
7+
--group-query-attention
8+
--num-query-groups 8
9+
--max-position-embeddings 131072
10+
--use-rotary-position-embeddings
11+
--disable-bias-linear
12+
--normalization "RMSNorm"
13+
--norm-epsilon 1e-5
14+
--rotary-base 500000
15+
--vocab-size 128256
16+
--kv-channels 128
17+
--use-rope-scaling
18+
--rotary-scaling-factor 8.0
19+
--untie-embeddings-and-output-weights
20+
)

scripts/models/llama3.2-3B.sh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
MODEL_ARGS=(
2+
--swiglu
3+
--num-layers 28
4+
--hidden-size 3072
5+
--ffn-hidden-size 8192
6+
--num-attention-heads 24
7+
--group-query-attention
8+
--num-query-groups 8
9+
--max-position-embeddings 131072
10+
--use-rotary-position-embeddings
11+
--disable-bias-linear
12+
--normalization "RMSNorm"
13+
--norm-epsilon 1e-5
14+
--rotary-base 500000
15+
--vocab-size 128256
16+
--kv-channels 128
17+
--use-rope-scaling
18+
--rotary-scaling-factor 32.0
19+
--untie-embeddings-and-output-weights
20+
)

scripts/run-llama3.1-8B.sh

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#!/bin/bash
2+
3+
# for rerun the task
4+
pkill -9 sglang
5+
sleep 3
6+
ray stop --force
7+
pkill -9 ray
8+
pkill -9 python
9+
sleep 3
10+
pkill -9 ray
11+
pkill -9 python
12+
13+
set -ex
14+
15+
# will prevent ray from buffering stdout/stderr
16+
export PYTHONBUFFERED=16
17+
18+
NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
19+
if [ "$NVLINK_COUNT" -gt 0 ]; then
20+
HAS_NVLINK=1
21+
else
22+
HAS_NVLINK=0
23+
fi
24+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
25+
26+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
27+
source "${SCRIPT_DIR}/models/llama3.1-8B.sh"
28+
29+
CKPT_ARGS=(
30+
--hf-checkpoint /root/Llama-3.1-8B
31+
--ref-load /root/Llama-3.1-8B_torch_dist
32+
--load /root/Llama-3.1-8B_slime/
33+
--save /root/Llama-3.1-8B_slime/
34+
--save-interval 50
35+
)
36+
37+
ROLLOUT_ARGS=(
38+
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
39+
--input-key prompt
40+
--label-key label
41+
--apply-chat-template
42+
--rollout-shuffle
43+
--rm-type math
44+
--num-epoch 1
45+
--rollout-batch-size 32
46+
--n-samples-per-prompt 8
47+
--rollout-max-response-len 16384
48+
--rollout-temperature 0.8
49+
50+
--global-batch-size 256
51+
--balance-data
52+
--partial-rollout
53+
--over-sampling-batch-size 64
54+
)
55+
56+
EVAL_ARGS=(
57+
--eval-interval 10
58+
--eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
59+
--n-samples-per-eval-prompt 4
60+
--eval-max-response-len 16384
61+
--eval-top-p 0.7
62+
)
63+
64+
PERF_ARGS=(
65+
--tensor-model-parallel-size 2
66+
--sequence-parallel
67+
--pipeline-model-parallel-size 1
68+
--context-parallel-size 1
69+
--expert-model-parallel-size 1
70+
--expert-tensor-parallel-size 1
71+
72+
--recompute-granularity full
73+
--recompute-method uniform
74+
--recompute-num-layers 1
75+
76+
# --micro-batch-size 1
77+
--use-dynamic-batch-size
78+
--max-tokens-per-gpu 9216
79+
)
80+
81+
GRPO_ARGS=(
82+
--advantage-estimator grpo
83+
--use-kl-loss
84+
--kl-loss-coef 0.00
85+
--kl-loss-type low_var_kl
86+
--entropy-coef 0.00
87+
--eps-clip 0.2
88+
--eps-clip-high 0.28
89+
)
90+
91+
OPTIMIZER_ARGS=(
92+
--optimizer adam
93+
--lr 1e-6
94+
--lr-decay-style constant
95+
--weight-decay 0.1
96+
--adam-beta1 0.9
97+
--adam-beta2 0.98
98+
)
99+
100+
WANDB_ARGS=(
101+
--use-wandb
102+
--wandb-project llama3.1-8B-training
103+
--wandb-group llama3.1-8B-grpo
104+
--wandb-key ${WANDB_API_KEY}
105+
)
106+
107+
SGLANG_ARGS=(
108+
--rollout-num-gpus-per-engine 2
109+
--sglang-mem-fraction-static 0.8
110+
)
111+
112+
MISC_ARGS=(
113+
# default dropout in megatron is 0.1
114+
--attention-dropout 0.0
115+
--hidden-dropout 0.0
116+
# should be good for model performance
117+
--accumulate-allreduce-grads-in-fp32
118+
--attention-softmax-in-fp32
119+
# need to comment this when using model with MLA
120+
--attention-backend flash
121+
)
122+
123+
# launch the master node of ray in container
124+
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
125+
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
126+
127+
# Build the runtime environment JSON with proper variable substitution
128+
RUNTIME_ENV_JSON="{
129+
\"env_vars\": {
130+
\"PYTHONPATH\": \"/root/Megatron-LM/\",
131+
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
132+
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
133+
}
134+
}"
135+
136+
ray job submit --address="http://127.0.0.1:8265" \
137+
--runtime-env-json="${RUNTIME_ENV_JSON}" \
138+
-- python3 train.py \
139+
--actor-num-nodes 1 \
140+
--actor-num-gpus-per-node 8 \
141+
--colocate \
142+
${MODEL_ARGS[@]} \
143+
${CKPT_ARGS[@]} \
144+
${ROLLOUT_ARGS[@]} \
145+
${OPTIMIZER_ARGS[@]} \
146+
${GRPO_ARGS[@]} \
147+
${WANDB_ARGS[@]} \
148+
${PERF_ARGS[@]} \
149+
${EVAL_ARGS[@]} \
150+
${SGLANG_ARGS[@]} \
151+
${MISC_ARGS[@]}

scripts/run-llama3.2-3B.sh

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#!/bin/bash
2+
3+
# for rerun the task
4+
pkill -9 sglang
5+
sleep 3
6+
ray stop --force
7+
pkill -9 ray
8+
pkill -9 python
9+
sleep 3
10+
pkill -9 ray
11+
pkill -9 python
12+
13+
set -ex
14+
15+
# will prevent ray from buffering stdout/stderr
16+
export PYTHONBUFFERED=16
17+
18+
NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
19+
if [ "$NVLINK_COUNT" -gt 0 ]; then
20+
HAS_NVLINK=1
21+
else
22+
HAS_NVLINK=0
23+
fi
24+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
25+
26+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
27+
source "${SCRIPT_DIR}/models/llama3.2-3B.sh"
28+
29+
CKPT_ARGS=(
30+
--hf-checkpoint /root/Llama-3.2-3B
31+
--ref-load /root/Llama-3.2-3B_torch_dist
32+
--load /root/Llama-3.2-3B_slime/
33+
--save /root/Llama-3.2-3B_slime/
34+
--save-interval 500
35+
)
36+
37+
ROLLOUT_ARGS=(
38+
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
39+
--input-key prompt
40+
--label-key label
41+
--apply-chat-template
42+
--rollout-shuffle
43+
--rm-type math
44+
--num-epoch 1
45+
--rollout-batch-size 32
46+
--n-samples-per-prompt 8
47+
--rollout-max-response-len 16384
48+
--rollout-temperature 1.0
49+
50+
--global-batch-size 256
51+
--balance-data
52+
--partial-rollout
53+
--over-sampling-batch-size 64
54+
)
55+
56+
EVAL_ARGS=(
57+
--eval-interval 10
58+
--eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
59+
--n-samples-per-eval-prompt 4
60+
--eval-max-response-len 16384
61+
--eval-top-p 0.7
62+
)
63+
64+
PERF_ARGS=(
65+
--tensor-model-parallel-size 2
66+
--sequence-parallel
67+
--pipeline-model-parallel-size 1
68+
--context-parallel-size 1
69+
--expert-model-parallel-size 1
70+
--expert-tensor-parallel-size 1
71+
72+
--recompute-granularity full
73+
--recompute-method uniform
74+
--recompute-num-layers 1
75+
76+
# --micro-batch-size 1
77+
--use-dynamic-batch-size
78+
--max-tokens-per-gpu 9216
79+
)
80+
81+
GRPO_ARGS=(
82+
--advantage-estimator grpo
83+
--use-kl-loss
84+
--kl-loss-coef 0.00
85+
--kl-loss-type low_var_kl
86+
--entropy-coef 0.00
87+
--eps-clip 0.2
88+
--eps-clip-high 0.28
89+
)
90+
91+
OPTIMIZER_ARGS=(
92+
--optimizer adam
93+
--lr 1e-6
94+
--lr-decay-style constant
95+
--weight-decay 0.1
96+
--adam-beta1 0.9
97+
--adam-beta2 0.98
98+
)
99+
100+
WANDB_ARGS=(
101+
--use-wandb
102+
--wandb-project debug
103+
--wandb-group h200-llama3.2-3B
104+
--wandb-key ${WANDB_API_KEY}
105+
)
106+
107+
SGLANG_ARGS=(
108+
--rollout-num-gpus-per-engine 2
109+
--sglang-mem-fraction-static 0.8
110+
)
111+
112+
MISC_ARGS=(
113+
# default dropout in megatron is 0.1
114+
--attention-dropout 0.0
115+
--hidden-dropout 0.0
116+
# should be good for model performance
117+
--accumulate-allreduce-grads-in-fp32
118+
--attention-softmax-in-fp32
119+
# need to comment this when using model with MLA
120+
--attention-backend flash
121+
)
122+
123+
# launch the master node of ray in container
124+
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
125+
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
126+
127+
# Build the runtime environment JSON with proper variable substitution
128+
RUNTIME_ENV_JSON="{
129+
\"env_vars\": {
130+
\"PYTHONPATH\": \"/root/Megatron-LM/\",
131+
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
132+
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
133+
}
134+
}"
135+
136+
ray job submit --address="http://127.0.0.1:8265" \
137+
--runtime-env-json="${RUNTIME_ENV_JSON}" \
138+
-- python3 train.py \
139+
--actor-num-nodes 1 \
140+
--actor-num-gpus-per-node 8 \
141+
--colocate \
142+
${MODEL_ARGS[@]} \
143+
${CKPT_ARGS[@]} \
144+
${ROLLOUT_ARGS[@]} \
145+
${OPTIMIZER_ARGS[@]} \
146+
${GRPO_ARGS[@]} \
147+
${DISTRIBUTED_ARGS[@]} \
148+
${WANDB_ARGS[@]} \
149+
${PERF_ARGS[@]} \
150+
${EVAL_ARGS[@]} \
151+
${SGLANG_ARGS[@]} \
152+
${MISC_ARGS[@]}

0 commit comments

Comments
 (0)