@@ -75,6 +75,7 @@ class OptimalTrainingConfig:
7575 label : str
7676 output_path : str
7777 tokenized : LMMixtureDatasetConfig
78+ seed : int = 0
7879 validation_configs : dict [str , DatasetComponent ] | None = None
7980
8081
@@ -163,6 +164,7 @@ def run_optimal_training(config: OptimalTrainingConfig) -> None:
163164 f"FLOPs={ config .target_budget :.1e} " ,
164165 f"label={ config .label } " ,
165166 f"N={ params :.1e} " ,
167+ f"seed={ config .seed } " ,
166168 ],
167169 ),
168170 mp = jmp .get_policy ("p=f32,c=bfloat16" ),
@@ -181,6 +183,7 @@ def run_optimal_training(config: OptimalTrainingConfig) -> None:
181183 "token_repeat" : (ResourceAxis .REPLICA_DCN , ResourceAxis .REPLICA , ResourceAxis .DATA ),
182184 },
183185 ),
186+ seed = config .seed ,
184187 allow_nondivisible_batch_size = True ,
185188 ),
186189 train_seq_len = SEQ_LEN ,
@@ -215,23 +218,33 @@ def run_optimal_training(config: OptimalTrainingConfig) -> None:
215218}
216219
217220# --- Step 2: Optimal Training Runs ---
221+ # Seeds per budget: 1e21 and 1e22 get 3 seeds (0, 42, 62746); 1e23 gets seed 0 only
222+ SEEDS_PER_BUDGET : dict [float , list [int ]] = {
223+ 1e21 : [0 , 42 , 62746 ],
224+ 1e22 : [0 , 42 , 62746 ],
225+ 1e23 : [0 ],
226+ }
227+
218228optimal_runs : list [ExecutorStep ] = []
219229for budget , (tpu_type , batch_size ) in TARGET_BUDGETS .items ():
220- step = ExecutorStep (
221- name = f"{ EXPERIMENT_NAME } -optimal-{ budget :.0e} -v5" ,
222- fn = run_optimal_training ,
223- config = OptimalTrainingConfig (
224- analysis_output_path = analysis_step .as_input_name (),
225- target_budget = budget ,
226- tpu_type = tpu_type ,
227- batch_size = batch_size ,
228- label = LABEL ,
229- output_path = this_output_path (),
230- tokenized = nemotron_mix ,
231- validation_configs = validation_configs ,
232- ),
233- )
234- optimal_runs .append (step )
230+ for seed in SEEDS_PER_BUDGET [budget ]:
231+ suffix = f"-seed{ seed } " if seed != 0 else ""
232+ step = ExecutorStep (
233+ name = f"{ EXPERIMENT_NAME } -optimal-{ budget :.0e} -v5{ suffix } " ,
234+ fn = run_optimal_training ,
235+ config = OptimalTrainingConfig (
236+ analysis_output_path = analysis_step .as_input_name (),
237+ target_budget = budget ,
238+ tpu_type = tpu_type ,
239+ batch_size = batch_size ,
240+ label = LABEL ,
241+ output_path = this_output_path (),
242+ tokenized = nemotron_mix ,
243+ seed = seed ,
244+ validation_configs = validation_configs ,
245+ ),
246+ )
247+ optimal_runs .append (step )
235248
236249all_steps = [analysis_step , * optimal_runs ]
237250
0 commit comments