Skip to content

Commit e827ac4

Browse files
[moe] Add great 10T K sweep experiment for issue #4047
Add experiments/grug/moe/great_10t_sweep_k.py generating an isoflop grid: 3 FLOP budgets (1e18, 3e18, 1e19) x 5 hidden dims x 4 K values (1,2,4,8) using the E=128 MoE recipe. Includes test validating sweep config generation. Fixes #4047 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a243fe5 commit e827ac4

2 files changed

Lines changed: 231 additions & 0 deletions

File tree

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Great 10T ablation: sweep K (num_experts_per_token) across isoflop budgets.
5+
6+
Generates an isoflop grid varying K in {1, 2, 4, 8} at three FLOP budgets
7+
(1e18, 3e18, 1e19) with five hidden dims per budget. Architecture follows
8+
the E=128 MoE recipe from iteration-02 scaling work. Higher K means more
9+
active FLOPs per token, so fewer training steps at the same budget.
10+
11+
See: https://github.com/marin-community/marin/issues/4047
12+
Parent sweep: https://github.com/marin-community/marin/issues/3469
13+
Gate: https://github.com/marin-community/marin/issues/4014
14+
"""
15+
16+
import math
17+
18+
from fray.cluster import ResourceConfig
19+
from levanter.optim import AdamConfig
20+
from levanter.tracker.wandb import WandbConfig
21+
from levanter.utils.flop_utils import lm_flops_per_token
22+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
23+
24+
from experiments.grug.moe.launch import (
25+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
26+
GrugMoeLaunchConfig,
27+
run_grug_moe,
28+
)
29+
from experiments.grug.moe.model import GrugModelConfig
30+
from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig
31+
32+
# ---------------------------------------------------------------------------
33+
# Sweep axes
34+
# ---------------------------------------------------------------------------
35+
FLOP_BUDGETS: tuple[float, ...] = (1e18, 3e18, 1e19)
36+
HIDDEN_DIMS: tuple[int, ...] = (512, 768, 1024, 1536, 2048)
37+
K_VALUES: tuple[int, ...] = (1, 2, 4, 8)
38+
39+
# ---------------------------------------------------------------------------
40+
# Architecture constants (E=128 recipe)
41+
# ---------------------------------------------------------------------------
42+
VOCAB_SIZE = 128_256
43+
NUM_EXPERTS = 128
44+
SEQ_LEN = 4096
45+
TARGET_STEPS_LOG2 = 14 # ~16384 steps per run
46+
MIN_BATCH_SIZE = 8
47+
48+
# Scaling heuristics (from iteration-02 MoE recipe).
49+
BASE_HIDDEN_LAYER_RATIO = 64
50+
LAYER_SCALING_FACTOR = 4.0
51+
LAYER_FORMULA_OFFSET = 9
52+
LR_CONSTANT = 0.33
53+
MAX_LR = 0.01
54+
BETA2_BASE = 0.98
55+
BETA2_BATCH_DIVISOR = 128
56+
57+
58+
def _compute_num_layers(hidden_dim: int) -> int:
59+
hs_pow = math.log2(hidden_dim)
60+
return round(hidden_dim / (BASE_HIDDEN_LAYER_RATIO + LAYER_SCALING_FACTOR * hs_pow - LAYER_FORMULA_OFFSET))
61+
62+
63+
def _round_up_pow2(x: float) -> int:
64+
if x <= 1:
65+
return 1
66+
return 2 ** math.ceil(math.log2(x))
67+
68+
69+
def _flops_per_token_for_config(cfg: GrugModelConfig) -> float:
70+
return lm_flops_per_token(
71+
hidden_dim=cfg.hidden_dim,
72+
intermediate_dim=cfg.intermediate_dim,
73+
shared_intermediate_dim=cfg.shared_expert_intermediate_dim,
74+
num_layers=cfg.num_layers,
75+
num_kv_heads=cfg.num_kv_heads,
76+
num_heads=cfg.num_heads,
77+
seq_len=cfg.max_seq_len,
78+
vocab_size=cfg.vocab_size,
79+
glu=True,
80+
num_experts=cfg.num_experts,
81+
num_shared_experts=1 if cfg.shared_expert_intermediate_dim > 0 else 0,
82+
num_experts_per_tok=cfg.num_experts_per_token,
83+
)
84+
85+
86+
def _build_model_config(hidden_dim: int, k: int) -> GrugModelConfig:
87+
num_layers = _compute_num_layers(hidden_dim)
88+
intermediate_dim = hidden_dim // 2
89+
shared_expert_dim = hidden_dim
90+
num_heads = max(1, hidden_dim // 128)
91+
return GrugModelConfig(
92+
vocab_size=VOCAB_SIZE,
93+
hidden_dim=hidden_dim,
94+
intermediate_dim=intermediate_dim,
95+
shared_expert_intermediate_dim=shared_expert_dim,
96+
num_experts=NUM_EXPERTS,
97+
num_experts_per_token=k,
98+
num_layers=num_layers,
99+
num_heads=num_heads,
100+
num_kv_heads=num_heads,
101+
max_seq_len=SEQ_LEN,
102+
)
103+
104+
105+
def _build_sweep_step(
106+
budget: float,
107+
hidden_dim: int,
108+
k: int,
109+
) -> ExecutorStep:
110+
model = _build_model_config(hidden_dim, k)
111+
fpt = _flops_per_token_for_config(model)
112+
tokens = budget / (3 * fpt)
113+
target_steps = 2**TARGET_STEPS_LOG2
114+
115+
batch_exact = tokens / (target_steps * SEQ_LEN)
116+
effective_bs = _round_up_pow2(batch_exact)
117+
effective_bs = max(MIN_BATCH_SIZE, effective_bs)
118+
119+
lr = min(MAX_LR, (LR_CONSTANT * math.sqrt(effective_bs)) / hidden_dim)
120+
beta2 = max(0.95, BETA2_BASE ** (effective_bs / BETA2_BATCH_DIVISOR))
121+
steps = max(1, round(tokens / (effective_bs * SEQ_LEN)))
122+
123+
budget_tag = f"{budget:.0e}"
124+
k_tag = f"k{k}"
125+
run_id = f"great-10t-sweepk-{budget_tag}-d{hidden_dim}-{k_tag}"
126+
127+
return ExecutorStep(
128+
name=f"grug/great-10t-sweepk/{budget_tag}/d{hidden_dim}/{k_tag}",
129+
fn=run_grug_moe,
130+
config=GrugMoeLaunchConfig(
131+
model=versioned(model),
132+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
133+
output_path=this_output_path(),
134+
run_id=run_id,
135+
resources=versioned(ResourceConfig.with_tpu("v5p-8")),
136+
steps=versioned(steps),
137+
batch_size=versioned(effective_bs),
138+
seed=versioned(0),
139+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
140+
tracker=WandbConfig(
141+
project="marin",
142+
tags=["grug", "moe", "great-10t", "sweep-k", budget_tag, k_tag],
143+
group="great-10t-sweep-k",
144+
name=None,
145+
),
146+
optimizer=versioned(
147+
AdamConfig(
148+
learning_rate=lr,
149+
weight_decay=0.1,
150+
lr_schedule="cosine",
151+
decay=0.2,
152+
min_lr_ratio=0.1,
153+
warmup=0.1,
154+
beta2=beta2,
155+
)
156+
),
157+
grug_trainer=versioned(
158+
GrugTrainerConfig(
159+
z_loss_weight=1e-4,
160+
ema_beta=None,
161+
log_every=1,
162+
)
163+
),
164+
eval=versioned(
165+
GrugEvalConfig(
166+
eval_batch_size=min(512, effective_bs),
167+
steps_per_eval=1000,
168+
max_eval_batches=8,
169+
eval_current=True,
170+
eval_ema=False,
171+
)
172+
),
173+
),
174+
)
175+
176+
177+
def build_sweep_steps() -> list[ExecutorStep]:
178+
"""Generate all ExecutorSteps for the great 10T K sweep."""
179+
steps: list[ExecutorStep] = []
180+
for budget in FLOP_BUDGETS:
181+
for hidden_dim in HIDDEN_DIMS:
182+
for k in K_VALUES:
183+
steps.append(_build_sweep_step(budget, hidden_dim, k))
184+
return steps
185+
186+
187+
ALL_STEPS = build_sweep_steps()
188+
189+
190+
if __name__ == "__main__":
191+
executor_main(
192+
steps=ALL_STEPS,
193+
description="Great 10T ablation: sweep K in {1,2,4,8} across isoflop budgets (issue #4047).",
194+
)

tests/test_grug_variant_contracts.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,40 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
263263
]
264264
for key in required_keys:
265265
assert key in summary
266+
267+
268+
def test_great_10t_sweep_k_generates_valid_configs():
269+
"""Verify the great 10T K sweep produces well-formed configs."""
270+
from experiments.grug.moe.great_10t_sweep_k import (
271+
ALL_STEPS,
272+
FLOP_BUDGETS,
273+
HIDDEN_DIMS,
274+
K_VALUES,
275+
_build_model_config,
276+
_flops_per_token_for_config,
277+
)
278+
279+
expected_count = len(FLOP_BUDGETS) * len(HIDDEN_DIMS) * len(K_VALUES)
280+
assert len(ALL_STEPS) == expected_count
281+
282+
# No duplicate step names.
283+
names = [s.name for s in ALL_STEPS]
284+
assert len(names) == len(set(names)), f"Duplicate step names: {[n for n in names if names.count(n) > 1]}"
285+
286+
# Every config has positive batch/steps and correct K.
287+
seen_k = set()
288+
for step in ALL_STEPS:
289+
cfg = step.config
290+
# batch_size and steps are wrapped in VersionedValue; unwrap with .value.
291+
assert cfg.batch_size.value > 0
292+
assert cfg.steps.value > 0
293+
model = cfg.model.value
294+
assert model.num_experts == 128
295+
assert model.num_experts_per_token in K_VALUES
296+
seen_k.add(model.num_experts_per_token)
297+
assert seen_k == set(K_VALUES)
298+
299+
# Higher K yields more FLOPs per token (same architecture otherwise).
300+
cfg_k1 = _build_model_config(1024, 1)
301+
cfg_k8 = _build_model_config(1024, 8)
302+
assert _flops_per_token_for_config(cfg_k1) < _flops_per_token_for_config(cfg_k8)

0 commit comments

Comments
 (0)