Skip to content

Commit 8e4d824

Browse files
[moe] Add great 10T gated norm ablation experiment
Add exp4044_great_10t_gated_norms.py with two training arms at 10T token budget to ablate gated norms. Baseline has no gated norms; ablation uses gated_norm_rank=16. Both arms share d=2048 MoE architecture (E=8, K=2), Adam optimizer, and Nemotron mix data on v5p-128. Fixes #4044 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8df5916 commit 8e4d824

1 file changed

Lines changed: 285 additions & 0 deletions

File tree

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Great 10T ablation: gated norms (issue #4044).
5+
6+
Runs two MoE training configurations at 10T token budget to determine whether
7+
gated norms (low-rank self-gating after RMSNorm) improve quality at scale.
8+
9+
Baseline: no gated norms (gated_norm_rank=None).
10+
Ablation: gated_norm_rank=16.
11+
12+
Both arms share the same optimizer, data, and architecture. Comparison metric
13+
is validation perplexity on c4en and the default validation suite.
14+
"""
15+
16+
import dataclasses
17+
import math
18+
import os
19+
from dataclasses import dataclass, field
20+
from datetime import timedelta
21+
22+
import jmp
23+
from fray.cluster import ResourceConfig
24+
from levanter.callbacks.profiler import ProfilerConfig
25+
from levanter.checkpoint import CheckpointerConfig
26+
from levanter.data.text import LmDataConfig
27+
from levanter.optim import AdamConfig, OptimizerConfig
28+
from levanter.tracker import TrackerConfig
29+
from levanter.tracker.wandb import WandbConfig
30+
from levanter.trainer import TrainerConfig
31+
from levanter.utils.flop_utils import lm_flops_per_token
32+
from levanter.utils.mesh import MeshConfig
33+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
34+
from marin.processing.tokenize import add_validation_sets_to_mixture
35+
36+
from experiments.defaults import default_validation_sets
37+
from experiments.grug.moe.model import GrugModelConfig
38+
from experiments.grug.moe.train import GrugEvalConfig, GrugRunConfig, GrugTrainerConfig, run_grug
39+
from experiments.pretraining_datasets import nemotron_mix_block_shuffle
40+
41+
# ---------------------------------------------------------------------------
42+
# Constants
43+
# ---------------------------------------------------------------------------
44+
45+
SEQ_LEN: int = 4096
46+
VOCAB_SIZE: int = 128_256
47+
MIN_BATCH_SIZE: int = 32
48+
49+
GATED_NORM_RANK: int = 16
50+
51+
# 10 trillion tokens.
52+
TOKEN_BUDGET: float = 10e12
53+
54+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION = add_validation_sets_to_mixture(
55+
nemotron_mix_block_shuffle,
56+
default_validation_sets(tokenizer=nemotron_mix_block_shuffle.tokenizer),
57+
)
58+
59+
# ---------------------------------------------------------------------------
60+
# Launch config (mirrors grug/moe/launch.py)
61+
# ---------------------------------------------------------------------------
62+
63+
64+
@dataclass(frozen=True)
65+
class GreatMoeLaunchConfig:
66+
"""Launch config for a single arm of the gated norm ablation."""
67+
68+
model: GrugModelConfig
69+
data: LmDataConfig
70+
output_path: str
71+
run_id: str
72+
resources: ResourceConfig
73+
steps: int
74+
batch_size: int
75+
seed: int
76+
mp: str
77+
tracker: TrackerConfig
78+
optimizer: OptimizerConfig
79+
grug_trainer: GrugTrainerConfig = field(default_factory=GrugTrainerConfig)
80+
eval: GrugEvalConfig | None = field(default_factory=GrugEvalConfig)
81+
profiler: ProfilerConfig = field(default_factory=lambda: ProfilerConfig(enabled=False))
82+
83+
84+
def _resolve_run_id(default_run_id: str) -> str:
85+
run_id = os.environ.get("GRUG_RUN_ID", default_run_id)
86+
ferry_date = os.environ.get("FERRY_DATE")
87+
if ferry_date:
88+
run_id = f"{run_id}-{ferry_date}"
89+
return run_id
90+
91+
92+
def _resolve_tracker(tracker: TrackerConfig, run_id: str, output_path: str) -> TrackerConfig:
93+
if isinstance(tracker, WandbConfig):
94+
return dataclasses.replace(tracker, name=run_id, replicate_path=output_path)
95+
return tracker
96+
97+
98+
def run_great_moe(config: GreatMoeLaunchConfig) -> None:
99+
"""Map GreatMoeLaunchConfig onto TrainerConfig and run training."""
100+
trainer = TrainerConfig(
101+
id=config.run_id,
102+
seed=config.seed,
103+
train_batch_size=config.batch_size,
104+
num_train_steps=config.steps,
105+
profiler=config.profiler,
106+
mp=jmp.get_policy(config.mp),
107+
tracker=_resolve_tracker(config.tracker, config.run_id, config.output_path),
108+
use_explicit_mesh_axes=True,
109+
mesh=MeshConfig(axes={"expert": 1}),
110+
require_accelerator=True,
111+
allow_nondivisible_batch_size=False,
112+
checkpointer=CheckpointerConfig(
113+
base_path=os.path.join(config.output_path, "checkpoints"),
114+
append_run_id_to_base_path=False,
115+
save_interval=timedelta(minutes=10),
116+
keep=[{"every": 5000}],
117+
),
118+
)
119+
120+
grug_trainer = dataclasses.replace(config.grug_trainer, trainer=trainer)
121+
122+
run_config = GrugRunConfig(
123+
model=config.model,
124+
data=config.data,
125+
resources=config.resources,
126+
optimizer=config.optimizer,
127+
trainer=grug_trainer,
128+
eval=config.eval,
129+
)
130+
run_grug(run_config)
131+
132+
133+
# ---------------------------------------------------------------------------
134+
# Model and training arithmetic
135+
# ---------------------------------------------------------------------------
136+
137+
HIDDEN_DIM = 2048
138+
NUM_HEADS = HIDDEN_DIM // 128 # 16
139+
NUM_KV_HEADS = NUM_HEADS
140+
141+
142+
def _compute_num_layers(hidden_dim: int) -> int:
143+
"""Depth-width formula from Marin2025Recipe."""
144+
hs_pow = math.log2(hidden_dim)
145+
return round(hidden_dim / (64 + (hs_pow * 4.0) - 9))
146+
147+
148+
NUM_LAYERS = _compute_num_layers(HIDDEN_DIM)
149+
150+
151+
def _round_to_power_of_two(x: float) -> int:
152+
if x <= 1:
153+
return 1
154+
return 2 ** math.ceil(math.log2(x))
155+
156+
157+
def _build_model(*, gated_norm_rank: int | None) -> GrugModelConfig:
158+
return GrugModelConfig(
159+
vocab_size=VOCAB_SIZE,
160+
hidden_dim=HIDDEN_DIM,
161+
intermediate_dim=HIDDEN_DIM // 2,
162+
shared_expert_intermediate_dim=HIDDEN_DIM,
163+
num_experts=8,
164+
num_experts_per_token=2,
165+
num_layers=NUM_LAYERS,
166+
num_heads=NUM_HEADS,
167+
num_kv_heads=NUM_KV_HEADS,
168+
max_seq_len=SEQ_LEN,
169+
head_dim=None,
170+
load_balancing_loss_coef=0.01,
171+
router_z_loss_coef=0.001,
172+
gated_norm_rank=gated_norm_rank,
173+
)
174+
175+
176+
def _compute_flops_per_token(cfg: GrugModelConfig) -> float:
177+
return lm_flops_per_token(
178+
hidden_dim=cfg.hidden_dim,
179+
intermediate_dim=cfg.intermediate_dim,
180+
num_layers=cfg.num_layers,
181+
num_kv_heads=cfg.num_kv_heads,
182+
num_heads=cfg.num_heads,
183+
seq_len=cfg.max_seq_len,
184+
vocab_size=cfg.vocab_size,
185+
glu=True,
186+
num_experts=cfg.num_experts,
187+
num_shared_experts=1 if cfg.shared_expert_intermediate_dim > 0 else 0,
188+
num_experts_per_tok=cfg.num_experts_per_token,
189+
)
190+
191+
192+
def _compute_training_params(token_budget: float, flops_per_token: float) -> tuple[int, int]:
193+
"""Compute batch_size and train_steps for a given token budget.
194+
195+
Targets ~2^16 steps; minimum batch size 32.
196+
"""
197+
target_steps = 2**16
198+
batch_exact = token_budget / (target_steps * SEQ_LEN)
199+
batch_size = max(MIN_BATCH_SIZE, _round_to_power_of_two(batch_exact))
200+
train_steps = max(1, round(token_budget / (batch_size * SEQ_LEN)))
201+
return batch_size, train_steps
202+
203+
204+
# Precompute training params from the baseline model (both arms use same architecture).
205+
_BASELINE_MODEL = _build_model(gated_norm_rank=None)
206+
_FPT = _compute_flops_per_token(_BASELINE_MODEL)
207+
BATCH_SIZE, TRAIN_STEPS = _compute_training_params(TOKEN_BUDGET, _FPT)
208+
209+
# Learning rate scaled with sqrt(batch).
210+
_EFFECTIVE_BS = BATCH_SIZE * SEQ_LEN / 4096
211+
LR = min(0.01, (0.33 * math.sqrt(_EFFECTIVE_BS)) / HIDDEN_DIM)
212+
BETA2 = max(0.95, 0.98 ** (_EFFECTIVE_BS / 128))
213+
214+
OPTIMIZER = AdamConfig(
215+
learning_rate=LR,
216+
weight_decay=0.1,
217+
lr_schedule="cosine",
218+
decay=0.2,
219+
min_lr_ratio=0.1,
220+
warmup=1000,
221+
beta2=BETA2,
222+
)
223+
224+
TRAINER_CONFIG = GrugTrainerConfig(
225+
z_loss_weight=1e-4,
226+
ema_beta=None,
227+
log_every=1,
228+
)
229+
230+
EVAL_CONFIG = GrugEvalConfig(
231+
eval_batch_size=512,
232+
steps_per_eval=5000,
233+
max_eval_batches=8,
234+
eval_current=True,
235+
eval_ema=False,
236+
)
237+
238+
239+
# ---------------------------------------------------------------------------
240+
# Ablation arms
241+
# ---------------------------------------------------------------------------
242+
243+
244+
def _make_arm(label: str, gated_norm_rank: int | None) -> ExecutorStep:
245+
model = _build_model(gated_norm_rank=gated_norm_rank)
246+
run_id = _resolve_run_id(f"great-10t-gated-norm-{label}")
247+
gnr_tag = f"gated_norm_rank={gated_norm_rank}" if gated_norm_rank is not None else "gated_norm=off"
248+
return ExecutorStep(
249+
name=f"grug/great-10t-gated-norm-{label}",
250+
fn=run_great_moe,
251+
config=GreatMoeLaunchConfig(
252+
model=versioned(model),
253+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
254+
output_path=this_output_path(),
255+
run_id=run_id,
256+
resources=versioned(ResourceConfig.with_tpu("v5p-128")),
257+
steps=versioned(TRAIN_STEPS),
258+
batch_size=versioned(BATCH_SIZE),
259+
seed=versioned(0),
260+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
261+
tracker=WandbConfig(
262+
project="marin",
263+
tags=["grug", "moe", "great-10t", "gated-norm-ablation", gnr_tag],
264+
group="great-10t-gated-norm-ablation",
265+
name=None,
266+
),
267+
optimizer=versioned(OPTIMIZER),
268+
grug_trainer=versioned(TRAINER_CONFIG),
269+
eval=versioned(EVAL_CONFIG),
270+
),
271+
)
272+
273+
274+
# Baseline: no gated norms.
275+
great_10t_gated_norm_baseline = _make_arm("baseline", gated_norm_rank=None)
276+
277+
# Ablation: gated norms with rank 16.
278+
great_10t_gated_norm_enabled = _make_arm("enabled", gated_norm_rank=GATED_NORM_RANK)
279+
280+
281+
if __name__ == "__main__":
282+
executor_main(
283+
steps=[great_10t_gated_norm_baseline, great_10t_gated_norm_enabled],
284+
description="Great 10T ablation: gated norms (issue #4044). Two arms at 10T tokens.",
285+
)

0 commit comments

Comments
 (0)