|
| 1 | +# Copyright The Marin Authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Experiment 4045: multi-budget sliding-window attention ablation for the great 10T gate. |
| 5 | +
|
| 6 | +Runs full-attention vs sliding-window (window=4096) at multiple FLOP budgets |
| 7 | +to build an isoflop scaling curve. Each arm is compute-matched: |
| 8 | +3 * flops_per_token * batch_size * seq_len * steps ~ budget. |
| 9 | +
|
| 10 | +Tracking issue: https://github.com/marin-community/marin/issues/4045 |
| 11 | +Parent gate: https://github.com/marin-community/marin/issues/4014 |
| 12 | +""" |
| 13 | + |
| 14 | +import dataclasses |
| 15 | + |
| 16 | +from fray.cluster import ResourceConfig |
| 17 | +from levanter.optim import AdamConfig |
| 18 | +from levanter.tracker.wandb import WandbConfig |
| 19 | +from levanter.utils.flop_utils import lm_flops_per_token |
| 20 | +from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned |
| 21 | + |
| 22 | +from experiments.grug.moe.launch import ( |
| 23 | + GrugMoeLaunchConfig, |
| 24 | + GrugTrainerConfig, |
| 25 | + NEMOTRON_MIX_WITH_DEFAULT_VALIDATION, |
| 26 | + _resolve_run_id, |
| 27 | + run_grug_moe, |
| 28 | +) |
| 29 | +from experiments.grug.moe.model import GrugModelConfig |
| 30 | +from experiments.grug.moe.train import GrugEvalConfig |
| 31 | + |
| 32 | +# --------------------------------------------------------------------------- |
| 33 | +# FLOP budgets — sweep from small to moderate scale |
| 34 | +# --------------------------------------------------------------------------- |
| 35 | + |
| 36 | +FLOP_BUDGETS: tuple[float, ...] = (3e18, 9e18, 1.8e19, 3e19, 9e19) |
| 37 | + |
| 38 | +BATCH_SIZE = 512 |
| 39 | +SEQ_LEN = 4096 |
| 40 | +VOCAB_SIZE = 128_256 |
| 41 | +NUM_EXPERTS = 8 |
| 42 | +NUM_EXPERTS_PER_TOKEN = 2 |
| 43 | +SLIDING_WINDOW_SIZE = 4096 |
| 44 | + |
| 45 | +# --------------------------------------------------------------------------- |
| 46 | +# Model configs per budget — wider + deeper at higher budgets |
| 47 | +# --------------------------------------------------------------------------- |
| 48 | + |
| 49 | +# Each entry: (hidden_dim, intermediate_dim, shared_expert_intermediate_dim, |
| 50 | +# num_layers, num_heads, num_kv_heads) |
| 51 | +# intermediate_dim ~ 3.5 * hidden_dim (SwiGLU convention). |
| 52 | + |
| 53 | +_MODEL_SPECS: dict[float, tuple[int, int, int, int, int, int]] = { |
| 54 | + 3e18: (384, 1344, 1344, 6, 6, 6), |
| 55 | + 9e18: (512, 1792, 1792, 6, 8, 8), |
| 56 | + 1.8e19: (512, 1792, 1792, 12, 8, 8), |
| 57 | + 3e19: (768, 2688, 2688, 12, 12, 12), |
| 58 | + 9e19: (1024, 3584, 3584, 16, 16, 16), |
| 59 | +} |
| 60 | + |
| 61 | + |
| 62 | +def _make_model(budget: float, *, swa: bool) -> GrugModelConfig: |
| 63 | + hidden, inter, shared_inter, layers, heads, kv_heads = _MODEL_SPECS[budget] |
| 64 | + return GrugModelConfig( |
| 65 | + vocab_size=VOCAB_SIZE, |
| 66 | + hidden_dim=hidden, |
| 67 | + intermediate_dim=inter, |
| 68 | + shared_expert_intermediate_dim=shared_inter, |
| 69 | + num_experts=NUM_EXPERTS, |
| 70 | + num_experts_per_token=NUM_EXPERTS_PER_TOKEN, |
| 71 | + num_layers=layers, |
| 72 | + num_heads=heads, |
| 73 | + num_kv_heads=kv_heads, |
| 74 | + max_seq_len=SEQ_LEN, |
| 75 | + head_dim=None, |
| 76 | + sliding_window=SLIDING_WINDOW_SIZE if swa else None, |
| 77 | + ) |
| 78 | + |
| 79 | + |
| 80 | +def model_flops_per_token(model: GrugModelConfig) -> float: |
| 81 | + return lm_flops_per_token( |
| 82 | + hidden_dim=model.hidden_dim, |
| 83 | + intermediate_dim=model.intermediate_dim, |
| 84 | + shared_intermediate_dim=model.shared_expert_intermediate_dim, |
| 85 | + num_layers=model.num_layers, |
| 86 | + num_kv_heads=model.num_kv_heads, |
| 87 | + num_heads=model.num_heads, |
| 88 | + seq_len=model.max_seq_len, |
| 89 | + vocab_size=model.vocab_size, |
| 90 | + glu=True, |
| 91 | + num_experts=model.num_experts, |
| 92 | + num_shared_experts=1 if model.shared_expert_intermediate_dim > 0 else 0, |
| 93 | + num_experts_per_tok=model.num_experts_per_token, |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +def steps_for_budget(fpt: float, budget: float) -> int: |
| 98 | + """Compute training steps so total FLOPs ~ budget.""" |
| 99 | + tokens = budget / (3 * fpt) |
| 100 | + return max(1, round(tokens / (BATCH_SIZE * SEQ_LEN))) |
| 101 | + |
| 102 | + |
| 103 | +# --------------------------------------------------------------------------- |
| 104 | +# Common training knobs |
| 105 | +# --------------------------------------------------------------------------- |
| 106 | + |
| 107 | +_OPTIMIZER = AdamConfig( |
| 108 | + learning_rate=3e-3, |
| 109 | + weight_decay=0.1, |
| 110 | + lr_schedule="cosine", |
| 111 | + decay=0.2, |
| 112 | + min_lr_ratio=0.1, |
| 113 | + warmup=1000, |
| 114 | +) |
| 115 | + |
| 116 | +_GRUG_TRAINER = GrugTrainerConfig( |
| 117 | + z_loss_weight=1e-4, |
| 118 | + ema_beta=None, |
| 119 | + log_every=1, |
| 120 | +) |
| 121 | + |
| 122 | +_EVAL = GrugEvalConfig( |
| 123 | + eval_batch_size=512, |
| 124 | + steps_per_eval=1000, |
| 125 | + max_eval_batches=8, |
| 126 | + eval_current=True, |
| 127 | + eval_ema=False, |
| 128 | +) |
| 129 | + |
| 130 | + |
| 131 | +def _wandb(group: str) -> WandbConfig: |
| 132 | + return WandbConfig( |
| 133 | + project="marin", |
| 134 | + tags=["grug", "moe", "exp4045", "swa-ablation", "great-gate"], |
| 135 | + group=group, |
| 136 | + name=None, |
| 137 | + ) |
| 138 | + |
| 139 | + |
| 140 | +# --------------------------------------------------------------------------- |
| 141 | +# Build executor steps for every (budget, swa/full) pair |
| 142 | +# --------------------------------------------------------------------------- |
| 143 | + |
| 144 | + |
| 145 | +def _build_steps() -> list[ExecutorStep]: |
| 146 | + steps: list[ExecutorStep] = [] |
| 147 | + for budget in FLOP_BUDGETS: |
| 148 | + budget_tag = f"{budget:.0e}" |
| 149 | + for swa in (False, True): |
| 150 | + arm = "swa-4096" if swa else "full-attn" |
| 151 | + model = _make_model(budget, swa=swa) |
| 152 | + fpt = model_flops_per_token(model) |
| 153 | + num_steps = steps_for_budget(fpt, budget) |
| 154 | + run_id = _resolve_run_id(f"exp4045-{arm}-{budget_tag}") |
| 155 | + step = ExecutorStep( |
| 156 | + name=f"grug/exp4045-{arm}-{budget_tag}", |
| 157 | + fn=run_grug_moe, |
| 158 | + config=GrugMoeLaunchConfig( |
| 159 | + model=versioned(model), |
| 160 | + data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION, |
| 161 | + output_path=this_output_path(), |
| 162 | + run_id=run_id, |
| 163 | + resources=versioned(ResourceConfig.with_tpu("v5p-8")), |
| 164 | + steps=versioned(num_steps), |
| 165 | + batch_size=versioned(BATCH_SIZE), |
| 166 | + seed=versioned(0), |
| 167 | + mp=versioned("params=float32,compute=bfloat16,output=bfloat16"), |
| 168 | + tracker=_wandb(f"exp4045-swa-ablation-{budget_tag}"), |
| 169 | + optimizer=versioned(_OPTIMIZER), |
| 170 | + grug_trainer=versioned(_GRUG_TRAINER), |
| 171 | + eval=versioned(_EVAL), |
| 172 | + ), |
| 173 | + ) |
| 174 | + steps.append(step) |
| 175 | + return steps |
| 176 | + |
| 177 | + |
| 178 | +ALL_STEPS = _build_steps() |
| 179 | + |
| 180 | + |
| 181 | +if __name__ == "__main__": |
| 182 | + executor_main( |
| 183 | + steps=ALL_STEPS, |
| 184 | + description="Exp 4045: multi-budget SWA ablation for the great 10T gate. Fixes #4045.", |
| 185 | + ) |
0 commit comments