Skip to content

Commit 46c07bd

Browse files
committed
add analyze
1 parent ffb69bd commit 46c07bd

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

experiments/domain_phase_mix/analysis.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,14 @@
4949
DEFAULT_METRICS = [
5050
"eval/loss",
5151
"eval/paloma/c4_en/bpb",
52-
"eval/paloma/wikipedia_en/bpb",
53-
"eval_harness/gsm8k/acc",
54-
"eval_harness/mmlu/acc",
55-
"eval_harness/hellaswag/acc",
56-
"eval_harness/arc_challenge/acc",
52+
"eval/paloma/m2d2_wikipedia_unsplit/bpb",
53+
"lm_eval/arc_challenge/acc",
54+
"lm_eval/arc_challenge/acc_norm",
55+
"lm_eval/hellaswag_0shot/acc",
56+
"lm_eval/hellaswag_0shot/acc_norm",
57+
"lm_eval/piqa/acc",
58+
"lm_eval/boolq/acc",
59+
"lm_eval/averages/macro_avg_acc",
5760
]
5861

5962

@@ -114,7 +117,7 @@ def collect_results(config: CollectResultsConfig):
114117
logger.info(f"Found {len(runs)} W&B runs")
115118

116119
# 3. Match runs to configs by run_id
117-
matched = match_runs_to_configs(runs, configs)
120+
matched = match_runs_to_configs(runs, configs, experiment_name=experiment_name)
118121
logger.info(f"Matched {sum(1 for m in matched if m.get('wandb_run_id'))} runs to configs")
119122

120123
# 4. Build DataFrame with all weights and metrics
@@ -203,22 +206,29 @@ def query_wandb_runs(
203206
return results
204207

205208

206-
def match_runs_to_configs(runs: list[dict], configs: list[dict]) -> list[dict]:
209+
def match_runs_to_configs(
210+
runs: list[dict], configs: list[dict], experiment_name: str
211+
) -> list[dict]:
207212
"""Match W&B runs to weight configurations by run_id pattern.
208213
209-
Extracts run_id from W&B run names (e.g., "experiment/run_042" -> 42)
210-
and matches to the corresponding config.
214+
Extracts run_id from W&B run names and matches to the corresponding config.
215+
Tries multiple patterns to handle different W&B naming conventions:
216+
1. Full path: "pinlin_calvin_xu/data_mixture/3_partitions_3_phases/run_00042"
217+
2. Short name: "run_00042-abc123" (W&B may truncate long names)
211218
212219
Args:
213220
runs: List of W&B run dictionaries.
214221
configs: List of weight configuration dictionaries.
222+
experiment_name: Experiment name prefix to filter runs (required to avoid false positives).
215223
216224
Returns:
217225
List of matched dictionaries with config + run info.
218226
"""
219227
# Build lookup from run_id to W&B run
220228
run_by_id: dict[int, dict] = {}
221-
run_id_pattern = re.compile(r"run_(\d+)")
229+
230+
escaped_name = re.escape(experiment_name)
231+
run_id_pattern = re.compile(rf"{escaped_name}/run_(\d+)")
222232

223233
for run in runs:
224234
name = run.get("wandb_run_name", "")

experiments/domain_phase_mix/three_phase_experiment.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,19 @@
2121
- Three data domains: pretrain (Nemotron), midtrain (full Dolmino), SFT
2222
2323
Usage:
24+
# Run training
2425
python -m experiments.domain_phase_mix.three_phase_experiment [--n_runs N] [--seed SEED]
26+
27+
# Run analysis (after training completes)
28+
python -m experiments.domain_phase_mix.three_phase_experiment --analyze
2529
"""
2630

2731
import logging
2832
import os
2933

3034
from experiments.evals.task_configs import CORE_TASKS
3135
from experiments.domain_phase_mix.proxy_sweep import regmix_60m_proxy
36+
from experiments.domain_phase_mix.analysis import create_analysis_step
3237
from marin.execution.executor import executor_main
3338

3439
from experiments.domain_phase_mix.config import PhaseSchedule
@@ -63,7 +68,7 @@
6368

6469

6570
def create_three_phase_experiment(
66-
name: str = "pinlin_calvin_xu/data_mixture/domain_phase_mix",
71+
name: str = "pinlin_calvin_xu/data_mixture/3_partitions_3_phases",
6772
experiment_budget: int = EXPERIMENT_BUDGET,
6873
target_budget: int = TARGET_BUDGET,
6974
batch_size: int = BATCH_SIZE,
@@ -117,40 +122,55 @@ def create_three_phase_experiment(
117122
def main(
118123
n_runs: int = 100,
119124
seed: int = 42,
120-
name_prefix: str = "pinlin_calvin_xu/data_mixture/domain_phase_mix",
125+
name_prefix: str = "pinlin_calvin_xu/data_mixture/3_partitions_3_phases",
126+
analyze: bool = False,
121127
):
122128
"""Main entry point for running the swarm experiment.
123129
124130
Args:
125131
n_runs: Number of training runs.
126132
seed: Random seed for weight sampling.
127133
name_prefix: Prefix for run names.
134+
analyze: If True, only run analysis step (collect results from W&B).
128135
"""
129136
if os.getenv("CI", None) is not None:
130137
logger.info("Skipping experiment execution on CI environment.")
131138
return
132139

133-
# Create experiment
134140
experiment = create_three_phase_experiment(name=name_prefix)
135141

136-
# Create steps (weight_configs_step saves to GCS, training_steps run the models)
137142
weight_configs_step, training_steps = experiment.create_swarm_steps(
138143
n_runs=n_runs, seed=seed, name_prefix=name_prefix
139144
)
145+
146+
analysis_step = create_analysis_step(
147+
weight_configs_step=weight_configs_step,
148+
name_prefix=name_prefix,
149+
)
150+
151+
if analyze:
152+
# Only run analysis
153+
logger.info("Running analysis only (collecting results from W&B)")
154+
all_steps = [weight_configs_step, analysis_step]
155+
executor_main(
156+
steps=all_steps,
157+
description=f"Analysis for {name_prefix}",
158+
)
159+
return
140160

141161
# Log experiment details
142162
tokens_per_step = BATCH_SIZE * SEQ_LEN
143163
total_steps = EXPERIMENT_BUDGET // tokens_per_step
144164
phase1_end = int(total_steps * PHASE_BOUNDARIES[0])
145165
phase2_end = int(total_steps * PHASE_BOUNDARIES[1])
146166

147-
logger.info(f"Created {len(training_steps)} training steps + 1 weight configs step")
167+
logger.info(f"Created {len(training_steps)} training steps + 1 weight configs step + 1 analysis step")
148168
logger.info(f"Total tokens per run: {EXPERIMENT_BUDGET:,}")
149169
logger.info(f"Total steps per run: {total_steps:,}")
150170
logger.info(f"Phase boundaries: step {phase1_end} (33%), step {phase2_end} (67%)")
151171

152-
# All steps: weight configs first, then training runs
153-
all_steps = [weight_configs_step, *training_steps]
172+
# All steps: weight configs first, then training runs, then analysis
173+
all_steps = [weight_configs_step, *training_steps, analysis_step]
154174

155175
executor_main(
156176
steps=all_steps,
@@ -177,9 +197,14 @@ def _parse_args():
177197
parser.add_argument(
178198
"--name_prefix",
179199
type=str,
180-
default="pinlin_calvin_xu/data_mixture/domain_phase_mix",
200+
default="pinlin_calvin_xu/data_mixture/3_partitions_3_phases",
181201
help="Prefix for run names.",
182202
)
203+
parser.add_argument(
204+
"--analyze",
205+
action="store_true",
206+
help="Run analysis only (collect results from W&B and export CSV).",
207+
)
183208
return parser.parse_known_args()
184209

185210

@@ -193,4 +218,5 @@ def _parse_args():
193218
n_runs=args.n_runs,
194219
seed=args.seed,
195220
name_prefix=args.name_prefix,
221+
analyze=args.analyze,
196222
)

0 commit comments

Comments
 (0)