Skip to content

Commit ca3477d

Browse files
[moE] Add 3e18 FLOP hparam sweep experiment for MoE recipe search
Adds an experiment script that sweeps core MoE architecture knobs (expert count, shared expert, routing density, aux loss coefficients) at 3e18 FLOPs using the grug MoE template. Includes tests validating FLOP budget math and sweep config validity. Fixes #4018 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a243fe5 commit ca3477d

2 files changed

Lines changed: 304 additions & 0 deletions

File tree

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""MoE hyperparameter sweep at 3e18 FLOPs.
5+
6+
Explores core MoE architecture knobs (expert count, shared expert, routing
7+
density, aux loss coefficients) at the 3e18 FLOP budget to establish a
8+
reference recipe for the 10T scaling path.
9+
10+
Each config targets 3e18 total training FLOPs (including 3x fwd+bwd multiplier).
11+
All runs use the grug MoE template on Nemotron mix with seq_len=4096 and
12+
vocab=128256.
13+
14+
See https://github.com/marin-community/marin/issues/4018.
15+
"""
16+
17+
import dataclasses
18+
import logging
19+
import os
20+
from dataclasses import dataclass
21+
from datetime import timedelta
22+
23+
import jmp
24+
from fray.cluster import ResourceConfig
25+
from levanter.checkpoint import CheckpointerConfig
26+
from levanter.data.text import LmDataConfig
27+
from levanter.optim import AdamConfig
28+
from levanter.tracker import TrackerConfig
29+
from levanter.tracker.wandb import WandbConfig
30+
from levanter.trainer import TrainerConfig
31+
from levanter.utils.flop_utils import lm_flops_per_token
32+
from levanter.utils.mesh import MeshConfig
33+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
34+
from marin.processing.tokenize import add_validation_sets_to_mixture
35+
36+
from experiments.defaults import default_validation_sets
37+
from experiments.grug.moe.launch import GrugMoeLaunchConfig, run_grug_moe
38+
from experiments.grug.moe.model import GrugModelConfig
39+
from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig
40+
from experiments.pretraining_datasets import nemotron_mix_block_shuffle
41+
42+
logger = logging.getLogger(__name__)
43+
44+
TARGET_FLOPS = 3e18
45+
SEQ_LEN = 4096
46+
VOCAB_SIZE = 128_256
47+
BATCH_SIZE = 256
48+
49+
NEMOTRON_MIX_WITH_VALIDATION = add_validation_sets_to_mixture(
50+
nemotron_mix_block_shuffle,
51+
default_validation_sets(tokenizer=nemotron_mix_block_shuffle.tokenizer),
52+
)
53+
54+
55+
def _steps_for_budget(model: GrugModelConfig, batch_size: int) -> int:
56+
"""Compute training steps to hit TARGET_FLOPS for a given model config."""
57+
fpt = lm_flops_per_token(
58+
hidden_dim=model.hidden_dim,
59+
intermediate_dim=model.intermediate_dim,
60+
shared_intermediate_dim=model.shared_expert_intermediate_dim,
61+
num_layers=model.num_layers,
62+
num_kv_heads=model.num_kv_heads,
63+
num_heads=model.num_heads,
64+
seq_len=model.max_seq_len,
65+
vocab_size=model.vocab_size,
66+
glu=True,
67+
num_experts=model.num_experts,
68+
num_shared_experts=1 if model.shared_expert_intermediate_dim > 0 else 0,
69+
num_experts_per_tok=model.num_experts_per_token,
70+
)
71+
flops_per_step = 3 * fpt * model.max_seq_len * batch_size
72+
return int(TARGET_FLOPS / flops_per_step)
73+
74+
75+
# ---------------------------------------------------------------------------
76+
# Baseline model: d=768, L=12, E=8, K=2, shared expert
77+
# ~1300 steps at bs=256 → ~1.4B tokens
78+
# ---------------------------------------------------------------------------
79+
BASELINE = GrugModelConfig(
80+
vocab_size=VOCAB_SIZE,
81+
hidden_dim=768,
82+
intermediate_dim=2048,
83+
shared_expert_intermediate_dim=2048,
84+
num_experts=8,
85+
num_experts_per_token=2,
86+
num_layers=12,
87+
num_heads=12,
88+
num_kv_heads=4,
89+
max_seq_len=SEQ_LEN,
90+
head_dim=None,
91+
load_balancing_loss_coef=0.01,
92+
router_z_loss_coef=0.001,
93+
)
94+
95+
96+
@dataclass(frozen=True)
97+
class SweepPoint:
98+
"""One arm of the sweep."""
99+
100+
name: str
101+
model: GrugModelConfig
102+
batch_size: int = BATCH_SIZE
103+
104+
105+
def _expert_count_variants() -> list[SweepPoint]:
106+
"""Axis 1: E in {8, 16, 32} with K=2 (FLOPs ~constant)."""
107+
return [
108+
SweepPoint("e8-k2", BASELINE),
109+
SweepPoint("e16-k2", dataclasses.replace(BASELINE, num_experts=16)),
110+
SweepPoint("e32-k2", dataclasses.replace(BASELINE, num_experts=32)),
111+
]
112+
113+
114+
def _shared_expert_variants() -> list[SweepPoint]:
115+
"""Axis 2: shared expert on vs off."""
116+
no_shared = dataclasses.replace(BASELINE, shared_expert_intermediate_dim=0)
117+
return [
118+
SweepPoint("shared-on", BASELINE),
119+
SweepPoint("shared-off", no_shared),
120+
]
121+
122+
123+
def _routing_density_variants() -> list[SweepPoint]:
124+
"""Axis 3: K=2 vs K=4. K=4 doubles routed MLP FLOPs so we halve intermediate_dim."""
125+
k4 = dataclasses.replace(
126+
BASELINE,
127+
num_experts_per_token=4,
128+
intermediate_dim=1024,
129+
shared_expert_intermediate_dim=1024,
130+
)
131+
return [
132+
SweepPoint("k2-i2048", BASELINE),
133+
SweepPoint("k4-i1024", k4),
134+
]
135+
136+
137+
def _aux_loss_variants() -> list[SweepPoint]:
138+
"""Axis 4: aux loss coefficient grid."""
139+
points = []
140+
for lbl in [0.001, 0.01, 0.1]:
141+
for rzl in [0.0, 0.001]:
142+
name = f"lbl{lbl}-rzl{rzl}"
143+
model = dataclasses.replace(
144+
BASELINE,
145+
load_balancing_loss_coef=lbl,
146+
router_z_loss_coef=rzl if rzl > 0 else None,
147+
)
148+
points.append(SweepPoint(name, model))
149+
return points
150+
151+
152+
def all_sweep_points() -> list[SweepPoint]:
153+
"""Deduplicated union of all sweep axes."""
154+
seen: set[str] = set()
155+
points: list[SweepPoint] = []
156+
for axis_fn in [_expert_count_variants, _shared_expert_variants, _routing_density_variants, _aux_loss_variants]:
157+
for pt in axis_fn():
158+
if pt.name not in seen:
159+
seen.add(pt.name)
160+
points.append(pt)
161+
return points
162+
163+
164+
def _resolve_run_id(sweep_name: str) -> str:
165+
run_id = f"moe-3e18-{sweep_name}"
166+
ferry_date = os.environ.get("FERRY_DATE")
167+
if ferry_date:
168+
run_id = f"{run_id}-{ferry_date}"
169+
return run_id
170+
171+
172+
def _build_step(point: SweepPoint) -> ExecutorStep:
173+
"""Build an ExecutorStep for one sweep arm."""
174+
steps = _steps_for_budget(point.model, point.batch_size)
175+
run_id = _resolve_run_id(point.name)
176+
logger.info("Sweep point %s: %d steps at bs=%d", point.name, steps, point.batch_size)
177+
178+
config = GrugMoeLaunchConfig(
179+
model=versioned(point.model),
180+
data=NEMOTRON_MIX_WITH_VALIDATION,
181+
output_path=this_output_path(),
182+
run_id=run_id,
183+
resources=versioned(ResourceConfig.with_tpu("v5p-8")),
184+
steps=versioned(steps),
185+
batch_size=versioned(point.batch_size),
186+
seed=versioned(0),
187+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
188+
tracker=WandbConfig(
189+
project="marin",
190+
tags=["moe", "3e18", "hparam-sweep"],
191+
group="moe-3e18-sweep",
192+
name=None,
193+
),
194+
optimizer=versioned(
195+
AdamConfig(
196+
learning_rate=3e-3,
197+
weight_decay=0.1,
198+
lr_schedule="cosine",
199+
decay=0.2,
200+
min_lr_ratio=0.1,
201+
warmup=200,
202+
)
203+
),
204+
grug_trainer=versioned(
205+
GrugTrainerConfig(
206+
z_loss_weight=1e-4,
207+
ema_beta=None,
208+
log_every=1,
209+
)
210+
),
211+
eval=versioned(
212+
GrugEvalConfig(
213+
eval_batch_size=256,
214+
steps_per_eval=200,
215+
max_eval_batches=8,
216+
eval_current=True,
217+
eval_ema=False,
218+
)
219+
),
220+
)
221+
222+
return ExecutorStep(
223+
name=f"moe-3e18/{point.name}",
224+
fn=run_grug_moe,
225+
config=config,
226+
)
227+
228+
229+
def build_sweep_steps() -> list[ExecutorStep]:
230+
return [_build_step(pt) for pt in all_sweep_points()]
231+
232+
233+
if __name__ == "__main__":
234+
executor_main(
235+
steps=build_sweep_steps(),
236+
description="MoE hparam sweep at 3e18 FLOPs (issue #4018).",
237+
)

tests/test_moe_3e18_sweep.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for the MoE 3e18 FLOP hparam sweep (issue #4018)."""
5+
6+
from experiments.moe_3e18_hparam_sweep import (
7+
BASELINE,
8+
TARGET_FLOPS,
9+
all_sweep_points,
10+
_steps_for_budget,
11+
_build_step,
12+
)
13+
14+
15+
def test_baseline_steps_within_budget():
16+
"""Steps computed for baseline config should consume ~3e18 FLOPs."""
17+
from levanter.utils.flop_utils import lm_flops_per_token
18+
19+
batch_size = 256
20+
steps = _steps_for_budget(BASELINE, batch_size)
21+
fpt = lm_flops_per_token(
22+
hidden_dim=BASELINE.hidden_dim,
23+
intermediate_dim=BASELINE.intermediate_dim,
24+
shared_intermediate_dim=BASELINE.shared_expert_intermediate_dim,
25+
num_layers=BASELINE.num_layers,
26+
num_kv_heads=BASELINE.num_kv_heads,
27+
num_heads=BASELINE.num_heads,
28+
seq_len=BASELINE.max_seq_len,
29+
vocab_size=BASELINE.vocab_size,
30+
glu=True,
31+
num_experts=BASELINE.num_experts,
32+
num_shared_experts=1 if BASELINE.shared_expert_intermediate_dim > 0 else 0,
33+
num_experts_per_tok=BASELINE.num_experts_per_token,
34+
)
35+
actual_flops = 3 * fpt * BASELINE.max_seq_len * batch_size * steps
36+
ratio = actual_flops / TARGET_FLOPS
37+
assert 0.9 <= ratio <= 1.0, f"Budget ratio {ratio:.3f} outside [0.9, 1.0]"
38+
39+
40+
def test_all_sweep_points_no_duplicates():
41+
points = all_sweep_points()
42+
names = [p.name for p in points]
43+
assert len(names) == len(set(names)), f"Duplicate sweep point names: {names}"
44+
45+
46+
def test_all_sweep_points_have_valid_models():
47+
"""Every sweep point should have a valid GrugModelConfig (post_init passes)."""
48+
for pt in all_sweep_points():
49+
assert pt.model.num_experts_per_token <= pt.model.num_experts
50+
assert pt.model.hidden_dim % pt.model.num_heads == 0
51+
assert pt.model.vocab_size > 0
52+
assert _steps_for_budget(pt.model, pt.batch_size) > 0
53+
54+
55+
def test_sweep_point_count():
56+
"""Sweep should produce a reasonable number of arms."""
57+
points = all_sweep_points()
58+
assert 8 <= len(points) <= 30, f"Expected 8-30 sweep points, got {len(points)}"
59+
60+
61+
def test_build_step_produces_executor_step():
62+
"""_build_step should return a valid ExecutorStep for each sweep point."""
63+
points = all_sweep_points()
64+
for pt in points[:3]:
65+
step = _build_step(pt)
66+
assert step.name.startswith("moe-3e18/")
67+
assert step.fn is not None

0 commit comments

Comments
 (0)