|
| 1 | +""" |
| 2 | +This file is in preview, and will be further refined and optimized. |
| 3 | +""" |
| 4 | + |
| 5 | +from dataclasses import dataclass |
| 6 | +from pathlib import Path |
| 7 | +from typing import Literal |
| 8 | + |
| 9 | +import typer |
| 10 | + |
| 11 | +import slime.utils.external_utils.command_utils as U |
| 12 | + |
| 13 | +app = typer.Typer() |
| 14 | + |
| 15 | + |
| 16 | +@dataclass |
| 17 | +class ScriptArgs(U.ExecuteTrainConfig): |
| 18 | + mode: Literal["normal", "debug_minimal"] = "normal" |
| 19 | + run_id: str = U.create_run_id() |
| 20 | + model_org: str = "zai-org" |
| 21 | + model_name: str = "GLM-4.5" |
| 22 | + megatron_model_type: str = "glm4.5-355B-A32B" |
| 23 | + num_gpus_per_node: int = 4 |
| 24 | + enable_eval: bool = True |
| 25 | + extra_args: str = "" |
| 26 | + rollout_fp8: bool = False |
| 27 | + dynamic_sampling: bool = False |
| 28 | + # TODO use more complex task |
| 29 | + task: Literal["dapo_aime", "gsm8k"] = "gsm8k" |
| 30 | + |
| 31 | + |
| 32 | +@app.command() |
| 33 | +@U.dataclass_cli |
| 34 | +def prepare_single(args: ScriptArgs): |
| 35 | + """This script only needs to be executed on one node.""" |
| 36 | + U.exec_command("mkdir -p /root/models /root/datasets") |
| 37 | + U.exec_command( |
| 38 | + f"huggingface-cli download {args.model_org}/{args.model_name} --local-dir /root/models/{args.model_name}" |
| 39 | + ) |
| 40 | + match args.task: |
| 41 | + case "dapo_aime": |
| 42 | + U.hf_download_dataset("zhuzilin/dapo-math-17k") |
| 43 | + U.hf_download_dataset("zhuzilin/aime-2024") |
| 44 | + case "gsm8k": |
| 45 | + U.hf_download_dataset("zhuzilin/gsm8k") |
| 46 | + |
| 47 | + if args.rollout_fp8: |
| 48 | + _convert_hf_to_fp8(args) |
| 49 | + |
| 50 | + |
| 51 | +def _convert_hf_to_fp8(args: ScriptArgs): |
| 52 | + path_output = f"/root/models/{args.model_name}-FP8/" |
| 53 | + if Path(path_output).exists(): |
| 54 | + return |
| 55 | + |
| 56 | + U.exec_command( |
| 57 | + "python tools/convert_hf_to_fp8.py " |
| 58 | + f"--model-dir /root/models/{args.model_name} " |
| 59 | + f"--save-dir {path_output} " |
| 60 | + "--strategy block --block-size 128 128 " |
| 61 | + "--max-workers 4" |
| 62 | + ) |
| 63 | + |
| 64 | + |
| 65 | +@app.command() |
| 66 | +@U.dataclass_cli |
| 67 | +def prepare_spmd(args: ScriptArgs): |
| 68 | + U.convert_checkpoint( |
| 69 | + model_name=args.model_name, |
| 70 | + megatron_model_type=args.megatron_model_type, |
| 71 | + num_gpus_per_node=args.num_gpus_per_node, |
| 72 | + multinode=True, |
| 73 | + dir_dst="/root/models", |
| 74 | + ) |
| 75 | + |
| 76 | + |
| 77 | +@app.command() |
| 78 | +@U.dataclass_cli |
| 79 | +def prepare_cp(args: ScriptArgs): |
| 80 | + _prepare_cp(args) |
| 81 | + |
| 82 | + |
| 83 | +def _prepare_cp(args: ScriptArgs): |
| 84 | + U.rsync_simple( |
| 85 | + path_src=f"/root/models/{args.model_name}_torch_dist", |
| 86 | + path_dst=f"/root/local_data/{args.model_name}_torch_dist", |
| 87 | + ) |
| 88 | + U.rsync_simple( |
| 89 | + path_src=f"/root/models/{args.model_name}", |
| 90 | + path_dst=f"/root/local_data/{args.model_name}", |
| 91 | + ) |
| 92 | + |
| 93 | + |
| 94 | +@app.command() |
| 95 | +@U.dataclass_cli |
| 96 | +def train(args: ScriptArgs): |
| 97 | + # ensure files are there is it was not synced before |
| 98 | + _prepare_cp(args) |
| 99 | + |
| 100 | + hf_checkpoint = ( |
| 101 | + f"/root/models/{args.model_name}_FP8" if args.rollout_fp8 else f"/root/local_data/{args.model_name}" |
| 102 | + ) |
| 103 | + |
| 104 | + load_save_path = f"/root/shared_data/{args.run_id}/checkpoints" |
| 105 | + ckpt_args = ( |
| 106 | + f"--hf-checkpoint {hf_checkpoint} " |
| 107 | + f"--ref-load /root/local_data/{args.model_name}_torch_dist " |
| 108 | + f"--load {load_save_path} " |
| 109 | + f"--save {load_save_path} " |
| 110 | + f"--save-interval {2 if args.mode == 'debug_minimal' else 20} " |
| 111 | + f"--save-retain-interval {2 if args.mode == 'debug_minimal' else 20} " |
| 112 | + ) |
| 113 | + |
| 114 | + rollout_args = ( |
| 115 | + "--label-key label " |
| 116 | + "--apply-chat-template " |
| 117 | + "--rollout-shuffle " |
| 118 | + "--rm-type math " |
| 119 | + "--num-rollout 3000 " |
| 120 | + # TODO enlarge |
| 121 | + "--rollout-batch-size 32 " |
| 122 | + "--n-samples-per-prompt 8 " |
| 123 | + "--rollout-temperature 0.8 " |
| 124 | + # ------------ |
| 125 | + # TODO enlarge |
| 126 | + "--num-steps-per-rollout 1 " |
| 127 | + "--balance-data " |
| 128 | + "--rollout-stop-token-ids 151329 151336 151338 " |
| 129 | + ) |
| 130 | + |
| 131 | + if args.dynamic_sampling and (args.true_on_policy != "debug_minimal"): |
| 132 | + rollout_args += ( |
| 133 | + "--over-sampling-batch-size 256 " |
| 134 | + "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " |
| 135 | + ) |
| 136 | + |
| 137 | + # sometimes disable eval to speed up debugging |
| 138 | + eval_args = "" |
| 139 | + if (args.mode != "debug_minimal") and args.enable_eval: |
| 140 | + eval_args += "--eval-interval 20 " "--eval-top-p 0.7 " |
| 141 | + |
| 142 | + match args.task: |
| 143 | + case "dapo_aime": |
| 144 | + rollout_args += ( |
| 145 | + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " |
| 146 | + "--input-key prompt " |
| 147 | + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " |
| 148 | + ) |
| 149 | + eval_args += ( |
| 150 | + "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " |
| 151 | + "--n-samples-per-eval-prompt 8 " |
| 152 | + "--eval-max-response-len 8192 " |
| 153 | + ) |
| 154 | + case "gsm8k": |
| 155 | + rollout_args += ( |
| 156 | + "--prompt-data /root/datasets/gsm8k/train.parquet " |
| 157 | + "--input-key messages " |
| 158 | + # Deliberately make it very short for this easy task |
| 159 | + f"--rollout-max-response-len 256 " |
| 160 | + ) |
| 161 | + eval_args += ( |
| 162 | + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " |
| 163 | + "--n-samples-per-eval-prompt 1 " |
| 164 | + "--eval-max-response-len 256 " |
| 165 | + ) |
| 166 | + |
| 167 | + if args.num_nodes <= 4: |
| 168 | + # Not really runnable, useful for --debug-rollout-only |
| 169 | + perf_args = ( |
| 170 | + "--tensor-model-parallel-size 4 " |
| 171 | + "--sequence-parallel " |
| 172 | + f"--pipeline-model-parallel-size 1 " |
| 173 | + "--context-parallel-size 2 " |
| 174 | + "--expert-model-parallel-size 8 " |
| 175 | + "--expert-tensor-parallel-size 1 " |
| 176 | + ) |
| 177 | + else: |
| 178 | + perf_args = ( |
| 179 | + # TODO choose a good config |
| 180 | + "--tensor-model-parallel-size 4 " |
| 181 | + "--sequence-parallel " |
| 182 | + f"--pipeline-model-parallel-size {8 if args.num_nodes == 8 else 4} " |
| 183 | + "--context-parallel-size 2 " |
| 184 | + "--expert-model-parallel-size 8 " |
| 185 | + "--expert-tensor-parallel-size 1 " |
| 186 | + ) |
| 187 | + perf_args += ( |
| 188 | + "--recompute-granularity full " |
| 189 | + "--recompute-method uniform " |
| 190 | + "--recompute-num-layers 1 " |
| 191 | + # ------------ |
| 192 | + "--use-dynamic-batch-size " |
| 193 | + "--max-tokens-per-gpu 16384 " |
| 194 | + ) |
| 195 | + |
| 196 | + grpo_args = ( |
| 197 | + "--advantage-estimator grpo " |
| 198 | + # TODO enables use-kl-loss but w/ coef 0. can we just disable it like this? |
| 199 | + # "--use-kl-loss " |
| 200 | + "--kl-loss-coef 0.00 " |
| 201 | + "--kl-loss-type low_var_kl " |
| 202 | + "--kl-coef 0.00 " |
| 203 | + "--entropy-coef 0.00 " |
| 204 | + # TODO wrong? |
| 205 | + "--eps-clip 1e-4 " |
| 206 | + "--eps-clip-high 2e-4 " |
| 207 | + "--use-tis " |
| 208 | + ) |
| 209 | + |
| 210 | + optimizer_args = ( |
| 211 | + "--optimizer adam " |
| 212 | + "--lr 1e-6 " |
| 213 | + "--lr-decay-style constant " |
| 214 | + "--weight-decay 0.1 " |
| 215 | + "--adam-beta1 0.9 " |
| 216 | + "--adam-beta2 0.98 " |
| 217 | + # ------------ |
| 218 | + # "--optimizer-cpu-offload " |
| 219 | + # "--overlap-cpu-optimizer-d2h-h2d " |
| 220 | + # "--use-precision-aware-optimizer " |
| 221 | + ) |
| 222 | + |
| 223 | + # TODO optimize parameters, especially for FP8 |
| 224 | + # TODO pure tp attention is very inefficient |
| 225 | + # sglang_decode_max_bs = 256 |
| 226 | + sglang_world_size = min(32, args.num_gpus_per_node * args.num_nodes) |
| 227 | + # sglang_attn_dp_size = 4 |
| 228 | + # sglang_attn_tp_size = sglang_world_size // sglang_attn_dp_size |
| 229 | + sglang_args = ( |
| 230 | + f"--rollout-num-gpus-per-engine {sglang_world_size} " |
| 231 | + "--sglang-mem-fraction-static 0.85 " |
| 232 | + f"--sglang-tp-size {sglang_world_size} " |
| 233 | + # f"--sglang-ep-size {sglang_world_size} " |
| 234 | + # dp attention |
| 235 | + # "--sglang-enable-dp-attention " |
| 236 | + # f"--sglang-dp-size {sglang_attn_dp_size} " |
| 237 | + # "--sglang-moe-dense-tp-size 1 " |
| 238 | + # "--sglang-enable-dp-lm-head " |
| 239 | + # TODO why disable? |
| 240 | + # "--sglang-disable-radix-cache " |
| 241 | + # enable deepep for sglang |
| 242 | + # "--sglang-moe-a2a-backend deepep " |
| 243 | + # "--sglang-deepep-mode low_latency " |
| 244 | + # make every dp rank has 128 concurrency |
| 245 | + # "--sglang-server-concurrency 1024 " |
| 246 | + # f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} " |
| 247 | + # f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} " |
| 248 | + # f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} " |
| 249 | + # For quick experiments |
| 250 | + # """--sglang-json-model-override-args '{"num_hidden_layers": 5}' """ |
| 251 | + f"--sglang-chunked-prefill-size {sglang_world_size * 2048} " |
| 252 | + ) |
| 253 | + sglang_extra_env_vars = {} |
| 254 | + # sglang_extra_env_vars = { |
| 255 | + # "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": f"{sglang_decode_max_bs}", |
| 256 | + # } |
| 257 | + |
| 258 | + misc_args = ( |
| 259 | + # default dropout in megatron is 0.1 |
| 260 | + "--attention-dropout 0.0 " |
| 261 | + "--hidden-dropout 0.0 " |
| 262 | + # should be good for model performance |
| 263 | + "--accumulate-allreduce-grads-in-fp32 " |
| 264 | + "--attention-softmax-in-fp32 " |
| 265 | + # need to comment this when using model with MLA |
| 266 | + "--attention-backend flash " |
| 267 | + # 4GB will lead to oom, not checked yet |
| 268 | + f"--update-weight-buffer-size {2 * 1024 ** 3} " |
| 269 | + # TODO maybe enable it |
| 270 | + # use deepep for megatron |
| 271 | + # "--moe-enable-deepep " |
| 272 | + # "--moe-token-dispatcher-type flex " |
| 273 | + # ------------ |
| 274 | + f"--actor-num-nodes {args.num_nodes} " |
| 275 | + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " |
| 276 | + f"--num-gpus-per-node {args.num_gpus_per_node} " |
| 277 | + "--colocate " |
| 278 | + "--use-fault-tolerance " |
| 279 | + f"--dump-details /root/shared_data/{args.run_id}/dump_details " |
| 280 | + "--disable-weights-backuper " |
| 281 | + ) |
| 282 | + |
| 283 | + train_args = ( |
| 284 | + f"{ckpt_args} " |
| 285 | + f"{rollout_args} " |
| 286 | + f"{optimizer_args} " |
| 287 | + f"{grpo_args} " |
| 288 | + f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} " |
| 289 | + f"{perf_args} " |
| 290 | + f"{eval_args} " |
| 291 | + f"{sglang_args} " |
| 292 | + f"{misc_args} " |
| 293 | + f"{args.extra_args} " |
| 294 | + ) |
| 295 | + |
| 296 | + U.execute_train( |
| 297 | + train_args=train_args, |
| 298 | + config=args, |
| 299 | + num_gpus_per_node=args.num_gpus_per_node, |
| 300 | + megatron_model_type=args.megatron_model_type, |
| 301 | + extra_env_vars={ |
| 302 | + **sglang_extra_env_vars, |
| 303 | + # TODO handle these |
| 304 | + # "GLOO_SOCKET_IFNAME": "${MLP_SOCKET_IFNAME}", |
| 305 | + # "TP_SOCKET_IFNAME": "${MLP_SOCKET_IFNAME}", |
| 306 | + # "NVTE_BWD_LAYERNORM_SM_MARGIN": "20", |
| 307 | + # "NCCL_IB_TC": "160", |
| 308 | + # "NCCL_PXN_DISABLE": "0", |
| 309 | + # "NCCL_IB_GID_INDEX": "3", |
| 310 | + # "NCCL_NET_GDR_LEVEL": "4", |
| 311 | + # "NCCL_IB_RETRY_CNT": "7", |
| 312 | + # "NCCL_IB_TIMEOUT": "32", |
| 313 | + # "NCCL_IB_QPS_PER_CONNECTION": "8", |
| 314 | + # "NCCL_P2P_LEVEL": "NVL", |
| 315 | + # "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # TODO should this be used |
| 316 | + # "NCCL_NVLS_ENABLE": "0", |
| 317 | + # "NCCL_MIN_CTAS": "4", |
| 318 | + # "OMPI_MCA_pml": "ob1", |
| 319 | + # "OMPI_MCA_btl": "^openib", |
| 320 | + # "OMPI_MCA_routed": "direct", |
| 321 | + # "OMPI_MCA_routed_radix": "1024", |
| 322 | + # "OMPI_MCA_plm_rsh_no_tree_spawn": "1", |
| 323 | + # "OMPI_MCA_oob_tcp_if_include": "${MLP_SOCKET_IFNAME}", |
| 324 | + # "OMPI_MCA_btl_tcp_if_include": "${MLP_SOCKET_IFNAME}", |
| 325 | + }, |
| 326 | + ) |
| 327 | + |
| 328 | + |
| 329 | +if __name__ == "__main__": |
| 330 | + app() |
0 commit comments