2121- Three data domains: pretrain (Nemotron), midtrain (full Dolmino), SFT
2222
2323Usage:
24+ # Run training
2425 python -m experiments.domain_phase_mix.three_phase_experiment [--n_runs N] [--seed SEED]
26+
27+ # Run analysis (after training completes)
28+ python -m experiments.domain_phase_mix.three_phase_experiment --analyze
2529"""
2630
2731import logging
2832import os
2933
3034from experiments .evals .task_configs import CORE_TASKS
3135from experiments .domain_phase_mix .proxy_sweep import regmix_60m_proxy
36+ from experiments .domain_phase_mix .analysis import create_analysis_step
3237from marin .execution .executor import executor_main
3338
3439from experiments .domain_phase_mix .config import PhaseSchedule
6368
6469
6570def create_three_phase_experiment (
66- name : str = "pinlin_calvin_xu/data_mixture/domain_phase_mix " ,
71+ name : str = "pinlin_calvin_xu/data_mixture/3_partitions_3_phases " ,
6772 experiment_budget : int = EXPERIMENT_BUDGET ,
6873 target_budget : int = TARGET_BUDGET ,
6974 batch_size : int = BATCH_SIZE ,
@@ -117,40 +122,55 @@ def create_three_phase_experiment(
117122def main (
118123 n_runs : int = 100 ,
119124 seed : int = 42 ,
120- name_prefix : str = "pinlin_calvin_xu/data_mixture/domain_phase_mix" ,
125+ name_prefix : str = "pinlin_calvin_xu/data_mixture/3_partitions_3_phases" ,
126+ analyze : bool = False ,
121127):
122128 """Main entry point for running the swarm experiment.
123129
124130 Args:
125131 n_runs: Number of training runs.
126132 seed: Random seed for weight sampling.
127133 name_prefix: Prefix for run names.
134+ analyze: If True, only run analysis step (collect results from W&B).
128135 """
129136 if os .getenv ("CI" , None ) is not None :
130137 logger .info ("Skipping experiment execution on CI environment." )
131138 return
132139
133- # Create experiment
134140 experiment = create_three_phase_experiment (name = name_prefix )
135141
136- # Create steps (weight_configs_step saves to GCS, training_steps run the models)
137142 weight_configs_step , training_steps = experiment .create_swarm_steps (
138143 n_runs = n_runs , seed = seed , name_prefix = name_prefix
139144 )
145+
146+ analysis_step = create_analysis_step (
147+ weight_configs_step = weight_configs_step ,
148+ name_prefix = name_prefix ,
149+ )
150+
151+ if analyze :
152+ # Only run analysis
153+ logger .info ("Running analysis only (collecting results from W&B)" )
154+ all_steps = [weight_configs_step , analysis_step ]
155+ executor_main (
156+ steps = all_steps ,
157+ description = f"Analysis for { name_prefix } " ,
158+ )
159+ return
140160
141161 # Log experiment details
142162 tokens_per_step = BATCH_SIZE * SEQ_LEN
143163 total_steps = EXPERIMENT_BUDGET // tokens_per_step
144164 phase1_end = int (total_steps * PHASE_BOUNDARIES [0 ])
145165 phase2_end = int (total_steps * PHASE_BOUNDARIES [1 ])
146166
147- logger .info (f"Created { len (training_steps )} training steps + 1 weight configs step" )
167+ logger .info (f"Created { len (training_steps )} training steps + 1 weight configs step + 1 analysis step " )
148168 logger .info (f"Total tokens per run: { EXPERIMENT_BUDGET :,} " )
149169 logger .info (f"Total steps per run: { total_steps :,} " )
150170 logger .info (f"Phase boundaries: step { phase1_end } (33%), step { phase2_end } (67%)" )
151171
152- # All steps: weight configs first, then training runs
153- all_steps = [weight_configs_step , * training_steps ]
172+ # All steps: weight configs first, then training runs, then analysis
173+ all_steps = [weight_configs_step , * training_steps , analysis_step ]
154174
155175 executor_main (
156176 steps = all_steps ,
@@ -177,9 +197,14 @@ def _parse_args():
177197 parser .add_argument (
178198 "--name_prefix" ,
179199 type = str ,
180- default = "pinlin_calvin_xu/data_mixture/domain_phase_mix " ,
200+ default = "pinlin_calvin_xu/data_mixture/3_partitions_3_phases " ,
181201 help = "Prefix for run names." ,
182202 )
203+ parser .add_argument (
204+ "--analyze" ,
205+ action = "store_true" ,
206+ help = "Run analysis only (collect results from W&B and export CSV)." ,
207+ )
183208 return parser .parse_known_args ()
184209
185210
@@ -193,4 +218,5 @@ def _parse_args():
193218 n_runs = args .n_runs ,
194219 seed = args .seed ,
195220 name_prefix = args .name_prefix ,
221+ analyze = args .analyze ,
196222 )
0 commit comments