Skip to content

Commit 3c3d0a5

Browse files
[moe] Add headwise attention gate and ablation launch script
Add attention_gate config field to GrugModelConfig with headwise gating support in CausalSelfAttention. When enabled, a learned per-head sigmoid gate modulates attention output before the output projection. Includes ablation launch script comparing baseline vs headwise-gated at ~1e19 FLOPs and a lowering contract test. Fixes #4020 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a243fe5 commit 3c3d0a5

3 files changed

Lines changed: 190 additions & 2 deletions

File tree

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Ablation: attention gate in the MoE grug model at ~1e19 FLOPs.
5+
6+
Runs two matched configurations:
7+
- baseline (no attention gate)
8+
- headwise attention gate
9+
10+
See https://github.com/marin-community/marin/issues/4020
11+
"""
12+
13+
import dataclasses
14+
15+
from fray.cluster import ResourceConfig
16+
from levanter.optim import AdamConfig
17+
from levanter.tracker.wandb import WandbConfig
18+
19+
from experiments.grug.moe.launch import (
20+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
21+
GrugMoeLaunchConfig,
22+
GrugTrainerConfig,
23+
run_grug_moe,
24+
)
25+
from experiments.grug.moe.model import GrugModelConfig
26+
from experiments.grug.moe.train import GrugEvalConfig
27+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
28+
29+
_BASE_MODEL = GrugModelConfig(
30+
vocab_size=128_256,
31+
hidden_dim=768,
32+
intermediate_dim=2048,
33+
shared_expert_intermediate_dim=2048,
34+
num_experts=8,
35+
num_experts_per_token=2,
36+
num_layers=12,
37+
num_heads=12,
38+
num_kv_heads=12,
39+
max_seq_len=4096,
40+
)
41+
42+
_GATED_MODEL = dataclasses.replace(_BASE_MODEL, attention_gate="headwise")
43+
44+
_OPTIMIZER = AdamConfig(
45+
learning_rate=3e-3,
46+
weight_decay=0.1,
47+
lr_schedule="cosine",
48+
decay=0.2,
49+
min_lr_ratio=0.1,
50+
warmup=500,
51+
)
52+
53+
_TRAINER = GrugTrainerConfig(
54+
z_loss_weight=1e-4,
55+
ema_beta=None,
56+
log_every=1,
57+
)
58+
59+
_EVAL = GrugEvalConfig(
60+
eval_batch_size=512,
61+
steps_per_eval=500,
62+
max_eval_batches=8,
63+
eval_current=True,
64+
eval_ema=False,
65+
)
66+
67+
_WANDB_TAGS = ["grug", "moe", "good-10t", "ablation", "attention-gate"]
68+
_STEPS = 2_130
69+
_BATCH_SIZE = 512
70+
_RESOURCES = ResourceConfig.with_tpu("v5p-8")
71+
72+
73+
def _make_launch_config(
74+
model: GrugModelConfig,
75+
run_id: str,
76+
wandb_group: str,
77+
extra_tags: list[str] | None = None,
78+
) -> GrugMoeLaunchConfig:
79+
tags = list(_WANDB_TAGS) + (extra_tags or [])
80+
return GrugMoeLaunchConfig(
81+
model=versioned(model),
82+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
83+
output_path=this_output_path(),
84+
run_id=run_id,
85+
resources=versioned(_RESOURCES),
86+
steps=versioned(_STEPS),
87+
batch_size=versioned(_BATCH_SIZE),
88+
seed=versioned(0),
89+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
90+
tracker=WandbConfig(
91+
project="marin",
92+
tags=tags,
93+
group=wandb_group,
94+
name=None,
95+
),
96+
optimizer=versioned(_OPTIMIZER),
97+
grug_trainer=versioned(_TRAINER),
98+
eval=versioned(_EVAL),
99+
)
100+
101+
102+
ablate_attn_gate_baseline = ExecutorStep(
103+
name="grug/ablate-attn-gate-baseline",
104+
fn=run_grug_moe,
105+
config=_make_launch_config(
106+
model=_BASE_MODEL,
107+
run_id="ablate-attn-gate-baseline",
108+
wandb_group="ablate-attn-gate",
109+
extra_tags=["baseline"],
110+
),
111+
)
112+
113+
ablate_attn_gate_headwise = ExecutorStep(
114+
name="grug/ablate-attn-gate-headwise",
115+
fn=run_grug_moe,
116+
config=_make_launch_config(
117+
model=_GATED_MODEL,
118+
run_id="ablate-attn-gate-headwise",
119+
wandb_group="ablate-attn-gate",
120+
extra_tags=["headwise"],
121+
),
122+
)
123+
124+
125+
if __name__ == "__main__":
126+
executor_main(
127+
steps=[ablate_attn_gate_baseline, ablate_attn_gate_headwise],
128+
description="Ablation: attention gate in MoE grug at ~1e19 FLOPs (issue #4020).",
129+
)

experiments/grug/moe/model.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import dataclasses
1212

1313
from dataclasses import dataclass
14-
from typing import get_args
14+
from typing import Literal, get_args
1515
import equinox as eqx
1616
import jax
1717
import jax.numpy as jnp
@@ -67,6 +67,7 @@ class GrugModelConfig:
6767
load_balancing_loss_coef: float | None = 0.01
6868
router_z_loss_coef: float | None = 0.001
6969
moe_implementation: MoeImplementation | None = None
70+
attention_gate: Literal["none", "headwise"] = "none"
7071
rope: RotaryConfig = dataclasses.field(default_factory=RotaryConfig)
7172

7273
def __post_init__(self) -> None:
@@ -92,6 +93,8 @@ def __post_init__(self) -> None:
9293
raise ValueError("load_balancing_loss_coef must be non-negative when set")
9394
if self.router_z_loss_coef is not None and self.router_z_loss_coef < 0:
9495
raise ValueError("router_z_loss_coef must be non-negative when set")
96+
if self.attention_gate not in ("none", "headwise"):
97+
raise ValueError(f"attention_gate must be 'none' or 'headwise', got {self.attention_gate!r}")
9598

9699
@property
97100
def inferred_head_dim(self) -> int:
@@ -109,17 +112,22 @@ class CausalSelfAttention(eqx.Module):
109112
w_k: Float[Array, "D MH"]
110113
w_v: Float[Array, "D MH"]
111114
w_o: Float[Array, "NH D"]
115+
w_gate: Float[Array, "D N"] | None
112116
cfg: GrugModelConfig = eqx.field(static=True)
113117

114118
@staticmethod
115119
def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "CausalSelfAttention":
116-
k_q, k_k, k_v, k_o = random.split(key, 4)
120+
k_q, k_k, k_v, k_o, k_g = random.split(key, 5)
117121
d, n, m, h = cfg.hidden_dim, cfg.num_heads, cfg.num_kv_heads, cfg.inferred_head_dim
122+
w_gate = None
123+
if cfg.attention_gate == "headwise":
124+
w_gate = reshard(_init_weight(k_g, (d, n), cfg.initializer_std), P("data", "model"))
118125
return CausalSelfAttention(
119126
w_q=reshard(_init_weight(k_q, (d, n * h), cfg.initializer_std), P("data", "model")),
120127
w_k=reshard(_init_weight(k_k, (d, m * h), cfg.initializer_std), P("data", "model")),
121128
w_v=reshard(_init_weight(k_v, (d, m * h), cfg.initializer_std), P("data", "model")),
122129
w_o=reshard(_init_weight(k_o, (n * h, d), cfg.initializer_std), P("model", "data")),
130+
w_gate=w_gate,
123131
cfg=cfg,
124132
)
125133

@@ -134,6 +142,11 @@ def __call__(self, x: Float[Array, "B S D"], mask: AttentionMask | jax.Array) ->
134142
v = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_v), "... (m d) -> ... m d", d=head_dim)
135143
q, k = apply_rotary_embedding(q, k, seq_len=seq_len, head_dim=head_dim, rope=self.cfg.rope)
136144
attn_out = attention(q, k, v, mask)
145+
if self.w_gate is not None:
146+
# Headwise gating: sigmoid(x @ w_gate) produces one scalar per head,
147+
# broadcast across head_dim. Shape: (B, S, N, 1).
148+
gate = jax.nn.sigmoid(jnp.einsum("bsd,dn->bsn", x, self.w_gate))[..., None]
149+
attn_out = attn_out * gate
137150
attn_out = rearrange(attn_out, "... n d -> ... (n d)")
138151
return jnp.einsum("bsh,hd->bsd", attn_out, self.w_o, out_sharding=batch_spec)
139152

tests/test_grug_variant_contracts.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,52 @@ def build():
179179
assert with_ema_state_shape.ema_params is not None
180180

181181

182+
def test_grug_moe_attention_gate_lowers():
183+
"""Verify that the MoE variant with attention_gate='headwise' lowers without error."""
184+
from experiments.grug.moe.model import GrugModelConfig, debug_mesh_and_token_pspec
185+
from experiments.grug.moe.train import initial_state as moe_initial_state, _make_train_step
186+
187+
cfg = GrugModelConfig(
188+
vocab_size=1024,
189+
hidden_dim=32,
190+
intermediate_dim=64,
191+
shared_expert_intermediate_dim=64,
192+
num_experts=4,
193+
num_experts_per_token=2,
194+
num_layers=2,
195+
num_heads=2,
196+
num_kv_heads=2,
197+
max_seq_len=4,
198+
attention_gate="headwise",
199+
)
200+
optimizer = optax.adam(1e-2)
201+
mp = jmp.get_policy("f32")
202+
train_step = _make_train_step(optimizer, mp, z_loss_weight=0.0, ema_beta=None)
203+
mesh, token_pspec = debug_mesh_and_token_pspec(num_devices=4)
204+
batch = GrugLmExample(
205+
tokens=jnp.zeros((8, 4), dtype=jnp.int32),
206+
loss_weight=jnp.ones((8, 4), dtype=jnp.float32),
207+
attn_mask=GrugAttentionMask.causal(),
208+
)
209+
210+
def one_step():
211+
sharded_batch = dataclasses.replace(
212+
batch,
213+
tokens=jax.sharding.reshard(batch.tokens, token_pspec),
214+
loss_weight=jax.sharding.reshard(batch.loss_weight, token_pspec),
215+
)
216+
state = moe_initial_state(cfg, optimizer=optimizer, mp=mp, key=jax.random.PRNGKey(0), ema_beta=None)
217+
return train_step(state, sharded_batch, compute_watch=False)
218+
219+
with _reset_abstract_mesh(), use_abstract_mesh(mesh):
220+
out_state_shape, out_metrics_shape, _out_watch_shape = eqx.filter_eval_shape(one_step)
221+
222+
assert out_state_shape.step.shape == ()
223+
assert "train/loss" in out_metrics_shape
224+
block = out_state_shape.params.blocks[0]
225+
assert block.attn.w_gate is not None
226+
227+
182228
def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
183229
train_module = importlib.import_module("experiments.grug.base.train")
184230
model_module = importlib.import_module("experiments.grug.base.model")

0 commit comments

Comments
 (0)