|
1 | | -diff --git a/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn_grpo_npu.py b/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn_grpo_npu.py |
2 | | -new file mode 100644 |
3 | | -index 00000000..8c862444 |
4 | | ---- /dev/null |
5 | | -+++ b/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn_grpo_npu.py |
6 | | -@@ -0,0 +1,163 @@ |
7 | | -+import os |
8 | | -+ |
9 | | -+import slime.utils.misc as U |
10 | | -+from slime.utils.external_utils.command_utils import execute_train_npu |
11 | | -+ |
12 | | -+MODEL_NAME = os.environ.get("SLIME_SCRIPT_MODEL_NAME", "Qwen3-VL-2B-Instruct") |
13 | | -+assert MODEL_NAME in { |
14 | | -+ "Qwen3-VL-2B-Instruct", |
15 | | -+ "Qwen3-VL-4B-Instruct", |
16 | | -+ "Qwen3-VL-8B-Instruct", |
17 | | -+ "Qwen3-VL-2B-Thinking", |
18 | | -+ "Qwen3-VL-4B-Thinking", |
19 | | -+ "Qwen3-VL-8B-Thinking", |
20 | | -+} |
21 | | -+ |
22 | | -+EXTERNAL_RAY = int(os.environ.get("SLIME_SCRIPT_EXTERNAL_RAY", "0")) |
23 | | -+TRAIN_BACKEND = os.environ.get("SLIME_SCRIPT_TRAIN_BACKEND", "fsdp").lower() |
24 | | -+assert TRAIN_BACKEND in {"fsdp", "megatron"} |
25 | | -+ |
26 | | -+DATASET_NAME = "VeraIsHere/geo3k_imgurl_processed" |
27 | | -+DATA_ROOT = "/root/datasets/geo3k_imgurl_processed" |
28 | | -+TRAIN_DATA_PATH = os.path.join(DATA_ROOT, "train.parquet") |
29 | | -+ |
30 | | -+ |
31 | | -+def get_megatron_model_type(model_name: str) -> str: |
32 | | -+ model_type = model_name.replace("-Instruct", "").replace("-Thinking", "") |
33 | | -+ model_type = model_type.replace("Qwen3-VL-", "qwen3-") |
34 | | -+ return model_type.replace("-2B", "-1.7B") |
35 | | -+ |
36 | | -+ |
37 | | -+def execute(): |
38 | | -+ ckpt_args = f"--hf-checkpoint /home/data/{MODEL_NAME} " |
39 | | -+ |
40 | | -+ wandb_args = ( |
41 | | -+ ( |
42 | | -+ "--use-wandb " |
43 | | -+ "--wandb-project slime-dev " |
44 | | -+ "--wandb-group geo3k_vlm_multi_turn " |
45 | | -+ f"--wandb-key '{wandb_api_key}' " |
46 | | -+ ) |
47 | | -+ if (wandb_api_key := os.environ.get("WANDB_API_KEY")) |
48 | | -+ else "" |
49 | | -+ ) |
50 | | -+ |
51 | | -+ rollout_args = ( |
52 | | -+ f"--prompt-data {TRAIN_DATA_PATH} " |
53 | | -+ "--input-key problem " |
54 | | -+ "--label-key answer " |
55 | | -+ '--multimodal-keys \'{"image": "images"}\' ' |
56 | | -+ "--rm-type math " |
57 | | -+ "--apply-chat-template " |
58 | | -+ "--custom-generate-function-path examples.geo3k_vlm_multi_turn.rollout.generate " |
59 | | -+ "--custom-config-path examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_config.yaml " |
60 | | -+ "--rollout-shuffle " |
61 | | -+ "--num-rollout 3000 " |
62 | | -+ "--rollout-batch-size 32 " |
63 | | -+ "--n-samples-per-prompt 8 " |
64 | | -+ "--rollout-max-response-len 4096 " |
65 | | -+ "--rollout-temperature 1 " |
66 | | -+ "--global-batch-size 256 " |
67 | | -+ ) |
68 | | -+ |
69 | | -+ # eval_args = ( |
70 | | -+ # "--eval-interval 20 " |
71 | | -+ # f"--eval-prompt-data geo3k_eval {TRAIN_DATA_PATH}@[0:64] " |
72 | | -+ # "--n-samples-per-eval-prompt 1 " |
73 | | -+ # "--eval-max-response-len 4096 " |
74 | | -+ # "--eval-top-k 1 " |
75 | | -+ # ) |
76 | | -+ |
77 | | -+ grpo_args = ( |
78 | | -+ "--advantage-estimator grpo " |
79 | | -+ "--kl-loss-coef 0.00 " |
80 | | -+ "--kl-loss-type low_var_kl " |
81 | | -+ "--kl-coef 0.00 " |
82 | | -+ "--entropy-coef 0.00 " |
83 | | -+ "--eps-clip 0.2 " |
84 | | -+ "--eps-clip-high 0.28 " |
85 | | -+ "--use-kl-loss " |
86 | | -+ ) |
87 | | -+ |
88 | | -+ optimizer_args = ( |
89 | | -+ "--optimizer adam " |
90 | | -+ "--lr 1e-6 " |
91 | | -+ "--lr-decay-style constant " |
92 | | -+ "--weight-decay 0.1 " |
93 | | -+ "--adam-beta1 0.9 " |
94 | | -+ "--adam-beta2 0.98 " |
95 | | -+ |
96 | | -+ "--optimizer-cpu-offload " |
97 | | -+ "--overlap-cpu-optimizer-d2h-h2d " |
98 | | -+ "--use-precision-aware-optimizer " |
99 | | -+ ) |
100 | | -+ |
101 | | -+ sglang_args = ( |
102 | | -+ "--rollout-num-gpus-per-engine 1 " |
103 | | -+ "--sglang-mem-fraction-static 0.6 " |
104 | | -+ f"--sglang-cuda-graph-bs {' '.join(map(str, [4, 8] + list(range(16, 257, 8))))} " |
105 | | -+ "--sglang-device npu " |
106 | | -+ "--sglang-disable-radix-cache " |
107 | | -+ "--sglang-chunked-prefill-size 32768 " |
108 | | -+ "--sglang-max-prefill-tokens 4000 " |
109 | | -+ "--sglang-max-total-tokens 327680 " |
110 | | -+ ) |
111 | | -+ |
112 | | -+ megatron_args = ( |
113 | | -+ "--train-backend megatron " |
114 | | -+ f"--load /home/data/{MODEL_NAME} " |
115 | | -+ f"--ref-load /home/data/{MODEL_NAME} " |
116 | | -+ "--tensor-model-parallel-size 4 " |
117 | | -+ "--sequence-parallel " |
118 | | -+ "--pipeline-model-parallel-size 1 " |
119 | | -+ "--context-parallel-size 1 " |
120 | | -+ "--expert-model-parallel-size 1 " |
121 | | -+ "--expert-tensor-parallel-size 1 " |
122 | | -+ "--recompute-granularity full " |
123 | | -+ "--recompute-method uniform " |
124 | | -+ "--recompute-num-layers 1 " |
125 | | -+ "--use-dynamic-batch-size " |
126 | | -+ "--max-tokens-per-gpu 16384 " |
127 | | -+ "--balance-data " |
128 | | -+ "--attention-dropout 0.0 " |
129 | | -+ "--hidden-dropout 0.0 " |
130 | | -+ "--accumulate-allreduce-grads-in-fp32 " |
131 | | -+ "--attention-softmax-in-fp32 " |
132 | | -+ "--attention-backend flash " |
133 | | -+ "--megatron-to-hf-mode bridge " |
134 | | -+ ) |
135 | | -+ |
136 | | -+ misc_args = ( |
137 | | -+ "--actor-num-nodes 1 " f"--actor-num-gpus-per-node 8 " f"--rollout-num-gpus 8 " |
138 | | -+ "--no-gradient-accumulation-fusion " |
139 | | -+ "--use-flash-attn " |
140 | | -+ ) |
141 | | -+ |
142 | | -+ if TRAIN_BACKEND == "megatron": |
143 | | -+ backend_args = megatron_args |
144 | | -+ megatron_model_type = get_megatron_model_type(MODEL_NAME) |
145 | | -+ os.environ["MODEL_ARGS_ROTARY_BASE"] = "5000000" |
146 | | -+ else: |
147 | | -+ exit() |
148 | | -+ |
149 | | -+ train_args = ( |
150 | | -+ f"{ckpt_args} " |
151 | | -+ f"{rollout_args} " |
152 | | -+ f"{optimizer_args} " |
153 | | -+ f"{grpo_args} " |
154 | | -+ f"{sglang_args} " |
155 | | -+ f"{backend_args} " |
156 | | -+ f"{misc_args} " |
157 | | -+ f"{wandb_args} " |
158 | | -+ # f"{get_default_wandb_args(__file__)} " |
159 | | -+ ) |
160 | | -+ |
161 | | -+ execute_train_npu( |
162 | | -+ train_args=train_args, |
163 | | -+ megatron_model_type=megatron_model_type, |
164 | | -+ extra_env_vars=({"WANDB_API_KEY": os.environ["WANDB_API_KEY"]} if os.environ.get("WANDB_API_KEY") else {}), |
165 | | -+ ) |
166 | | -+ |
167 | | -+ |
168 | | -+if __name__ == "__main__": |
169 | | -+ execute() |
170 | 1 | diff --git a/slime/backends/megatron_utils/__init__.py b/slime/backends/megatron_utils/__init__.py |
171 | 2 | index a4666fbe..96e7b1b0 100644 |
172 | 3 | --- a/slime/backends/megatron_utils/__init__.py |
|
0 commit comments