|
| 1 | +# Copyright The Marin Authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""MoE hyperparameter sweep at 3e18 FLOPs. |
| 5 | +
|
| 6 | +Explores core MoE architecture knobs (expert count, shared expert, routing |
| 7 | +density, aux loss coefficients) at the 3e18 FLOP budget to establish a |
| 8 | +reference recipe for the 10T scaling path. |
| 9 | +
|
| 10 | +Each config targets 3e18 total training FLOPs (including 3x fwd+bwd multiplier). |
| 11 | +All runs use the grug MoE template on Nemotron mix with seq_len=4096 and |
| 12 | +vocab=128256. |
| 13 | +
|
| 14 | +See https://github.com/marin-community/marin/issues/4018. |
| 15 | +""" |
| 16 | + |
| 17 | +import dataclasses |
| 18 | +import logging |
| 19 | +import os |
| 20 | +from dataclasses import dataclass |
| 21 | +from datetime import timedelta |
| 22 | + |
| 23 | +import jmp |
| 24 | +from fray.cluster import ResourceConfig |
| 25 | +from levanter.checkpoint import CheckpointerConfig |
| 26 | +from levanter.data.text import LmDataConfig |
| 27 | +from levanter.optim import AdamConfig |
| 28 | +from levanter.tracker import TrackerConfig |
| 29 | +from levanter.tracker.wandb import WandbConfig |
| 30 | +from levanter.trainer import TrainerConfig |
| 31 | +from levanter.utils.flop_utils import lm_flops_per_token |
| 32 | +from levanter.utils.mesh import MeshConfig |
| 33 | +from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned |
| 34 | +from marin.processing.tokenize import add_validation_sets_to_mixture |
| 35 | + |
| 36 | +from experiments.defaults import default_validation_sets |
| 37 | +from experiments.grug.moe.launch import GrugMoeLaunchConfig, run_grug_moe |
| 38 | +from experiments.grug.moe.model import GrugModelConfig |
| 39 | +from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig |
| 40 | +from experiments.pretraining_datasets import nemotron_mix_block_shuffle |
| 41 | + |
| 42 | +logger = logging.getLogger(__name__) |
| 43 | + |
| 44 | +TARGET_FLOPS = 3e18 |
| 45 | +SEQ_LEN = 4096 |
| 46 | +VOCAB_SIZE = 128_256 |
| 47 | +BATCH_SIZE = 256 |
| 48 | + |
| 49 | +NEMOTRON_MIX_WITH_VALIDATION = add_validation_sets_to_mixture( |
| 50 | + nemotron_mix_block_shuffle, |
| 51 | + default_validation_sets(tokenizer=nemotron_mix_block_shuffle.tokenizer), |
| 52 | +) |
| 53 | + |
| 54 | + |
| 55 | +def _steps_for_budget(model: GrugModelConfig, batch_size: int) -> int: |
| 56 | + """Compute training steps to hit TARGET_FLOPS for a given model config.""" |
| 57 | + fpt = lm_flops_per_token( |
| 58 | + hidden_dim=model.hidden_dim, |
| 59 | + intermediate_dim=model.intermediate_dim, |
| 60 | + shared_intermediate_dim=model.shared_expert_intermediate_dim, |
| 61 | + num_layers=model.num_layers, |
| 62 | + num_kv_heads=model.num_kv_heads, |
| 63 | + num_heads=model.num_heads, |
| 64 | + seq_len=model.max_seq_len, |
| 65 | + vocab_size=model.vocab_size, |
| 66 | + glu=True, |
| 67 | + num_experts=model.num_experts, |
| 68 | + num_shared_experts=1 if model.shared_expert_intermediate_dim > 0 else 0, |
| 69 | + num_experts_per_tok=model.num_experts_per_token, |
| 70 | + ) |
| 71 | + flops_per_step = 3 * fpt * model.max_seq_len * batch_size |
| 72 | + return int(TARGET_FLOPS / flops_per_step) |
| 73 | + |
| 74 | + |
| 75 | +# --------------------------------------------------------------------------- |
| 76 | +# Baseline model: d=768, L=12, E=8, K=2, shared expert |
| 77 | +# ~1300 steps at bs=256 → ~1.4B tokens |
| 78 | +# --------------------------------------------------------------------------- |
| 79 | +BASELINE = GrugModelConfig( |
| 80 | + vocab_size=VOCAB_SIZE, |
| 81 | + hidden_dim=768, |
| 82 | + intermediate_dim=2048, |
| 83 | + shared_expert_intermediate_dim=2048, |
| 84 | + num_experts=8, |
| 85 | + num_experts_per_token=2, |
| 86 | + num_layers=12, |
| 87 | + num_heads=12, |
| 88 | + num_kv_heads=4, |
| 89 | + max_seq_len=SEQ_LEN, |
| 90 | + head_dim=None, |
| 91 | + load_balancing_loss_coef=0.01, |
| 92 | + router_z_loss_coef=0.001, |
| 93 | +) |
| 94 | + |
| 95 | + |
| 96 | +@dataclass(frozen=True) |
| 97 | +class SweepPoint: |
| 98 | + """One arm of the sweep.""" |
| 99 | + |
| 100 | + name: str |
| 101 | + model: GrugModelConfig |
| 102 | + batch_size: int = BATCH_SIZE |
| 103 | + |
| 104 | + |
| 105 | +def _expert_count_variants() -> list[SweepPoint]: |
| 106 | + """Axis 1: E in {8, 16, 32} with K=2 (FLOPs ~constant).""" |
| 107 | + return [ |
| 108 | + SweepPoint("e8-k2", BASELINE), |
| 109 | + SweepPoint("e16-k2", dataclasses.replace(BASELINE, num_experts=16)), |
| 110 | + SweepPoint("e32-k2", dataclasses.replace(BASELINE, num_experts=32)), |
| 111 | + ] |
| 112 | + |
| 113 | + |
| 114 | +def _shared_expert_variants() -> list[SweepPoint]: |
| 115 | + """Axis 2: shared expert on vs off.""" |
| 116 | + no_shared = dataclasses.replace(BASELINE, shared_expert_intermediate_dim=0) |
| 117 | + return [ |
| 118 | + SweepPoint("shared-on", BASELINE), |
| 119 | + SweepPoint("shared-off", no_shared), |
| 120 | + ] |
| 121 | + |
| 122 | + |
| 123 | +def _routing_density_variants() -> list[SweepPoint]: |
| 124 | + """Axis 3: K=2 vs K=4. K=4 doubles routed MLP FLOPs so we halve intermediate_dim.""" |
| 125 | + k4 = dataclasses.replace( |
| 126 | + BASELINE, |
| 127 | + num_experts_per_token=4, |
| 128 | + intermediate_dim=1024, |
| 129 | + shared_expert_intermediate_dim=1024, |
| 130 | + ) |
| 131 | + return [ |
| 132 | + SweepPoint("k2-i2048", BASELINE), |
| 133 | + SweepPoint("k4-i1024", k4), |
| 134 | + ] |
| 135 | + |
| 136 | + |
| 137 | +def _aux_loss_variants() -> list[SweepPoint]: |
| 138 | + """Axis 4: aux loss coefficient grid.""" |
| 139 | + points = [] |
| 140 | + for lbl in [0.001, 0.01, 0.1]: |
| 141 | + for rzl in [0.0, 0.001]: |
| 142 | + name = f"lbl{lbl}-rzl{rzl}" |
| 143 | + model = dataclasses.replace( |
| 144 | + BASELINE, |
| 145 | + load_balancing_loss_coef=lbl, |
| 146 | + router_z_loss_coef=rzl if rzl > 0 else None, |
| 147 | + ) |
| 148 | + points.append(SweepPoint(name, model)) |
| 149 | + return points |
| 150 | + |
| 151 | + |
| 152 | +def all_sweep_points() -> list[SweepPoint]: |
| 153 | + """Deduplicated union of all sweep axes.""" |
| 154 | + seen: set[str] = set() |
| 155 | + points: list[SweepPoint] = [] |
| 156 | + for axis_fn in [_expert_count_variants, _shared_expert_variants, _routing_density_variants, _aux_loss_variants]: |
| 157 | + for pt in axis_fn(): |
| 158 | + if pt.name not in seen: |
| 159 | + seen.add(pt.name) |
| 160 | + points.append(pt) |
| 161 | + return points |
| 162 | + |
| 163 | + |
| 164 | +def _resolve_run_id(sweep_name: str) -> str: |
| 165 | + run_id = f"moe-3e18-{sweep_name}" |
| 166 | + ferry_date = os.environ.get("FERRY_DATE") |
| 167 | + if ferry_date: |
| 168 | + run_id = f"{run_id}-{ferry_date}" |
| 169 | + return run_id |
| 170 | + |
| 171 | + |
| 172 | +def _build_step(point: SweepPoint) -> ExecutorStep: |
| 173 | + """Build an ExecutorStep for one sweep arm.""" |
| 174 | + steps = _steps_for_budget(point.model, point.batch_size) |
| 175 | + run_id = _resolve_run_id(point.name) |
| 176 | + logger.info("Sweep point %s: %d steps at bs=%d", point.name, steps, point.batch_size) |
| 177 | + |
| 178 | + config = GrugMoeLaunchConfig( |
| 179 | + model=versioned(point.model), |
| 180 | + data=NEMOTRON_MIX_WITH_VALIDATION, |
| 181 | + output_path=this_output_path(), |
| 182 | + run_id=run_id, |
| 183 | + resources=versioned(ResourceConfig.with_tpu("v5p-8")), |
| 184 | + steps=versioned(steps), |
| 185 | + batch_size=versioned(point.batch_size), |
| 186 | + seed=versioned(0), |
| 187 | + mp=versioned("params=float32,compute=bfloat16,output=bfloat16"), |
| 188 | + tracker=WandbConfig( |
| 189 | + project="marin", |
| 190 | + tags=["moe", "3e18", "hparam-sweep"], |
| 191 | + group="moe-3e18-sweep", |
| 192 | + name=None, |
| 193 | + ), |
| 194 | + optimizer=versioned( |
| 195 | + AdamConfig( |
| 196 | + learning_rate=3e-3, |
| 197 | + weight_decay=0.1, |
| 198 | + lr_schedule="cosine", |
| 199 | + decay=0.2, |
| 200 | + min_lr_ratio=0.1, |
| 201 | + warmup=200, |
| 202 | + ) |
| 203 | + ), |
| 204 | + grug_trainer=versioned( |
| 205 | + GrugTrainerConfig( |
| 206 | + z_loss_weight=1e-4, |
| 207 | + ema_beta=None, |
| 208 | + log_every=1, |
| 209 | + ) |
| 210 | + ), |
| 211 | + eval=versioned( |
| 212 | + GrugEvalConfig( |
| 213 | + eval_batch_size=256, |
| 214 | + steps_per_eval=200, |
| 215 | + max_eval_batches=8, |
| 216 | + eval_current=True, |
| 217 | + eval_ema=False, |
| 218 | + ) |
| 219 | + ), |
| 220 | + ) |
| 221 | + |
| 222 | + return ExecutorStep( |
| 223 | + name=f"moe-3e18/{point.name}", |
| 224 | + fn=run_grug_moe, |
| 225 | + config=config, |
| 226 | + ) |
| 227 | + |
| 228 | + |
| 229 | +def build_sweep_steps() -> list[ExecutorStep]: |
| 230 | + return [_build_step(pt) for pt in all_sweep_points()] |
| 231 | + |
| 232 | + |
| 233 | +if __name__ == "__main__": |
| 234 | + executor_main( |
| 235 | + steps=build_sweep_steps(), |
| 236 | + description="MoE hparam sweep at 3e18 FLOPs (issue #4018).", |
| 237 | + ) |
0 commit comments