Skip to content

Commit c21dcd5

Browse files
[moe] Add great 10T first-k-dense isoflop ablation sweep
Add experiments/grug/moe/great_10t_first_k_dense.py generating an isoflop grid: 3 FLOP budgets (1e18, 3e18, 1e19) x 5 hidden dims x 4 num_dense_layers values (0, 1, 2, 4) using the E=128 K=4 MoE recipe. Includes test validating sweep config generation. Fixes #4040 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4fbf80a commit c21dcd5

2 files changed

Lines changed: 265 additions & 0 deletions

File tree

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Great 10T ablation: sweep num_dense_layers across isoflop budgets.
5+
6+
Generates an isoflop grid varying num_dense_layers (0, 1, 2, 4) at three FLOP
7+
budgets (1e18, 3e18, 1e19) with five hidden dims per budget. Architecture
8+
follows the E=128, K=4 MoE recipe from iteration-02 scaling work.
9+
10+
See: https://github.com/marin-community/marin/issues/4040
11+
Parent sweep: https://github.com/marin-community/marin/issues/3469
12+
Gate: https://github.com/marin-community/marin/issues/4014
13+
"""
14+
15+
import math
16+
from dataclasses import replace
17+
18+
from fray.cluster import ResourceConfig
19+
from levanter.optim import AdamConfig
20+
from levanter.tracker.wandb import WandbConfig
21+
from levanter.utils.flop_utils import lm_flops_per_token
22+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
23+
24+
from experiments.grug.moe.launch import (
25+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
26+
GrugMoeLaunchConfig,
27+
run_grug_moe,
28+
)
29+
from experiments.grug.moe.model import GrugModelConfig
30+
from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig
31+
32+
# ---------------------------------------------------------------------------
33+
# Sweep axes
34+
# ---------------------------------------------------------------------------
35+
FLOP_BUDGETS: tuple[float, ...] = (1e18, 3e18, 1e19)
36+
HIDDEN_DIMS: tuple[int, ...] = (512, 768, 1024, 1536, 2048)
37+
NUM_DENSE_LAYERS_VALUES: tuple[int, ...] = (0, 1, 2, 4)
38+
39+
# ---------------------------------------------------------------------------
40+
# Architecture constants (E=128, K=4 recipe)
41+
# ---------------------------------------------------------------------------
42+
VOCAB_SIZE = 128_256
43+
NUM_EXPERTS = 128
44+
NUM_EXPERTS_PER_TOKEN = 4
45+
SEQ_LEN = 4096
46+
TARGET_STEPS_LOG2 = 14 # ~16384 steps per run
47+
48+
# Scaling heuristics (from iteration-02 MoE recipe).
49+
BASE_HIDDEN_LAYER_RATIO = 64
50+
LAYER_SCALING_FACTOR = 4.0
51+
LAYER_FORMULA_OFFSET = 9
52+
LR_CONSTANT = 0.33
53+
MAX_LR = 0.01
54+
BETA2_BASE = 0.98
55+
BETA2_BATCH_DIVISOR = 128
56+
57+
58+
def _compute_num_layers(hidden_dim: int) -> int:
59+
hs_pow = math.log2(hidden_dim)
60+
return round(hidden_dim / (BASE_HIDDEN_LAYER_RATIO + LAYER_SCALING_FACTOR * hs_pow - LAYER_FORMULA_OFFSET))
61+
62+
63+
def _round_up_pow2(x: float) -> int:
64+
if x <= 1:
65+
return 1
66+
return 2 ** math.ceil(math.log2(x))
67+
68+
69+
def _flops_per_token_for_config(cfg: GrugModelConfig) -> float:
70+
"""Compute forward-pass FLOPs per token for a mixed dense/MoE model."""
71+
num_moe_layers = cfg.num_layers - cfg.num_dense_layers
72+
moe_flops = lm_flops_per_token(
73+
hidden_dim=cfg.hidden_dim,
74+
intermediate_dim=cfg.intermediate_dim,
75+
shared_intermediate_dim=cfg.shared_expert_intermediate_dim,
76+
num_layers=num_moe_layers,
77+
num_kv_heads=cfg.num_kv_heads,
78+
num_heads=cfg.num_heads,
79+
seq_len=cfg.max_seq_len,
80+
vocab_size=cfg.vocab_size,
81+
glu=True,
82+
num_experts=cfg.num_experts,
83+
num_shared_experts=1 if cfg.shared_expert_intermediate_dim > 0 else 0,
84+
num_experts_per_tok=cfg.num_experts_per_token,
85+
)
86+
if cfg.num_dense_layers > 0:
87+
dense_flops = lm_flops_per_token(
88+
hidden_dim=cfg.hidden_dim,
89+
intermediate_dim=cfg.resolved_dense_intermediate_dim,
90+
shared_intermediate_dim=0,
91+
num_layers=cfg.num_dense_layers,
92+
num_kv_heads=cfg.num_kv_heads,
93+
num_heads=cfg.num_heads,
94+
seq_len=cfg.max_seq_len,
95+
vocab_size=cfg.vocab_size,
96+
glu=True,
97+
num_experts=1,
98+
num_shared_experts=0,
99+
num_experts_per_tok=1,
100+
)
101+
else:
102+
dense_flops = 0.0
103+
return moe_flops + dense_flops
104+
105+
106+
def _build_model_config(hidden_dim: int, num_dense_layers: int) -> GrugModelConfig:
107+
num_layers = _compute_num_layers(hidden_dim)
108+
clamped_dense = min(num_dense_layers, num_layers)
109+
intermediate_dim = hidden_dim // 2
110+
shared_expert_dim = hidden_dim
111+
num_heads = max(1, hidden_dim // 128)
112+
return GrugModelConfig(
113+
vocab_size=VOCAB_SIZE,
114+
hidden_dim=hidden_dim,
115+
intermediate_dim=intermediate_dim,
116+
shared_expert_intermediate_dim=shared_expert_dim,
117+
num_experts=NUM_EXPERTS,
118+
num_experts_per_token=NUM_EXPERTS_PER_TOKEN,
119+
num_layers=num_layers,
120+
num_heads=num_heads,
121+
num_kv_heads=num_heads,
122+
max_seq_len=SEQ_LEN,
123+
num_dense_layers=clamped_dense,
124+
)
125+
126+
127+
def _build_sweep_step(
128+
budget: float,
129+
hidden_dim: int,
130+
num_dense_layers: int,
131+
) -> ExecutorStep:
132+
model = _build_model_config(hidden_dim, num_dense_layers)
133+
fpt = _flops_per_token_for_config(model)
134+
tokens = budget / (3 * fpt)
135+
target_steps = 2**TARGET_STEPS_LOG2
136+
137+
batch_exact = tokens / (target_steps * SEQ_LEN)
138+
effective_bs = _round_up_pow2(batch_exact)
139+
effective_bs = max(8, effective_bs)
140+
141+
lr = min(MAX_LR, (LR_CONSTANT * math.sqrt(effective_bs)) / hidden_dim)
142+
beta2 = max(0.95, BETA2_BASE ** (effective_bs / BETA2_BATCH_DIVISOR))
143+
steps = max(1, round(tokens / (effective_bs * SEQ_LEN)))
144+
145+
budget_tag = f"{budget:.0e}"
146+
dense_tag = f"dense{num_dense_layers}"
147+
run_id = f"great-10t-fkd-{budget_tag}-d{hidden_dim}-{dense_tag}"
148+
149+
return ExecutorStep(
150+
name=f"grug/great-10t-fkd/{budget_tag}/d{hidden_dim}/{dense_tag}",
151+
fn=run_grug_moe,
152+
config=GrugMoeLaunchConfig(
153+
model=versioned(model),
154+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
155+
output_path=this_output_path(),
156+
run_id=run_id,
157+
resources=versioned(ResourceConfig.with_tpu("v5p-8")),
158+
steps=versioned(steps),
159+
batch_size=versioned(effective_bs),
160+
seed=versioned(0),
161+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
162+
tracker=WandbConfig(
163+
project="marin",
164+
tags=["grug", "moe", "great-10t", "first-k-dense", budget_tag, dense_tag],
165+
group="great-10t-first-k-dense",
166+
name=None,
167+
),
168+
optimizer=versioned(
169+
AdamConfig(
170+
learning_rate=lr,
171+
weight_decay=0.1,
172+
lr_schedule="cosine",
173+
decay=0.2,
174+
min_lr_ratio=0.1,
175+
warmup=0.1,
176+
beta2=beta2,
177+
)
178+
),
179+
grug_trainer=versioned(
180+
GrugTrainerConfig(
181+
z_loss_weight=1e-4,
182+
ema_beta=None,
183+
log_every=1,
184+
)
185+
),
186+
eval=versioned(
187+
GrugEvalConfig(
188+
eval_batch_size=min(512, effective_bs),
189+
steps_per_eval=1000,
190+
max_eval_batches=8,
191+
eval_current=True,
192+
eval_ema=False,
193+
)
194+
),
195+
),
196+
)
197+
198+
199+
def build_sweep_steps() -> list[ExecutorStep]:
200+
"""Generate all ExecutorSteps for the first-k-dense isoflop ablation."""
201+
steps: list[ExecutorStep] = []
202+
for budget in FLOP_BUDGETS:
203+
for hidden_dim in HIDDEN_DIMS:
204+
for num_dense in NUM_DENSE_LAYERS_VALUES:
205+
num_layers = _compute_num_layers(hidden_dim)
206+
if num_dense > num_layers:
207+
continue
208+
steps.append(_build_sweep_step(budget, hidden_dim, num_dense))
209+
return steps
210+
211+
212+
ALL_STEPS = build_sweep_steps()
213+
214+
215+
if __name__ == "__main__":
216+
executor_main(
217+
steps=ALL_STEPS,
218+
description="Great 10T ablation: sweep num_dense_layers across isoflop budgets (issue #4040).",
219+
)

tests/test_grug_variant_contracts.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,49 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
321321
]
322322
for key in required_keys:
323323
assert key in summary
324+
325+
326+
def test_great_10t_first_k_dense_sweep_generates_valid_configs():
327+
"""Verify the first-k-dense isoflop sweep produces well-formed configs."""
328+
from experiments.grug.moe.great_10t_first_k_dense import (
329+
ALL_STEPS,
330+
FLOP_BUDGETS,
331+
HIDDEN_DIMS,
332+
NUM_DENSE_LAYERS_VALUES,
333+
_build_model_config,
334+
_compute_num_layers,
335+
_flops_per_token_for_config,
336+
)
337+
338+
# Must generate at least one step per budget.
339+
assert len(ALL_STEPS) > len(FLOP_BUDGETS)
340+
341+
# Verify no duplicate step names.
342+
names = [s.name for s in ALL_STEPS]
343+
assert len(names) == len(set(names)), f"Duplicate step names: {[n for n in names if names.count(n) > 1]}"
344+
345+
# Check that every generated config is valid and has positive batch/steps.
346+
for step in ALL_STEPS:
347+
cfg = step.config
348+
assert cfg.batch_size > 0
349+
assert cfg.steps > 0
350+
model = cfg.model
351+
assert model.num_dense_layers <= model.num_layers
352+
assert model.num_dense_layers in NUM_DENSE_LAYERS_VALUES or model.num_dense_layers < max(NUM_DENSE_LAYERS_VALUES)
353+
354+
# Verify dense layers that exceed num_layers are skipped.
355+
for hidden_dim in HIDDEN_DIMS:
356+
num_layers = _compute_num_layers(hidden_dim)
357+
for nd in NUM_DENSE_LAYERS_VALUES:
358+
if nd > num_layers:
359+
matching = [s for s in ALL_STEPS if s.config.model.hidden_dim == hidden_dim and f"dense{nd}" in s.name]
360+
assert len(matching) == 0, f"Should skip dense{nd} for hidden_dim={hidden_dim} (layers={num_layers})"
361+
362+
# Spot-check: num_dense_layers=0 model should have same FLOPs as baseline.
363+
baseline = _build_model_config(1024, 0)
364+
with_dense = _build_model_config(1024, 2)
365+
assert baseline.num_dense_layers == 0
366+
assert with_dense.num_dense_layers == 2
367+
# Different num_dense_layers should yield different FLOP counts
368+
# (dense layers have no router overhead and no shared expert).
369+
assert _flops_per_token_for_config(baseline) != _flops_per_token_for_config(with_dense)

0 commit comments

Comments
 (0)