2424)
2525
2626
27- def prepare (workspace , cluster , num_gpus , expname_prefix , wandb_params ):
27+ def prepare (workspace , cluster , expname_prefix ):
2828 # data preparation needs to run locally without container, so not wrapping with run_cmd
2929 prepare_datasets (["aime24" , "aime25" ])
3030
@@ -90,7 +90,7 @@ def run_sdg(workspace, cluster, num_gpus, expname_prefix, wandb_params):
9090 )
9191
9292
93- def run_training (workspace , cluster , num_gpus , expname_prefix , wandb_params ):
93+ def run_training (workspace , cluster , num_gpus , expname_prefix , backend , wandb_params ):
9494 # convert the generated solutions to a format that can be used for training
9595 run_cmd (
9696 ctx = wrap_arguments (
@@ -110,47 +110,54 @@ def run_training(workspace, cluster, num_gpus, expname_prefix, wandb_params):
110110 )
111111
112112 # train the model
113-
114- sft_nemo_rl (
115- ctx = wrap_arguments (
116- "++policy.max_total_sequence_length=8192 "
117- "++policy.train_global_batch_size=32 "
118- "++policy.tensor_model_parallel_size=4 "
119- "++policy.context_parallel_size=2 "
120- "++policy.lr=1e-5 "
121- "++sft.max_num_epochs=2 "
122- ),
123- cluster = cluster ,
124- output_dir = f"{ workspace } /training" ,
125- hf_model = "Qwen/Qwen2.5-14B-Instruct" ,
126- backend = "megatron" ,
127- num_gpus = num_gpus ,
128- num_nodes = 1 ,
129- disable_wandb = wandb_params ["disable_wandb" ],
130- wandb_project = wandb_params ["wandb_project" ],
131- training_data = f"{ workspace } /sft-data.jsonl" ,
132- expname = f"{ expname_prefix } -training" ,
133- run_after = f"{ expname_prefix } -prepare-training-data" ,
134- final_hf_path = f"{ workspace } /training/qwen2.5-14b-improved-hf" ,
135- )
136-
137-
138- def final_eval (workspace , cluster , num_gpus , expname_prefix , wandb_params ):
113+ base_args = [
114+ "++policy.max_total_sequence_length=8192" ,
115+ "++policy.train_global_batch_size=32" ,
116+ "++policy.tensor_model_parallel_size=4" ,
117+ "++policy.context_parallel_size=2" ,
118+ "++policy.lr=1e-5" ,
119+ "++sft.max_num_epochs=2" ,
120+ ]
121+ # For FSDP, sequence_packing cannot be used with context parallel
122+ for training_backend in backend :
123+ args = list (base_args )
124+ if training_backend == "fsdp" :
125+ args .append ("++policy.sequence_packing.enabled=False" )
126+
127+ sft_nemo_rl (
128+ ctx = wrap_arguments (" " .join (args )),
129+ cluster = cluster ,
130+ output_dir = f"{ workspace } /training-{ training_backend } " ,
131+ hf_model = "Qwen/Qwen2.5-14B-Instruct" ,
132+ backend = training_backend ,
133+ num_gpus = num_gpus ,
134+ num_nodes = 1 ,
135+ disable_wandb = wandb_params ["disable_wandb" ],
136+ wandb_project = wandb_params ["wandb_project" ],
137+ training_data = f"{ workspace } /sft-data.jsonl" ,
138+ expname = f"{ expname_prefix } -training-{ training_backend } " ,
139+ run_after = f"{ expname_prefix } -prepare-training-data" ,
140+ final_hf_path = f"{ workspace } /training-{ training_backend } /qwen2.5-14b-improved-hf" ,
141+ )
142+
143+
144+ def final_eval (workspace , cluster , num_gpus , expname_prefix , backend , wandb_params ):
139145 # launching evaluation
140- eval (
141- ctx = wrap_arguments ("++inference.tokens_to_generate=16384 ++parse_reasoning=True " ),
142- cluster = cluster ,
143- model = f"{ workspace } /training/qwen2.5-14b-improved-hf" ,
144- server_type = "vllm" ,
145- server_gpus = num_gpus ,
146- benchmarks = "aime24:8,aime25:8" ,
147- output_dir = f"{ workspace } /evals/after-training" ,
148- num_jobs = 1 ,
149- expname = f"{ expname_prefix } -final-eval" ,
150- run_after = f"{ expname_prefix } -training" ,
151- wandb_name = f"{ expname_prefix } -final-eval" if not wandb_params ["disable_wandb" ] else None ,
152- wandb_project = wandb_params ["wandb_project" ],
153- )
146+ for training_backend in backend :
147+ eval (
148+ ctx = wrap_arguments ("++inference.tokens_to_generate=16384 ++parse_reasoning=True " ),
149+ cluster = cluster ,
150+ model = f"{ workspace } /training-{ training_backend } /qwen2.5-14b-improved-hf" ,
151+ server_type = "vllm" ,
152+ server_gpus = num_gpus ,
153+ benchmarks = "aime24:8,aime25:8" ,
154+ output_dir = f"{ workspace } /evals/after-training-{ training_backend } " ,
155+ num_jobs = 1 ,
156+ expname = f"{ expname_prefix } -final-eval-{ training_backend } " ,
157+ run_after = f"{ expname_prefix } -training-{ training_backend } " ,
158+ wandb_name = f"{ expname_prefix } -final-eval" if not wandb_params ["disable_wandb" ] else None ,
159+ wandb_project = wandb_params ["wandb_project" ],
160+ )
154161
155162
156163def initial_eval (workspace , cluster , num_gpus , expname_prefix , wandb_params ):
@@ -203,21 +210,42 @@ def initial_eval(workspace, cluster, num_gpus, expname_prefix, wandb_params):
203210 default = "nemo-skills" ,
204211 help = "WandB project name for tracking experiments." ,
205212 )
213+ parser .add_argument (
214+ "--backend" ,
215+ type = str ,
216+ nargs = "+" ,
217+ choices = ["megatron" , "fsdp" ],
218+ default = ["megatron" ],
219+ )
220+
206221 args = parser .parse_args ()
207222
208223 wandb_params = {
209224 "disable_wandb" : args .disable_wandb ,
210225 "wandb_project" : args .wandb_project ,
211226 }
212- args = (
227+ common_args = (
213228 args .workspace ,
214229 args .cluster ,
215230 args .num_gpus ,
216231 args .expname_prefix ,
232+ args .backend ,
217233 wandb_params ,
218234 )
219- prepare (* args )
220- initial_eval (* args )
221- run_sdg (* args )
222- run_training (* args )
223- final_eval (* args )
235+ prepare (workspace = args .workspace , cluster = args .cluster , expname_prefix = args .expname_prefix )
236+ initial_eval (
237+ workspace = args .workspace ,
238+ cluster = args .cluster ,
239+ num_gpus = args .num_gpus ,
240+ expname_prefix = args .expname_prefix ,
241+ wandb_params = wandb_params ,
242+ )
243+ run_sdg (
244+ workspace = args .workspace ,
245+ cluster = args .cluster ,
246+ num_gpus = args .num_gpus ,
247+ expname_prefix = args .expname_prefix ,
248+ wandb_params = wandb_params ,
249+ )
250+ run_training (* common_args )
251+ final_eval (* common_args )
0 commit comments