Skip to content

Commit 4e4c888

Browse files
authored
add moonlight test (THUDM#935)
1 parent 5ee4e82 commit 4e4c888

File tree

3 files changed

+123
-1
lines changed

3 files changed

+123
-1
lines changed

.github/workflows/pr-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
strategy:
4545
fail-fast: false
4646
matrix:
47-
info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}]
47+
info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}]
4848
defaults:
4949
run:
5050
working-directory: ${{ github.workspace }}

.github/workflows/pr-test.yml.j2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
{'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8},
66
{'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8},
77
{'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8},
8+
{'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8},
89
],
910
},
1011
'e2e-test-long': {

tests/test_moonlight_16B_A3B.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
import slime.utils.external_utils.command_utils as U
3+
4+
ENABLE_EVAL = bool(int(os.environ.get("SLIME_TEST_ENABLE_EVAL", "1")))
5+
TIGHT_HOST_MEMORY = bool(int(os.environ.get("SLIME_TEST_TIGHT_HOST_MEMORY", "1")))
6+
7+
MODEL_NAME = "Moonlight-16B-A3B-Instruct"
8+
MODEL_TYPE = "moonlight"
9+
NUM_GPUS = 8
10+
11+
12+
def prepare():
13+
U.exec_command("mkdir -p /root/models /root/datasets")
14+
U.exec_command(
15+
f"hf download moonshotai/Moonlight-16B-A3B-Instruct --local-dir /root/models/Moonlight-16B-A3B-Instruct"
16+
)
17+
U.hf_download_dataset("zhuzilin/dapo-math-17k")
18+
U.hf_download_dataset("zhuzilin/aime-2024")
19+
20+
U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS)
21+
22+
23+
def execute():
24+
ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist "
25+
26+
rollout_args = (
27+
"--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl "
28+
"--input-key prompt "
29+
"--label-key label "
30+
"--apply-chat-template "
31+
"--rollout-shuffle "
32+
"--rm-type math "
33+
"--num-rollout 3 "
34+
"--rollout-batch-size 8 "
35+
"--n-samples-per-prompt 8 "
36+
"--rollout-max-response-len 4096 "
37+
"--rollout-temperature 1 "
38+
"--global-batch-size 32 "
39+
)
40+
41+
eval_args = (
42+
f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}"
43+
"--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl "
44+
"--n-samples-per-eval-prompt 1 "
45+
"--eval-max-response-len 4096 "
46+
"--eval-top-k 1 "
47+
)
48+
49+
perf_args = (
50+
"--tensor-model-parallel-size 2 "
51+
"--sequence-parallel "
52+
"--pipeline-model-parallel-size 1 "
53+
"--context-parallel-size 2 "
54+
"--expert-model-parallel-size 8 "
55+
"--expert-tensor-parallel-size 1 "
56+
"--recompute-granularity full "
57+
"--recompute-method uniform "
58+
"--recompute-num-layers 1 "
59+
"--use-dynamic-batch-size "
60+
f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 2048} "
61+
)
62+
63+
grpo_args = (
64+
"--advantage-estimator gspo "
65+
f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}"
66+
"--kl-loss-coef 0.00 "
67+
"--kl-loss-type low_var_kl "
68+
"--kl-coef 0.00 "
69+
"--entropy-coef 0.00 "
70+
"--eps-clip 4e-4 "
71+
)
72+
73+
optimizer_args = (
74+
"--optimizer adam "
75+
"--lr 1e-6 "
76+
"--lr-decay-style constant "
77+
"--weight-decay 0.1 "
78+
"--adam-beta1 0.9 "
79+
"--adam-beta2 0.98 "
80+
)
81+
82+
sglang_args = (
83+
"--rollout-num-gpus-per-engine 2 " "--sglang-mem-fraction-static 0.8 " "--sglang-max-running-requests 512 "
84+
)
85+
86+
ci_args = "--ci-test "
87+
88+
misc_args = (
89+
"--attention-dropout 0.0 "
90+
"--hidden-dropout 0.0 "
91+
"--accumulate-allreduce-grads-in-fp32 "
92+
"--attention-softmax-in-fp32 "
93+
"--attention-backend flash "
94+
"--actor-num-nodes 1 "
95+
"--actor-num-gpus-per-node 8 "
96+
"--colocate "
97+
)
98+
99+
train_args = (
100+
f"{ckpt_args} "
101+
f"{rollout_args} "
102+
f"{optimizer_args} "
103+
f"{grpo_args} "
104+
f"{U.get_default_wandb_args(__file__)} "
105+
f"{perf_args} "
106+
f"{eval_args} "
107+
f"{sglang_args} "
108+
f"{ci_args} "
109+
f"{misc_args} "
110+
)
111+
112+
U.execute_train(
113+
train_args=train_args,
114+
num_gpus_per_node=NUM_GPUS,
115+
megatron_model_type=MODEL_TYPE,
116+
)
117+
118+
119+
if __name__ == "__main__":
120+
prepare()
121+
execute()

0 commit comments

Comments
 (0)