Skip to content

Commit 13a9abf

Browse files
[moe] Add multi-budget SWA ablation for great 10T gate
Add sliding_window field to GrugModelConfig and wire through Transformer.__call__. Create exp4045 isoflop sweep comparing full attention vs SWA-4096 at five FLOP budgets (3e18 through 9e19) to determine whether sliding-window attention is justified at scale. Includes contract tests for config wiring, FLOP matching, and validation. Fixes #4045 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a243fe5 commit 13a9abf

3 files changed

Lines changed: 251 additions & 1 deletion

File tree

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Experiment 4045: multi-budget sliding-window attention ablation for the great 10T gate.
5+
6+
Runs full-attention vs sliding-window (window=4096) at multiple FLOP budgets
7+
to build an isoflop scaling curve. Each arm is compute-matched:
8+
3 * flops_per_token * batch_size * seq_len * steps ~ budget.
9+
10+
Tracking issue: https://github.com/marin-community/marin/issues/4045
11+
Parent gate: https://github.com/marin-community/marin/issues/4014
12+
"""
13+
14+
import dataclasses
15+
16+
from fray.cluster import ResourceConfig
17+
from levanter.optim import AdamConfig
18+
from levanter.tracker.wandb import WandbConfig
19+
from levanter.utils.flop_utils import lm_flops_per_token
20+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
21+
22+
from experiments.grug.moe.launch import (
23+
GrugMoeLaunchConfig,
24+
GrugTrainerConfig,
25+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
26+
_resolve_run_id,
27+
run_grug_moe,
28+
)
29+
from experiments.grug.moe.model import GrugModelConfig
30+
from experiments.grug.moe.train import GrugEvalConfig
31+
32+
# ---------------------------------------------------------------------------
33+
# FLOP budgets — sweep from small to moderate scale
34+
# ---------------------------------------------------------------------------
35+
36+
FLOP_BUDGETS: tuple[float, ...] = (3e18, 9e18, 1.8e19, 3e19, 9e19)
37+
38+
BATCH_SIZE = 512
39+
SEQ_LEN = 4096
40+
VOCAB_SIZE = 128_256
41+
NUM_EXPERTS = 8
42+
NUM_EXPERTS_PER_TOKEN = 2
43+
SLIDING_WINDOW_SIZE = 4096
44+
45+
# ---------------------------------------------------------------------------
46+
# Model configs per budget — wider + deeper at higher budgets
47+
# ---------------------------------------------------------------------------
48+
49+
# Each entry: (hidden_dim, intermediate_dim, shared_expert_intermediate_dim,
50+
# num_layers, num_heads, num_kv_heads)
51+
# intermediate_dim ~ 3.5 * hidden_dim (SwiGLU convention).
52+
53+
_MODEL_SPECS: dict[float, tuple[int, int, int, int, int, int]] = {
54+
3e18: (384, 1344, 1344, 6, 6, 6),
55+
9e18: (512, 1792, 1792, 6, 8, 8),
56+
1.8e19: (512, 1792, 1792, 12, 8, 8),
57+
3e19: (768, 2688, 2688, 12, 12, 12),
58+
9e19: (1024, 3584, 3584, 16, 16, 16),
59+
}
60+
61+
62+
def _make_model(budget: float, *, swa: bool) -> GrugModelConfig:
63+
hidden, inter, shared_inter, layers, heads, kv_heads = _MODEL_SPECS[budget]
64+
return GrugModelConfig(
65+
vocab_size=VOCAB_SIZE,
66+
hidden_dim=hidden,
67+
intermediate_dim=inter,
68+
shared_expert_intermediate_dim=shared_inter,
69+
num_experts=NUM_EXPERTS,
70+
num_experts_per_token=NUM_EXPERTS_PER_TOKEN,
71+
num_layers=layers,
72+
num_heads=heads,
73+
num_kv_heads=kv_heads,
74+
max_seq_len=SEQ_LEN,
75+
head_dim=None,
76+
sliding_window=SLIDING_WINDOW_SIZE if swa else None,
77+
)
78+
79+
80+
def model_flops_per_token(model: GrugModelConfig) -> float:
81+
return lm_flops_per_token(
82+
hidden_dim=model.hidden_dim,
83+
intermediate_dim=model.intermediate_dim,
84+
shared_intermediate_dim=model.shared_expert_intermediate_dim,
85+
num_layers=model.num_layers,
86+
num_kv_heads=model.num_kv_heads,
87+
num_heads=model.num_heads,
88+
seq_len=model.max_seq_len,
89+
vocab_size=model.vocab_size,
90+
glu=True,
91+
num_experts=model.num_experts,
92+
num_shared_experts=1 if model.shared_expert_intermediate_dim > 0 else 0,
93+
num_experts_per_tok=model.num_experts_per_token,
94+
)
95+
96+
97+
def steps_for_budget(fpt: float, budget: float) -> int:
98+
"""Compute training steps so total FLOPs ~ budget."""
99+
tokens = budget / (3 * fpt)
100+
return max(1, round(tokens / (BATCH_SIZE * SEQ_LEN)))
101+
102+
103+
# ---------------------------------------------------------------------------
104+
# Common training knobs
105+
# ---------------------------------------------------------------------------
106+
107+
_OPTIMIZER = AdamConfig(
108+
learning_rate=3e-3,
109+
weight_decay=0.1,
110+
lr_schedule="cosine",
111+
decay=0.2,
112+
min_lr_ratio=0.1,
113+
warmup=1000,
114+
)
115+
116+
_GRUG_TRAINER = GrugTrainerConfig(
117+
z_loss_weight=1e-4,
118+
ema_beta=None,
119+
log_every=1,
120+
)
121+
122+
_EVAL = GrugEvalConfig(
123+
eval_batch_size=512,
124+
steps_per_eval=1000,
125+
max_eval_batches=8,
126+
eval_current=True,
127+
eval_ema=False,
128+
)
129+
130+
131+
def _wandb(group: str) -> WandbConfig:
132+
return WandbConfig(
133+
project="marin",
134+
tags=["grug", "moe", "exp4045", "swa-ablation", "great-gate"],
135+
group=group,
136+
name=None,
137+
)
138+
139+
140+
# ---------------------------------------------------------------------------
141+
# Build executor steps for every (budget, swa/full) pair
142+
# ---------------------------------------------------------------------------
143+
144+
145+
def _build_steps() -> list[ExecutorStep]:
146+
steps: list[ExecutorStep] = []
147+
for budget in FLOP_BUDGETS:
148+
budget_tag = f"{budget:.0e}"
149+
for swa in (False, True):
150+
arm = "swa-4096" if swa else "full-attn"
151+
model = _make_model(budget, swa=swa)
152+
fpt = model_flops_per_token(model)
153+
num_steps = steps_for_budget(fpt, budget)
154+
run_id = _resolve_run_id(f"exp4045-{arm}-{budget_tag}")
155+
step = ExecutorStep(
156+
name=f"grug/exp4045-{arm}-{budget_tag}",
157+
fn=run_grug_moe,
158+
config=GrugMoeLaunchConfig(
159+
model=versioned(model),
160+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
161+
output_path=this_output_path(),
162+
run_id=run_id,
163+
resources=versioned(ResourceConfig.with_tpu("v5p-8")),
164+
steps=versioned(num_steps),
165+
batch_size=versioned(BATCH_SIZE),
166+
seed=versioned(0),
167+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
168+
tracker=_wandb(f"exp4045-swa-ablation-{budget_tag}"),
169+
optimizer=versioned(_OPTIMIZER),
170+
grug_trainer=versioned(_GRUG_TRAINER),
171+
eval=versioned(_EVAL),
172+
),
173+
)
174+
steps.append(step)
175+
return steps
176+
177+
178+
ALL_STEPS = _build_steps()
179+
180+
181+
if __name__ == "__main__":
182+
executor_main(
183+
steps=ALL_STEPS,
184+
description="Exp 4045: multi-budget SWA ablation for the great 10T gate. Fixes #4045.",
185+
)

experiments/grug/moe/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class GrugModelConfig:
6767
load_balancing_loss_coef: float | None = 0.01
6868
router_z_loss_coef: float | None = 0.001
6969
moe_implementation: MoeImplementation | None = None
70+
sliding_window: int | None = None
7071
rope: RotaryConfig = dataclasses.field(default_factory=RotaryConfig)
7172

7273
def __post_init__(self) -> None:
@@ -92,6 +93,8 @@ def __post_init__(self) -> None:
9293
raise ValueError("load_balancing_loss_coef must be non-negative when set")
9394
if self.router_z_loss_coef is not None and self.router_z_loss_coef < 0:
9495
raise ValueError("router_z_loss_coef must be non-negative when set")
96+
if self.sliding_window is not None and self.sliding_window <= 0:
97+
raise ValueError(f"sliding_window must be positive when set, got {self.sliding_window}")
9598

9699
@property
97100
def inferred_head_dim(self) -> int:
@@ -419,7 +422,7 @@ def __call__(
419422
mask: AttentionMask | jax.Array | None = None,
420423
) -> tuple[Float[Array, "B S D"], dict[str, jax.Array]]:
421424
if mask is None:
422-
mask = AttentionMask.causal()
425+
mask = AttentionMask.causal(sliding_window=self.config.sliding_window)
423426

424427
batch_spec = _batch_spec()
425428
hidden = self.token_embed.at[token_ids].get(out_sharding=batch_spec)

tests/test_exp4045_swa_sweep.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import dataclasses
5+
6+
from experiments.grug.moe.exp4045_ablate_swa_sweep import (
7+
ALL_STEPS,
8+
BATCH_SIZE,
9+
FLOP_BUDGETS,
10+
SEQ_LEN,
11+
SLIDING_WINDOW_SIZE,
12+
_make_model,
13+
model_flops_per_token,
14+
steps_for_budget,
15+
)
16+
from experiments.grug.moe.model import GrugModelConfig
17+
18+
import pytest
19+
20+
21+
def test_each_budget_has_two_arms():
22+
assert len(ALL_STEPS) == 2 * len(FLOP_BUDGETS)
23+
24+
25+
def test_swa_and_full_differ_only_in_sliding_window():
26+
for budget in FLOP_BUDGETS:
27+
full = _make_model(budget, swa=False)
28+
swa = _make_model(budget, swa=True)
29+
assert full.sliding_window is None
30+
assert swa.sliding_window == SLIDING_WINDOW_SIZE
31+
assert full.hidden_dim == swa.hidden_dim
32+
assert full.num_experts == swa.num_experts
33+
assert full.num_layers == swa.num_layers
34+
assert full.shared_expert_intermediate_dim == swa.shared_expert_intermediate_dim
35+
36+
37+
def test_flop_budgets_are_close_to_target():
38+
for budget in FLOP_BUDGETS:
39+
for swa in (False, True):
40+
model = _make_model(budget, swa=swa)
41+
fpt = model_flops_per_token(model)
42+
num_steps = steps_for_budget(fpt, budget)
43+
total = 3 * fpt * num_steps * BATCH_SIZE * SEQ_LEN
44+
ratio = total / budget
45+
assert 0.9 <= ratio <= 1.1, (
46+
f"budget={budget:.0e} swa={swa}: total={total:.2e} ratio={ratio:.3f}"
47+
)
48+
49+
50+
def test_swa_same_flops_as_full():
51+
"""SWA does not change the FLOP estimate (lm_flops_per_token uses full attention)."""
52+
for budget in FLOP_BUDGETS:
53+
full = _make_model(budget, swa=False)
54+
swa = _make_model(budget, swa=True)
55+
assert model_flops_per_token(full) == model_flops_per_token(swa)
56+
57+
58+
def test_sliding_window_validation():
59+
with pytest.raises(ValueError, match="sliding_window must be positive"):
60+
GrugModelConfig(vocab_size=128, sliding_window=0)
61+
with pytest.raises(ValueError, match="sliding_window must be positive"):
62+
GrugModelConfig(vocab_size=128, sliding_window=-1)

0 commit comments

Comments
 (0)