Skip to content

Commit a5d4e33

Browse files
[moe] Add multi-scale AdamH vs Adam isoflop experiment
Add isoflop suite comparing Adam and AdamH (GrugAdamHConfig) on grug MoE models at 3e18, 1e19, 3e19, and 1e20 FLOPs. Each budget gets an appropriately-sized model (E=8, K=2, shared expert) with LR/beta2 scaled by batch size and model width. Produces 8 runs total for the great 10T gate optimizer comparison. Includes config generation tests. Fixes #4042 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b666a77 commit a5d4e33

4 files changed

Lines changed: 353 additions & 14 deletions

File tree

experiments/grug/moe/adamh_vs_adam_1e19.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,13 @@
1010
Part of #4024 / #4013.
1111
"""
1212

13-
import dataclasses
1413
import math
1514
import os
16-
from dataclasses import field
17-
from datetime import timedelta
1815

19-
import jmp
2016
from fray.cluster import ResourceConfig
21-
from levanter.callbacks.profiler import ProfilerConfig
22-
from levanter.checkpoint import CheckpointerConfig
23-
from levanter.data.text import LmDataConfig
2417
from levanter.optim import AdamConfig, GrugAdamHConfig, OptimizerConfig
2518
from levanter.tracker.wandb import WandbConfig
26-
from levanter.trainer import TrainerConfig
2719
from levanter.utils.flop_utils import lm_flops_per_token
28-
from levanter.utils.mesh import MeshConfig
2920

3021
from experiments.defaults import default_validation_sets
3122
from experiments.grug.moe.launch import GrugMoeLaunchConfig, run_grug_moe
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Multi-scale isoflop comparison of AdamH vs Adam on the grug MoE architecture.
5+
6+
Launches paired Adam / AdamH runs at four FLOP budgets (3e18, 1e19, 3e19, 1e20)
7+
on appropriately-sized grug MoE models (E=8, K=2, shared expert). Each budget
8+
gets a model sized so that token count stays within ~20-40x parameters (roughly
9+
Chinchilla-optimal for MoE). The only variable within each pair is the optimizer.
10+
11+
This produces 8 runs total, enough to check whether AdamH vs Adam trends hold
12+
across scales before locking in the optimizer for the 10T TPU path.
13+
14+
Part of #4042 / #4014.
15+
"""
16+
17+
import math
18+
import os
19+
from dataclasses import dataclass
20+
21+
from fray.cluster import ResourceConfig
22+
from levanter.optim import AdamConfig, GrugAdamHConfig, OptimizerConfig
23+
from levanter.tracker.wandb import WandbConfig
24+
from levanter.utils.flop_utils import lm_flops_per_token
25+
26+
from experiments.defaults import default_validation_sets
27+
from experiments.grug.moe.launch import GrugMoeLaunchConfig, run_grug_moe
28+
from experiments.grug.moe.model import GrugModelConfig
29+
from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig
30+
from experiments.pretraining_datasets import nemotron_mix_block_shuffle
31+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
32+
from marin.processing.tokenize import add_validation_sets_to_mixture
33+
34+
# ---------- constants ----------
35+
SEQ_LEN = 4096
36+
VOCAB_SIZE = 128_256
37+
NUM_EXPERTS = 8
38+
NUM_EXPERTS_PER_TOKEN = 2
39+
HEAD_DIM = 128
40+
41+
NEMOTRON_MIX = add_validation_sets_to_mixture(
42+
nemotron_mix_block_shuffle,
43+
default_validation_sets(tokenizer=nemotron_mix_block_shuffle.tokenizer),
44+
)
45+
46+
GRUG_TRAINER = GrugTrainerConfig(
47+
z_loss_weight=1e-4,
48+
ema_beta=None,
49+
log_every=1,
50+
)
51+
52+
EVAL = GrugEvalConfig(
53+
eval_batch_size=512,
54+
steps_per_eval=1000,
55+
max_eval_batches=8,
56+
eval_current=True,
57+
eval_ema=False,
58+
)
59+
60+
61+
@dataclass(frozen=True)
62+
class ScalePoint:
63+
"""One point in the isoflop suite."""
64+
65+
budget: float
66+
hidden_dim: int
67+
num_layers: int
68+
batch_size: int
69+
70+
71+
def _num_layers_for_hidden(hidden_dim: int) -> int:
72+
"""Depth-to-width heuristic: ~hidden/64, clamped to reasonable range."""
73+
raw = hidden_dim / 64
74+
return max(6, min(48, round(raw)))
75+
76+
77+
def _flops_per_token(hidden_dim: int, intermediate_dim: int, num_layers: int, num_heads: int) -> float:
78+
return lm_flops_per_token(
79+
hidden_dim=hidden_dim,
80+
intermediate_dim=intermediate_dim,
81+
num_layers=num_layers,
82+
num_kv_heads=num_heads,
83+
num_heads=num_heads,
84+
seq_len=SEQ_LEN,
85+
vocab_size=VOCAB_SIZE,
86+
glu=True,
87+
num_experts=NUM_EXPERTS,
88+
num_shared_experts=1,
89+
num_experts_per_tok=NUM_EXPERTS_PER_TOKEN,
90+
shared_intermediate_dim=intermediate_dim,
91+
)
92+
93+
94+
def _pick_batch_size(tokens: float, target_steps: int = 2**16) -> int:
95+
"""Pick batch size (power of 2) to hit ~target_steps of training."""
96+
raw = tokens / (target_steps * SEQ_LEN)
97+
bs = 2 ** round(math.log2(max(8, raw)))
98+
return max(8, min(512, bs))
99+
100+
101+
def _find_hidden_dim_for_budget(budget: float, step_size: int = 128) -> int:
102+
"""Binary-ish search for the hidden_dim that makes the model ~Chinchilla-optimal.
103+
104+
We want tokens/params ~ 20-40x. Search over hidden_dim in steps of step_size.
105+
"""
106+
best_dim = 512
107+
best_score = float("inf")
108+
for dim in range(384, 4096 + 1, step_size):
109+
n_layers = _num_layers_for_hidden(dim)
110+
n_heads = dim // HEAD_DIM
111+
if n_heads < 1 or dim % HEAD_DIM != 0:
112+
continue
113+
intermediate = dim * 3
114+
fpt = _flops_per_token(dim, intermediate, n_layers, n_heads)
115+
tokens = budget / (3 * fpt)
116+
# Rough param count for MoE: embedding + layers*(attn + K*expert_mlp + shared_mlp)
117+
# Simplified: use a proxy based on dense equivalent
118+
attn_params = 4 * dim * dim * n_layers
119+
expert_params = 3 * dim * intermediate * NUM_EXPERTS * n_layers # GLU: 3 matrices
120+
shared_params = 3 * dim * intermediate * n_layers
121+
embed_params = 2 * VOCAB_SIZE * dim
122+
total_params = attn_params + expert_params + shared_params + embed_params
123+
# Active params per token (what matters for Chinchilla)
124+
active_params = attn_params + 3 * dim * intermediate * NUM_EXPERTS_PER_TOKEN * n_layers + shared_params + embed_params
125+
ratio = tokens / active_params
126+
# Target ratio ~20-40x, aim for ~25
127+
score = abs(math.log(ratio / 25))
128+
if score < best_score:
129+
best_score = score
130+
best_dim = dim
131+
return best_dim
132+
133+
134+
def make_scale_point(budget: float) -> ScalePoint:
135+
hidden_dim = _find_hidden_dim_for_budget(budget)
136+
num_layers = _num_layers_for_hidden(hidden_dim)
137+
num_heads = hidden_dim // HEAD_DIM
138+
intermediate_dim = hidden_dim * 3
139+
fpt = _flops_per_token(hidden_dim, intermediate_dim, num_layers, num_heads)
140+
tokens = budget / (3 * fpt)
141+
batch_size = _pick_batch_size(tokens)
142+
return ScalePoint(budget=budget, hidden_dim=hidden_dim, num_layers=num_layers, batch_size=batch_size)
143+
144+
145+
def make_model(sp: ScalePoint) -> GrugModelConfig:
146+
num_heads = sp.hidden_dim // HEAD_DIM
147+
return GrugModelConfig(
148+
vocab_size=VOCAB_SIZE,
149+
hidden_dim=sp.hidden_dim,
150+
intermediate_dim=sp.hidden_dim * 3,
151+
shared_expert_intermediate_dim=sp.hidden_dim * 3,
152+
num_experts=NUM_EXPERTS,
153+
num_experts_per_token=NUM_EXPERTS_PER_TOKEN,
154+
num_layers=sp.num_layers,
155+
num_heads=num_heads,
156+
num_kv_heads=num_heads,
157+
max_seq_len=SEQ_LEN,
158+
)
159+
160+
161+
def compute_train_steps(sp: ScalePoint) -> int:
162+
num_heads = sp.hidden_dim // HEAD_DIM
163+
fpt = _flops_per_token(sp.hidden_dim, sp.hidden_dim * 3, sp.num_layers, num_heads)
164+
tokens = sp.budget / (3 * fpt)
165+
return round(tokens / (sp.batch_size * SEQ_LEN))
166+
167+
168+
def make_adam_optimizer(sp: ScalePoint) -> AdamConfig:
169+
"""Adam optimizer with LR/beta2 scaled by batch size and model width."""
170+
effective_bs = sp.batch_size * SEQ_LEN / 4096
171+
lr = min(0.01, (0.33 * math.sqrt(effective_bs)) / sp.hidden_dim)
172+
beta2 = max(0.95, 0.98 ** (effective_bs / 128))
173+
return AdamConfig(
174+
learning_rate=lr,
175+
weight_decay=0.1,
176+
beta1=0.9,
177+
beta2=beta2,
178+
epsilon=1e-8,
179+
lr_schedule="linear",
180+
decay=0.2,
181+
min_lr_ratio=0.0,
182+
warmup=0.1,
183+
max_grad_norm=1.0,
184+
)
185+
186+
187+
def make_adamh_optimizer(sp: ScalePoint) -> GrugAdamHConfig:
188+
"""AdamH optimizer: sqrt(lr * wd) for scale-invariant weights, standard lr for Adam params."""
189+
effective_bs = sp.batch_size * SEQ_LEN / 4096
190+
adam_lr = min(0.01, (0.33 * math.sqrt(effective_bs)) / sp.hidden_dim)
191+
beta2 = max(0.95, 0.98 ** (effective_bs / 128))
192+
adamh_lr = math.sqrt(adam_lr * 0.1)
193+
return GrugAdamHConfig(
194+
learning_rate=adamh_lr,
195+
adam_lr=adam_lr,
196+
beta1=0.9,
197+
beta2=beta2,
198+
epsilon=1e-8,
199+
lr_schedule="linear",
200+
decay=0.2,
201+
min_lr_ratio=0.0,
202+
warmup=0.1,
203+
max_grad_norm=0.1,
204+
weight_decay=0.0,
205+
)
206+
207+
208+
def _resolve_run_id(base: str) -> str:
209+
run_id = os.environ.get("GRUG_RUN_ID", base)
210+
ferry_date = os.environ.get("FERRY_DATE")
211+
if ferry_date:
212+
run_id = f"{run_id}-{ferry_date}"
213+
return run_id
214+
215+
216+
def _make_step(
217+
sp: ScalePoint, optimizer: OptimizerConfig, opt_name: str, tags: list[str]
218+
) -> ExecutorStep:
219+
budget_str = f"{sp.budget:.0e}"
220+
name = f"moe-{opt_name}-{budget_str}-d{sp.hidden_dim}"
221+
run_id = _resolve_run_id(name)
222+
train_steps = compute_train_steps(sp)
223+
return ExecutorStep(
224+
name=f"grug/{name}",
225+
fn=run_grug_moe,
226+
config=GrugMoeLaunchConfig(
227+
model=versioned(make_model(sp)),
228+
data=NEMOTRON_MIX,
229+
output_path=this_output_path(),
230+
run_id=run_id,
231+
resources=versioned(ResourceConfig.with_tpu("v5p-8")),
232+
steps=versioned(train_steps),
233+
batch_size=versioned(sp.batch_size),
234+
seed=versioned(42),
235+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
236+
tracker=WandbConfig(
237+
project="marin",
238+
tags=["grug", "moe", "adamh-vs-adam", "isoflop", budget_str, *tags],
239+
group="moe-adamh-vs-adam-isoflop",
240+
name=None,
241+
),
242+
optimizer=versioned(optimizer),
243+
grug_trainer=versioned(GRUG_TRAINER),
244+
eval=versioned(EVAL),
245+
),
246+
)
247+
248+
249+
# ---------- isoflop suite ----------
250+
BUDGETS = (3e18, 1e19, 3e19, 1e20)
251+
SCALE_POINTS = [make_scale_point(b) for b in BUDGETS]
252+
253+
all_steps: list[ExecutorStep] = []
254+
for sp in SCALE_POINTS:
255+
adam_opt = make_adam_optimizer(sp)
256+
adamh_opt = make_adamh_optimizer(sp)
257+
all_steps.append(_make_step(sp, adam_opt, "adam", ["adam", "baseline"]))
258+
all_steps.append(_make_step(sp, adamh_opt, "adamh", ["adamh"]))
259+
260+
261+
if __name__ == "__main__":
262+
executor_main(
263+
steps=all_steps,
264+
description=(
265+
"AdamH vs Adam isoflop suite on grug MoE (E=8, K=2) at 3e18/1e19/3e19/1e20 FLOPs. "
266+
"Part of #4042."
267+
),
268+
)

lib/levanter/tests/test_grug_adamh.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
def _make_fake_grug_params():
1313
"""Minimal pytree mimicking grug MoE parameter structure."""
1414
return {
15-
"embed": jnp.zeros((128, 64)), # embedding: 2D but name contains "embed"
15+
"embed": jnp.zeros((128, 64)), # embedding: 2D but name contains "embed"
1616
"layers": {
1717
"attn_w_q": jnp.zeros((64, 64)), # weight matrix
1818
"attn_w_k": jnp.zeros((64, 64)), # weight matrix
19-
"mlp_w1": jnp.zeros((64, 192)), # weight matrix
20-
"mlp_w2": jnp.zeros((192, 64)), # weight matrix
21-
"norm_weight": jnp.zeros((64,)), # 1D norm scale
19+
"mlp_w1": jnp.zeros((64, 192)), # weight matrix
20+
"mlp_w2": jnp.zeros((192, 64)), # weight matrix
21+
"norm_weight": jnp.zeros((64,)), # 1D norm scale
2222
},
2323
"router_weight": jnp.zeros((64, 8)), # router: 2D but name contains "router"
24-
"lm_head": jnp.zeros((64, 128)), # weight matrix
24+
"lm_head": jnp.zeros((64, 128)), # weight matrix
2525
}
2626

2727

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for the multi-scale AdamH vs Adam isoflop experiment config generation."""
5+
6+
from experiments.grug.moe.adamh_vs_adam_isoflop import (
7+
BUDGETS,
8+
SCALE_POINTS,
9+
all_steps,
10+
compute_train_steps,
11+
make_adam_optimizer,
12+
make_adamh_optimizer,
13+
make_model,
14+
make_scale_point,
15+
)
16+
from levanter.optim import AdamConfig, GrugAdamHConfig
17+
18+
19+
def test_scale_points_cover_all_budgets():
20+
assert len(SCALE_POINTS) == len(BUDGETS)
21+
for sp, budget in zip(SCALE_POINTS, BUDGETS):
22+
assert sp.budget == budget
23+
24+
25+
def test_hidden_dim_increases_with_budget():
26+
dims = [sp.hidden_dim for sp in SCALE_POINTS]
27+
for i in range(1, len(dims)):
28+
assert dims[i] >= dims[i - 1], f"hidden_dim should not decrease: {dims}"
29+
30+
31+
def test_model_configs_are_valid():
32+
for sp in SCALE_POINTS:
33+
model = make_model(sp)
34+
assert model.hidden_dim == sp.hidden_dim
35+
assert model.num_layers == sp.num_layers
36+
assert model.num_experts == 8
37+
assert model.num_experts_per_token == 2
38+
assert model.hidden_dim % 128 == 0, "hidden_dim must be divisible by head_dim=128"
39+
40+
41+
def test_train_steps_positive():
42+
for sp in SCALE_POINTS:
43+
steps = compute_train_steps(sp)
44+
assert steps > 0, f"train_steps must be positive for budget={sp.budget}"
45+
46+
47+
def test_adam_optimizer_types():
48+
for sp in SCALE_POINTS:
49+
adam = make_adam_optimizer(sp)
50+
adamh = make_adamh_optimizer(sp)
51+
assert isinstance(adam, AdamConfig)
52+
assert isinstance(adamh, GrugAdamHConfig)
53+
assert adam.learning_rate > 0
54+
assert adamh.learning_rate > 0
55+
assert adamh.adam_lr > 0
56+
57+
58+
def test_adamh_lr_follows_heuristic():
59+
"""AdamH scale-invariant LR = sqrt(adam_lr * 0.1)."""
60+
import math
61+
62+
for sp in SCALE_POINTS:
63+
adamh = make_adamh_optimizer(sp)
64+
expected = math.sqrt(adamh.adam_lr * 0.1)
65+
assert abs(adamh.learning_rate - expected) < 1e-12, (
66+
f"AdamH LR {adamh.learning_rate} != sqrt({adamh.adam_lr} * 0.1) = {expected}"
67+
)
68+
69+
70+
def test_all_steps_generated():
71+
# 2 optimizers x 4 budgets = 8 steps
72+
assert len(all_steps) == 8
73+
names = [s.name for s in all_steps]
74+
# Each budget should have an adam and adamh step
75+
for budget in BUDGETS:
76+
budget_str = f"{budget:.0e}"
77+
adam_names = [n for n in names if "adam-" in n and budget_str in n]
78+
adamh_names = [n for n in names if "adamh-" in n and budget_str in n]
79+
assert len(adam_names) == 1, f"Expected 1 adam step for {budget_str}, got {adam_names}"
80+
assert len(adamh_names) == 1, f"Expected 1 adamh step for {budget_str}, got {adamh_names}"

0 commit comments

Comments
 (0)