Skip to content

Commit 38fbf21

Browse files
[moe] Add expert count sweep script for E in {128,256,512}
Adds experiments/grug/moe/sweep_expert_count.py to sweep num_experts across {128, 256, 512} using the grug MoE template, supporting the 10T gate investigation. Part of #4030 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a243fe5 commit 38fbf21

1 file changed

Lines changed: 129 additions & 0 deletions

File tree

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Sweep expert count E in {128, 256, 512} for the 10T gate MoE recipe.
5+
6+
This experiment varies only the number of routed experts while holding the
7+
per-expert intermediate dimension, shared expert, K (experts per token), and
8+
all other hyperparameters fixed. The goal is to determine whether expert
9+
count is a significant remaining lever for the baseline recipe.
10+
11+
Ref: https://github.com/marin-community/marin/issues/4030
12+
Parent sweep: https://github.com/marin-community/marin/issues/3469
13+
Gate: https://github.com/marin-community/marin/issues/4013
14+
"""
15+
16+
import dataclasses
17+
18+
from fray.cluster import ResourceConfig
19+
from levanter.optim import AdamConfig
20+
from levanter.tracker.wandb import WandbConfig
21+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
22+
23+
from experiments.grug.moe.launch import (
24+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
25+
GrugMoeLaunchConfig,
26+
_resolve_run_id,
27+
run_grug_moe,
28+
)
29+
from experiments.grug.moe.model import GrugModelConfig
30+
from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig
31+
32+
# ---------------------------------------------------------------------------
33+
# Base model config for the 10T gate recipe.
34+
#
35+
# This mirrors the trial model dimensions but can be replaced with the final
36+
# gate config once #4013 locks in the architecture. Only `num_experts` is
37+
# swept; everything else stays constant across arms.
38+
# ---------------------------------------------------------------------------
39+
BASE_MODEL = GrugModelConfig(
40+
vocab_size=128_256,
41+
hidden_dim=512,
42+
intermediate_dim=1792,
43+
shared_expert_intermediate_dim=1792,
44+
num_experts=128, # overridden per arm
45+
num_experts_per_token=2,
46+
num_layers=6,
47+
num_heads=8,
48+
num_kv_heads=8,
49+
max_seq_len=4096,
50+
head_dim=None,
51+
)
52+
53+
EXPERT_COUNTS = (128, 256, 512)
54+
55+
OPTIMIZER = AdamConfig(
56+
learning_rate=3e-3,
57+
weight_decay=0.1,
58+
lr_schedule="cosine",
59+
decay=0.2,
60+
min_lr_ratio=0.1,
61+
warmup=1000,
62+
)
63+
64+
TRAINER = GrugTrainerConfig(
65+
z_loss_weight=1e-4,
66+
ema_beta=None,
67+
log_every=1,
68+
)
69+
70+
EVAL = GrugEvalConfig(
71+
eval_batch_size=512,
72+
steps_per_eval=1000,
73+
max_eval_batches=8,
74+
eval_current=True,
75+
eval_ema=False,
76+
)
77+
78+
STEPS = 2_000
79+
BATCH_SIZE = 512
80+
SEED = 0
81+
MP_POLICY = "params=float32,compute=bfloat16,output=bfloat16"
82+
83+
84+
def _build_step(num_experts: int) -> ExecutorStep:
85+
"""Build an ExecutorStep for a single expert-count arm."""
86+
tag = f"e{num_experts}"
87+
run_id = _resolve_run_id(f"grug-moe-sweep-E-{tag}")
88+
model = dataclasses.replace(BASE_MODEL, num_experts=num_experts)
89+
90+
config = GrugMoeLaunchConfig(
91+
model=versioned(model),
92+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
93+
output_path=this_output_path(),
94+
run_id=run_id,
95+
resources=versioned(
96+
# Start with v5p-8 (matches the trial template). For large E
97+
# the expert mesh axis or slice count may need adjustment.
98+
ResourceConfig.with_tpu("v5p-8"),
99+
),
100+
steps=versioned(STEPS),
101+
batch_size=versioned(BATCH_SIZE),
102+
seed=versioned(SEED),
103+
mp=versioned(MP_POLICY),
104+
tracker=WandbConfig(
105+
project="marin",
106+
tags=["grug", "moe", "sweep-E", tag],
107+
group="grug-moe-sweep-E",
108+
name=None,
109+
),
110+
optimizer=versioned(OPTIMIZER),
111+
grug_trainer=versioned(TRAINER),
112+
eval=versioned(EVAL),
113+
)
114+
115+
return ExecutorStep(
116+
name=f"grug/moe-sweep-E-{tag}",
117+
fn=run_grug_moe,
118+
config=config,
119+
)
120+
121+
122+
sweep_steps = [_build_step(e) for e in EXPERT_COUNTS]
123+
124+
125+
if __name__ == "__main__":
126+
executor_main(
127+
steps=sweep_steps,
128+
description="Sweep expert count E in {128, 256, 512} for the 10T gate MoE recipe (#4030).",
129+
)

0 commit comments

Comments
 (0)