Skip to content

Commit dc6e4c0

Browse files
[moe] Make capacity_factor configurable in GrugModelConfig and add sweep script
Add capacity_factor field to GrugModelConfig (default 1.25, matching the existing hardcoded value) so it can be varied in experiment sweeps. Add sweep_capacity_factor.py to sweep over {1.0, 1.125, 1.25, 1.5, 2.0}. Fixes #4017 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a243fe5 commit dc6e4c0

3 files changed

Lines changed: 144 additions & 1 deletion

File tree

experiments/grug/moe/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class GrugModelConfig:
6666
initializer_std: float = 0.02
6767
load_balancing_loss_coef: float | None = 0.01
6868
router_z_loss_coef: float | None = 0.001
69+
capacity_factor: float = _DEFAULT_EP_CAPACITY_FACTOR
6970
moe_implementation: MoeImplementation | None = None
7071
rope: RotaryConfig = dataclasses.field(default_factory=RotaryConfig)
7172

@@ -92,6 +93,8 @@ def __post_init__(self) -> None:
9293
raise ValueError("load_balancing_loss_coef must be non-negative when set")
9394
if self.router_z_loss_coef is not None and self.router_z_loss_coef < 0:
9495
raise ValueError("router_z_loss_coef must be non-negative when set")
96+
if self.capacity_factor <= 0:
97+
raise ValueError("capacity_factor must be positive")
9598

9699
@property
97100
def inferred_head_dim(self) -> int:
@@ -334,7 +337,7 @@ def __call__(
334337
activation=ActivationFunctionEnum.silu,
335338
implementation=self.cfg.moe_implementation,
336339
mesh=get_abstract_mesh(),
337-
capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR,
340+
capacity_factor=self.cfg.capacity_factor,
338341
)
339342
routed = rearrange(routed_flat, "(b s) d -> b s d", b=b, s=s)
340343
routed = reshard(routed, _batch_spec())
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Sweep capacity_factor for the MoE grug variant.
5+
6+
Runs the trial model at several capacity factors to determine whether the
7+
default 1.25 is safe or whether it masks avoidable overflow or throughput loss.
8+
9+
See: https://github.com/marin-community/marin/issues/4017
10+
"""
11+
12+
import dataclasses
13+
14+
from levanter.tracker.wandb import WandbConfig
15+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
16+
17+
from experiments.grug.moe.launch import (
18+
GRUG_MOE_TRIAL_MODEL,
19+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
20+
GrugMoeLaunchConfig,
21+
run_grug_moe,
22+
)
23+
from experiments.grug.moe.model import GrugModelConfig
24+
from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig
25+
26+
from fray.cluster import ResourceConfig
27+
from levanter.optim import AdamConfig
28+
29+
CAPACITY_FACTORS = [1.0, 1.125, 1.25, 1.5, 2.0]
30+
31+
32+
def _build_sweep_steps() -> list[ExecutorStep]:
33+
steps: list[ExecutorStep] = []
34+
for cf in CAPACITY_FACTORS:
35+
tag = f"cf{cf:.3f}".replace(".", "p")
36+
model = dataclasses.replace(GRUG_MOE_TRIAL_MODEL, capacity_factor=cf)
37+
run_id = f"grug-moe-sweep-cf-{tag}"
38+
step = ExecutorStep(
39+
name=f"grug/moe-sweep-cf-{tag}",
40+
fn=run_grug_moe,
41+
config=GrugMoeLaunchConfig(
42+
model=versioned(model),
43+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
44+
output_path=this_output_path(),
45+
run_id=run_id,
46+
resources=versioned(ResourceConfig.with_tpu("v5p-8")),
47+
steps=versioned(2_000),
48+
batch_size=versioned(512),
49+
seed=versioned(0),
50+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
51+
tracker=WandbConfig(
52+
project="marin",
53+
tags=["grug", "moe", "sweep", "capacity-factor"],
54+
group="grug-moe-sweep-capacity-factor",
55+
name=None,
56+
),
57+
optimizer=versioned(
58+
AdamConfig(
59+
learning_rate=3e-3,
60+
weight_decay=0.1,
61+
lr_schedule="cosine",
62+
decay=0.2,
63+
min_lr_ratio=0.1,
64+
warmup=1000,
65+
)
66+
),
67+
grug_trainer=versioned(
68+
GrugTrainerConfig(
69+
z_loss_weight=1e-4,
70+
ema_beta=None,
71+
log_every=1,
72+
)
73+
),
74+
eval=versioned(
75+
GrugEvalConfig(
76+
eval_batch_size=512,
77+
steps_per_eval=1000,
78+
max_eval_batches=8,
79+
eval_current=True,
80+
eval_ema=False,
81+
)
82+
),
83+
),
84+
)
85+
steps.append(step)
86+
return steps
87+
88+
89+
sweep_steps = _build_sweep_steps()
90+
91+
if __name__ == "__main__":
92+
executor_main(
93+
steps=sweep_steps,
94+
description="Sweep capacity_factor over {1.0, 1.125, 1.25, 1.5, 2.0} for the MoE grug trial model.",
95+
)

tests/test_grug_variant_contracts.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,48 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
263263
]
264264
for key in required_keys:
265265
assert key in summary
266+
267+
268+
@pytest.mark.parametrize("capacity_factor", [1.0, 1.5, 2.0])
269+
def test_moe_capacity_factor_config_propagates_to_lowering(capacity_factor: float):
270+
"""Verify that GrugModelConfig.capacity_factor is accepted and the model lowers with non-default values."""
271+
model_module = importlib.import_module("experiments.grug.moe.model")
272+
train_module = importlib.import_module("experiments.grug.moe.train")
273+
model_config_cls = model_module.GrugModelConfig
274+
mesh_fn = model_module.debug_mesh_and_token_pspec
275+
276+
cfg = model_config_cls(vocab_size=1024, capacity_factor=capacity_factor)
277+
assert cfg.capacity_factor == capacity_factor
278+
279+
optimizer = optax.adam(1e-2)
280+
mp = jmp.get_policy("f32")
281+
train_step = train_module._make_train_step(optimizer, mp, z_loss_weight=0.0, ema_beta=None)
282+
mesh, token_pspec = mesh_fn(num_devices=4)
283+
batch = GrugLmExample(
284+
tokens=jnp.zeros((8, 4), dtype=jnp.int32),
285+
loss_weight=jnp.ones((8, 4), dtype=jnp.float32),
286+
attn_mask=GrugAttentionMask.causal(),
287+
)
288+
289+
def one_step():
290+
sharded_batch = dataclasses.replace(
291+
batch,
292+
tokens=jax.sharding.reshard(batch.tokens, token_pspec),
293+
loss_weight=jax.sharding.reshard(batch.loss_weight, token_pspec),
294+
)
295+
state = train_module.initial_state(cfg, optimizer=optimizer, mp=mp, key=jax.random.PRNGKey(0), ema_beta=None)
296+
return train_step(state, sharded_batch, compute_watch=False)
297+
298+
with _reset_abstract_mesh(), use_abstract_mesh(mesh):
299+
out_state_shape, out_metrics_shape, _ = eqx.filter_eval_shape(one_step)
300+
301+
assert out_state_shape.step.shape == ()
302+
assert "train/loss" in out_metrics_shape
303+
304+
305+
def test_moe_capacity_factor_rejects_non_positive():
306+
model_module = importlib.import_module("experiments.grug.moe.model")
307+
with pytest.raises(ValueError, match="capacity_factor must be positive"):
308+
model_module.GrugModelConfig(vocab_size=1024, capacity_factor=0.0)
309+
with pytest.raises(ValueError, match="capacity_factor must be positive"):
310+
model_module.GrugModelConfig(vocab_size=1024, capacity_factor=-1.0)

0 commit comments

Comments
 (0)