Skip to content

Commit 8df5916

Browse files
[moe] Add gated norm support and ablation launch script
Add GatedNorm (low-rank self-gating after RMSNorm) to the MoE grug model with a gated_norm_rank config field, and create an ablation launch script comparing baseline vs gated-norm at ~1e19 FLOPs for the good 10T gate. Fixes #4026 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a243fe5 commit 8df5916

3 files changed

Lines changed: 225 additions & 2 deletions

File tree

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Ablation: gated norms in the MoE grug model at ~1e19 FLOPs.
5+
6+
Runs two matched configurations:
7+
- baseline (no gated norms)
8+
- gated_norm_rank=16
9+
10+
See https://github.com/marin-community/marin/issues/4026
11+
"""
12+
13+
import dataclasses
14+
15+
from fray.cluster import ResourceConfig
16+
from levanter.optim import AdamConfig
17+
from levanter.tracker.wandb import WandbConfig
18+
19+
from experiments.grug.moe.launch import (
20+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
21+
GrugMoeLaunchConfig,
22+
GrugTrainerConfig,
23+
run_grug_moe,
24+
)
25+
from experiments.grug.moe.model import GrugModelConfig
26+
from experiments.grug.moe.train import GrugEvalConfig
27+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
28+
29+
GATED_NORM_RANK = 16
30+
31+
_BASE_MODEL = GrugModelConfig(
32+
vocab_size=128_256,
33+
hidden_dim=768,
34+
intermediate_dim=2048,
35+
shared_expert_intermediate_dim=2048,
36+
num_experts=8,
37+
num_experts_per_token=2,
38+
num_layers=12,
39+
num_heads=12,
40+
num_kv_heads=12,
41+
max_seq_len=4096,
42+
)
43+
44+
_GATED_NORM_MODEL = dataclasses.replace(_BASE_MODEL, gated_norm_rank=GATED_NORM_RANK)
45+
46+
_OPTIMIZER = AdamConfig(
47+
learning_rate=3e-3,
48+
weight_decay=0.1,
49+
lr_schedule="cosine",
50+
decay=0.2,
51+
min_lr_ratio=0.1,
52+
warmup=500,
53+
)
54+
55+
_TRAINER = GrugTrainerConfig(
56+
z_loss_weight=1e-4,
57+
ema_beta=None,
58+
log_every=1,
59+
)
60+
61+
_EVAL = GrugEvalConfig(
62+
eval_batch_size=512,
63+
steps_per_eval=500,
64+
max_eval_batches=8,
65+
eval_current=True,
66+
eval_ema=False,
67+
)
68+
69+
_WANDB_TAGS = ["grug", "moe", "good-10t", "ablation", "gated-norm"]
70+
_STEPS = 2_130
71+
_BATCH_SIZE = 512
72+
_RESOURCES = ResourceConfig.with_tpu("v5p-8")
73+
74+
75+
def _make_launch_config(
76+
model: GrugModelConfig,
77+
run_id: str,
78+
wandb_group: str,
79+
extra_tags: list[str] | None = None,
80+
) -> GrugMoeLaunchConfig:
81+
tags = list(_WANDB_TAGS) + (extra_tags or [])
82+
return GrugMoeLaunchConfig(
83+
model=versioned(model),
84+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
85+
output_path=this_output_path(),
86+
run_id=run_id,
87+
resources=versioned(_RESOURCES),
88+
steps=versioned(_STEPS),
89+
batch_size=versioned(_BATCH_SIZE),
90+
seed=versioned(0),
91+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
92+
tracker=WandbConfig(
93+
project="marin",
94+
tags=tags,
95+
group=wandb_group,
96+
name=None,
97+
),
98+
optimizer=versioned(_OPTIMIZER),
99+
grug_trainer=versioned(_TRAINER),
100+
eval=versioned(_EVAL),
101+
)
102+
103+
104+
ablate_gated_norm_baseline = ExecutorStep(
105+
name="grug/ablate-gated-norm-baseline",
106+
fn=run_grug_moe,
107+
config=_make_launch_config(
108+
model=_BASE_MODEL,
109+
run_id="ablate-gated-norm-baseline",
110+
wandb_group="ablate-gated-norm",
111+
extra_tags=["baseline"],
112+
),
113+
)
114+
115+
ablate_gated_norm_enabled = ExecutorStep(
116+
name="grug/ablate-gated-norm-enabled",
117+
fn=run_grug_moe,
118+
config=_make_launch_config(
119+
model=_GATED_NORM_MODEL,
120+
run_id="ablate-gated-norm-enabled",
121+
wandb_group="ablate-gated-norm",
122+
extra_tags=[f"gated_norm_rank={GATED_NORM_RANK}"],
123+
),
124+
)
125+
126+
127+
if __name__ == "__main__":
128+
executor_main(
129+
steps=[ablate_gated_norm_baseline, ablate_gated_norm_enabled],
130+
description="Ablation: gated norms in MoE grug at ~1e19 FLOPs (issue #4026).",
131+
)

experiments/grug/moe/model.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class GrugModelConfig:
6767
load_balancing_loss_coef: float | None = 0.01
6868
router_z_loss_coef: float | None = 0.001
6969
moe_implementation: MoeImplementation | None = None
70+
gated_norm_rank: int | None = None
7071
rope: RotaryConfig = dataclasses.field(default_factory=RotaryConfig)
7172

7273
def __post_init__(self) -> None:
@@ -92,6 +93,8 @@ def __post_init__(self) -> None:
9293
raise ValueError("load_balancing_loss_coef must be non-negative when set")
9394
if self.router_z_loss_coef is not None and self.router_z_loss_coef < 0:
9495
raise ValueError("router_z_loss_coef must be non-negative when set")
96+
if self.gated_norm_rank is not None and self.gated_norm_rank <= 0:
97+
raise ValueError("gated_norm_rank must be positive when set")
9598

9699
@property
97100
def inferred_head_dim(self) -> int:
@@ -156,6 +159,32 @@ def __call__(self, x: Float[Array, "... D"]) -> Float[Array, "... D"]:
156159
return (normed * weight).astype(dtype)
157160

158161

162+
class GatedNorm(eqx.Module):
163+
"""Low-rank self-gating applied after RMSNorm.
164+
165+
Computes: x * sigmoid(up(silu(down(x)))), where down projects from
166+
hidden_dim to rank and up projects back.
167+
"""
168+
169+
w_down: jax.Array
170+
w_up: jax.Array
171+
172+
@staticmethod
173+
def init(hidden_dim: int, rank: int, initializer_std: float, *, key: PRNGKeyArray) -> "GatedNorm":
174+
k_down, k_up = random.split(key)
175+
return GatedNorm(
176+
w_down=reshard(_init_weight(k_down, (hidden_dim, rank), initializer_std), P(None, None)),
177+
w_up=reshard(_init_weight(k_up, (rank, hidden_dim), initializer_std), P(None, None)),
178+
)
179+
180+
@named_call
181+
def __call__(self, x: Float[Array, "... D"]) -> Float[Array, "... D"]:
182+
gate_hidden = jnp.einsum("...d,dr->...r", x, self.w_down)
183+
gate_hidden = jax.nn.silu(gate_hidden)
184+
gate = jax.nn.sigmoid(jnp.einsum("...r,rd->...d", gate_hidden, self.w_up))
185+
return x * gate.astype(x.dtype)
186+
187+
159188
class DenseMLP(eqx.Module):
160189
w_gate: jax.Array
161190
w_up: jax.Array
@@ -343,14 +372,16 @@ def __call__(
343372

344373
class Block(eqx.Module):
345374
rms_attn: RMSNorm
375+
gated_norm_attn: GatedNorm | None
346376
attn: CausalSelfAttention
347377
rms_mlp: RMSNorm
378+
gated_norm_mlp: GatedNorm | None
348379
mlp: MoEMLP
349380
shared: DenseMLP | None
350381

351382
@staticmethod
352383
def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "Block":
353-
attn_key, mlp_key, shared_key = random.split(key, 3)
384+
attn_key, mlp_key, shared_key, gn_attn_key, gn_mlp_key = random.split(key, 5)
354385
shared = None
355386
if cfg.shared_expert_intermediate_dim > 0:
356387
shared = DenseMLP.init(
@@ -359,10 +390,17 @@ def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "Block":
359390
cfg.initializer_std,
360391
key=shared_key,
361392
)
393+
gated_norm_attn = None
394+
gated_norm_mlp = None
395+
if cfg.gated_norm_rank is not None:
396+
gated_norm_attn = GatedNorm.init(cfg.hidden_dim, cfg.gated_norm_rank, cfg.initializer_std, key=gn_attn_key)
397+
gated_norm_mlp = GatedNorm.init(cfg.hidden_dim, cfg.gated_norm_rank, cfg.initializer_std, key=gn_mlp_key)
362398
return Block(
363399
rms_attn=RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps),
400+
gated_norm_attn=gated_norm_attn,
364401
attn=CausalSelfAttention.init(cfg, key=attn_key),
365402
rms_mlp=RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps),
403+
gated_norm_mlp=gated_norm_mlp,
366404
mlp=MoEMLP.init(cfg, key=mlp_key),
367405
shared=shared,
368406
)
@@ -373,8 +411,13 @@ def __call__(
373411
x: Float[Array, "B S D"],
374412
mask: AttentionMask | jax.Array,
375413
) -> tuple[Float[Array, "B S D"], dict[str, jax.Array]]:
376-
x = x + self.attn(self.rms_attn(x), mask)
414+
attn_in = self.rms_attn(x)
415+
if self.gated_norm_attn is not None:
416+
attn_in = self.gated_norm_attn(attn_in)
417+
x = x + self.attn(attn_in, mask)
377418
mlp_in = self.rms_mlp(x)
419+
if self.gated_norm_mlp is not None:
420+
mlp_in = self.gated_norm_mlp(mlp_in)
378421
mlp_out, router_stats = self.mlp(mlp_in)
379422
if self.shared is not None:
380423
mlp_out = mlp_out + self.shared(mlp_in, activation=ActivationFunctionEnum.silu)
@@ -518,6 +561,7 @@ def debug_mesh_and_token_pspec(num_devices: int) -> tuple[jax.sharding.AbstractM
518561
"Block",
519562
"CausalSelfAttention",
520563
"DenseMLP",
564+
"GatedNorm",
521565
"GrugModelConfig",
522566
"MoEMLP",
523567
"MoeActivation",

tests/test_grug_variant_contracts.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,54 @@ def build():
179179
assert with_ema_state_shape.ema_params is not None
180180

181181

182+
def test_grug_moe_gated_norm_lowers():
183+
"""Verify that the MoE variant with gated_norm_rank lowers without error."""
184+
from experiments.grug.moe.model import GrugModelConfig, debug_mesh_and_token_pspec
185+
from experiments.grug.moe.train import initial_state as moe_initial_state, _make_train_step
186+
187+
cfg = GrugModelConfig(
188+
vocab_size=1024,
189+
hidden_dim=32,
190+
intermediate_dim=64,
191+
shared_expert_intermediate_dim=64,
192+
num_experts=4,
193+
num_experts_per_token=2,
194+
num_layers=2,
195+
num_heads=2,
196+
num_kv_heads=2,
197+
max_seq_len=4,
198+
gated_norm_rank=8,
199+
)
200+
optimizer = optax.adam(1e-2)
201+
mp = jmp.get_policy("f32")
202+
train_step = _make_train_step(optimizer, mp, z_loss_weight=0.0, ema_beta=None)
203+
mesh, token_pspec = debug_mesh_and_token_pspec(num_devices=4)
204+
batch = GrugLmExample(
205+
tokens=jnp.zeros((8, 4), dtype=jnp.int32),
206+
loss_weight=jnp.ones((8, 4), dtype=jnp.float32),
207+
attn_mask=GrugAttentionMask.causal(),
208+
)
209+
210+
def one_step():
211+
sharded_batch = dataclasses.replace(
212+
batch,
213+
tokens=jax.sharding.reshard(batch.tokens, token_pspec),
214+
loss_weight=jax.sharding.reshard(batch.loss_weight, token_pspec),
215+
)
216+
state = moe_initial_state(cfg, optimizer=optimizer, mp=mp, key=jax.random.PRNGKey(0), ema_beta=None)
217+
return train_step(state, sharded_batch, compute_watch=False)
218+
219+
with _reset_abstract_mesh(), use_abstract_mesh(mesh):
220+
out_state_shape, out_metrics_shape, _out_watch_shape = eqx.filter_eval_shape(one_step)
221+
222+
assert out_state_shape.step.shape == ()
223+
assert "train/loss" in out_metrics_shape
224+
# Verify gated norm params exist in the model tree
225+
block = out_state_shape.params.blocks[0]
226+
assert block.gated_norm_attn is not None
227+
assert block.gated_norm_mlp is not None
228+
229+
182230
def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
183231
train_module = importlib.import_module("experiments.grug.base.train")
184232
model_module = importlib.import_module("experiments.grug.base.model")

0 commit comments

Comments
 (0)