Skip to content

Commit dd56492

Browse files
authored
Add GLM demo on blackwell hardwares (#883)
1 parent 8ee2d64 commit dd56492

File tree

1 file changed

+330
-0
lines changed

1 file changed

+330
-0
lines changed

scripts/run_glm45_355b_a32b.py

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
"""
2+
This file is in preview, and will be further refined and optimized.
3+
"""
4+
5+
from dataclasses import dataclass
6+
from pathlib import Path
7+
from typing import Literal
8+
9+
import typer
10+
11+
import slime.utils.external_utils.command_utils as U
12+
13+
app = typer.Typer()
14+
15+
16+
@dataclass
17+
class ScriptArgs(U.ExecuteTrainConfig):
18+
mode: Literal["normal", "debug_minimal"] = "normal"
19+
run_id: str = U.create_run_id()
20+
model_org: str = "zai-org"
21+
model_name: str = "GLM-4.5"
22+
megatron_model_type: str = "glm4.5-355B-A32B"
23+
num_gpus_per_node: int = 4
24+
enable_eval: bool = True
25+
extra_args: str = ""
26+
rollout_fp8: bool = False
27+
dynamic_sampling: bool = False
28+
# TODO use more complex task
29+
task: Literal["dapo_aime", "gsm8k"] = "gsm8k"
30+
31+
32+
@app.command()
33+
@U.dataclass_cli
34+
def prepare_single(args: ScriptArgs):
35+
"""This script only needs to be executed on one node."""
36+
U.exec_command("mkdir -p /root/models /root/datasets")
37+
U.exec_command(
38+
f"huggingface-cli download {args.model_org}/{args.model_name} --local-dir /root/models/{args.model_name}"
39+
)
40+
match args.task:
41+
case "dapo_aime":
42+
U.hf_download_dataset("zhuzilin/dapo-math-17k")
43+
U.hf_download_dataset("zhuzilin/aime-2024")
44+
case "gsm8k":
45+
U.hf_download_dataset("zhuzilin/gsm8k")
46+
47+
if args.rollout_fp8:
48+
_convert_hf_to_fp8(args)
49+
50+
51+
def _convert_hf_to_fp8(args: ScriptArgs):
52+
path_output = f"/root/models/{args.model_name}-FP8/"
53+
if Path(path_output).exists():
54+
return
55+
56+
U.exec_command(
57+
"python tools/convert_hf_to_fp8.py "
58+
f"--model-dir /root/models/{args.model_name} "
59+
f"--save-dir {path_output} "
60+
"--strategy block --block-size 128 128 "
61+
"--max-workers 4"
62+
)
63+
64+
65+
@app.command()
66+
@U.dataclass_cli
67+
def prepare_spmd(args: ScriptArgs):
68+
U.convert_checkpoint(
69+
model_name=args.model_name,
70+
megatron_model_type=args.megatron_model_type,
71+
num_gpus_per_node=args.num_gpus_per_node,
72+
multinode=True,
73+
dir_dst="/root/models",
74+
)
75+
76+
77+
@app.command()
78+
@U.dataclass_cli
79+
def prepare_cp(args: ScriptArgs):
80+
_prepare_cp(args)
81+
82+
83+
def _prepare_cp(args: ScriptArgs):
84+
U.rsync_simple(
85+
path_src=f"/root/models/{args.model_name}_torch_dist",
86+
path_dst=f"/root/local_data/{args.model_name}_torch_dist",
87+
)
88+
U.rsync_simple(
89+
path_src=f"/root/models/{args.model_name}",
90+
path_dst=f"/root/local_data/{args.model_name}",
91+
)
92+
93+
94+
@app.command()
95+
@U.dataclass_cli
96+
def train(args: ScriptArgs):
97+
# ensure files are there is it was not synced before
98+
_prepare_cp(args)
99+
100+
hf_checkpoint = (
101+
f"/root/models/{args.model_name}_FP8" if args.rollout_fp8 else f"/root/local_data/{args.model_name}"
102+
)
103+
104+
load_save_path = f"/root/shared_data/{args.run_id}/checkpoints"
105+
ckpt_args = (
106+
f"--hf-checkpoint {hf_checkpoint} "
107+
f"--ref-load /root/local_data/{args.model_name}_torch_dist "
108+
f"--load {load_save_path} "
109+
f"--save {load_save_path} "
110+
f"--save-interval {2 if args.mode == 'debug_minimal' else 20} "
111+
f"--save-retain-interval {2 if args.mode == 'debug_minimal' else 20} "
112+
)
113+
114+
rollout_args = (
115+
"--label-key label "
116+
"--apply-chat-template "
117+
"--rollout-shuffle "
118+
"--rm-type math "
119+
"--num-rollout 3000 "
120+
# TODO enlarge
121+
"--rollout-batch-size 32 "
122+
"--n-samples-per-prompt 8 "
123+
"--rollout-temperature 0.8 "
124+
# ------------
125+
# TODO enlarge
126+
"--num-steps-per-rollout 1 "
127+
"--balance-data "
128+
"--rollout-stop-token-ids 151329 151336 151338 "
129+
)
130+
131+
if args.dynamic_sampling and (args.true_on_policy != "debug_minimal"):
132+
rollout_args += (
133+
"--over-sampling-batch-size 256 "
134+
"--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std "
135+
)
136+
137+
# sometimes disable eval to speed up debugging
138+
eval_args = ""
139+
if (args.mode != "debug_minimal") and args.enable_eval:
140+
eval_args += "--eval-interval 20 " "--eval-top-p 0.7 "
141+
142+
match args.task:
143+
case "dapo_aime":
144+
rollout_args += (
145+
"--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl "
146+
"--input-key prompt "
147+
f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} "
148+
)
149+
eval_args += (
150+
"--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl "
151+
"--n-samples-per-eval-prompt 8 "
152+
"--eval-max-response-len 8192 "
153+
)
154+
case "gsm8k":
155+
rollout_args += (
156+
"--prompt-data /root/datasets/gsm8k/train.parquet "
157+
"--input-key messages "
158+
# Deliberately make it very short for this easy task
159+
f"--rollout-max-response-len 256 "
160+
)
161+
eval_args += (
162+
"--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet "
163+
"--n-samples-per-eval-prompt 1 "
164+
"--eval-max-response-len 256 "
165+
)
166+
167+
if args.num_nodes <= 4:
168+
# Not really runnable, useful for --debug-rollout-only
169+
perf_args = (
170+
"--tensor-model-parallel-size 4 "
171+
"--sequence-parallel "
172+
f"--pipeline-model-parallel-size 1 "
173+
"--context-parallel-size 2 "
174+
"--expert-model-parallel-size 8 "
175+
"--expert-tensor-parallel-size 1 "
176+
)
177+
else:
178+
perf_args = (
179+
# TODO choose a good config
180+
"--tensor-model-parallel-size 4 "
181+
"--sequence-parallel "
182+
f"--pipeline-model-parallel-size {8 if args.num_nodes == 8 else 4} "
183+
"--context-parallel-size 2 "
184+
"--expert-model-parallel-size 8 "
185+
"--expert-tensor-parallel-size 1 "
186+
)
187+
perf_args += (
188+
"--recompute-granularity full "
189+
"--recompute-method uniform "
190+
"--recompute-num-layers 1 "
191+
# ------------
192+
"--use-dynamic-batch-size "
193+
"--max-tokens-per-gpu 16384 "
194+
)
195+
196+
grpo_args = (
197+
"--advantage-estimator grpo "
198+
# TODO enables use-kl-loss but w/ coef 0. can we just disable it like this?
199+
# "--use-kl-loss "
200+
"--kl-loss-coef 0.00 "
201+
"--kl-loss-type low_var_kl "
202+
"--kl-coef 0.00 "
203+
"--entropy-coef 0.00 "
204+
# TODO wrong?
205+
"--eps-clip 1e-4 "
206+
"--eps-clip-high 2e-4 "
207+
"--use-tis "
208+
)
209+
210+
optimizer_args = (
211+
"--optimizer adam "
212+
"--lr 1e-6 "
213+
"--lr-decay-style constant "
214+
"--weight-decay 0.1 "
215+
"--adam-beta1 0.9 "
216+
"--adam-beta2 0.98 "
217+
# ------------
218+
# "--optimizer-cpu-offload "
219+
# "--overlap-cpu-optimizer-d2h-h2d "
220+
# "--use-precision-aware-optimizer "
221+
)
222+
223+
# TODO optimize parameters, especially for FP8
224+
# TODO pure tp attention is very inefficient
225+
# sglang_decode_max_bs = 256
226+
sglang_world_size = min(32, args.num_gpus_per_node * args.num_nodes)
227+
# sglang_attn_dp_size = 4
228+
# sglang_attn_tp_size = sglang_world_size // sglang_attn_dp_size
229+
sglang_args = (
230+
f"--rollout-num-gpus-per-engine {sglang_world_size} "
231+
"--sglang-mem-fraction-static 0.85 "
232+
f"--sglang-tp-size {sglang_world_size} "
233+
# f"--sglang-ep-size {sglang_world_size} "
234+
# dp attention
235+
# "--sglang-enable-dp-attention "
236+
# f"--sglang-dp-size {sglang_attn_dp_size} "
237+
# "--sglang-moe-dense-tp-size 1 "
238+
# "--sglang-enable-dp-lm-head "
239+
# TODO why disable?
240+
# "--sglang-disable-radix-cache "
241+
# enable deepep for sglang
242+
# "--sglang-moe-a2a-backend deepep "
243+
# "--sglang-deepep-mode low_latency "
244+
# make every dp rank has 128 concurrency
245+
# "--sglang-server-concurrency 1024 "
246+
# f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} "
247+
# f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} "
248+
# f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} "
249+
# For quick experiments
250+
# """--sglang-json-model-override-args '{"num_hidden_layers": 5}' """
251+
f"--sglang-chunked-prefill-size {sglang_world_size * 2048} "
252+
)
253+
sglang_extra_env_vars = {}
254+
# sglang_extra_env_vars = {
255+
# "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": f"{sglang_decode_max_bs}",
256+
# }
257+
258+
misc_args = (
259+
# default dropout in megatron is 0.1
260+
"--attention-dropout 0.0 "
261+
"--hidden-dropout 0.0 "
262+
# should be good for model performance
263+
"--accumulate-allreduce-grads-in-fp32 "
264+
"--attention-softmax-in-fp32 "
265+
# need to comment this when using model with MLA
266+
"--attention-backend flash "
267+
# 4GB will lead to oom, not checked yet
268+
f"--update-weight-buffer-size {2 * 1024 ** 3} "
269+
# TODO maybe enable it
270+
# use deepep for megatron
271+
# "--moe-enable-deepep "
272+
# "--moe-token-dispatcher-type flex "
273+
# ------------
274+
f"--actor-num-nodes {args.num_nodes} "
275+
f"--actor-num-gpus-per-node {args.num_gpus_per_node} "
276+
f"--num-gpus-per-node {args.num_gpus_per_node} "
277+
"--colocate "
278+
"--use-fault-tolerance "
279+
f"--dump-details /root/shared_data/{args.run_id}/dump_details "
280+
"--disable-weights-backuper "
281+
)
282+
283+
train_args = (
284+
f"{ckpt_args} "
285+
f"{rollout_args} "
286+
f"{optimizer_args} "
287+
f"{grpo_args} "
288+
f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} "
289+
f"{perf_args} "
290+
f"{eval_args} "
291+
f"{sglang_args} "
292+
f"{misc_args} "
293+
f"{args.extra_args} "
294+
)
295+
296+
U.execute_train(
297+
train_args=train_args,
298+
config=args,
299+
num_gpus_per_node=args.num_gpus_per_node,
300+
megatron_model_type=args.megatron_model_type,
301+
extra_env_vars={
302+
**sglang_extra_env_vars,
303+
# TODO handle these
304+
# "GLOO_SOCKET_IFNAME": "${MLP_SOCKET_IFNAME}",
305+
# "TP_SOCKET_IFNAME": "${MLP_SOCKET_IFNAME}",
306+
# "NVTE_BWD_LAYERNORM_SM_MARGIN": "20",
307+
# "NCCL_IB_TC": "160",
308+
# "NCCL_PXN_DISABLE": "0",
309+
# "NCCL_IB_GID_INDEX": "3",
310+
# "NCCL_NET_GDR_LEVEL": "4",
311+
# "NCCL_IB_RETRY_CNT": "7",
312+
# "NCCL_IB_TIMEOUT": "32",
313+
# "NCCL_IB_QPS_PER_CONNECTION": "8",
314+
# "NCCL_P2P_LEVEL": "NVL",
315+
# "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # TODO should this be used
316+
# "NCCL_NVLS_ENABLE": "0",
317+
# "NCCL_MIN_CTAS": "4",
318+
# "OMPI_MCA_pml": "ob1",
319+
# "OMPI_MCA_btl": "^openib",
320+
# "OMPI_MCA_routed": "direct",
321+
# "OMPI_MCA_routed_radix": "1024",
322+
# "OMPI_MCA_plm_rsh_no_tree_spawn": "1",
323+
# "OMPI_MCA_oob_tcp_if_include": "${MLP_SOCKET_IFNAME}",
324+
# "OMPI_MCA_btl_tcp_if_include": "${MLP_SOCKET_IFNAME}",
325+
},
326+
)
327+
328+
329+
if __name__ == "__main__":
330+
app()

0 commit comments

Comments
 (0)