|
| 1 | +# Copyright The Marin Authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Great 10T ablation: sweep num_dense_layers across isoflop budgets. |
| 5 | +
|
| 6 | +Generates an isoflop grid varying num_dense_layers (0, 1, 2, 4) at three FLOP |
| 7 | +budgets (1e18, 3e18, 1e19) with five hidden dims per budget. Architecture |
| 8 | +follows the E=128, K=4 MoE recipe from iteration-02 scaling work. |
| 9 | +
|
| 10 | +See: https://github.com/marin-community/marin/issues/4040 |
| 11 | +Parent sweep: https://github.com/marin-community/marin/issues/3469 |
| 12 | +Gate: https://github.com/marin-community/marin/issues/4014 |
| 13 | +""" |
| 14 | + |
| 15 | +import math |
| 16 | +from dataclasses import replace |
| 17 | + |
| 18 | +from fray.cluster import ResourceConfig |
| 19 | +from levanter.optim import AdamConfig |
| 20 | +from levanter.tracker.wandb import WandbConfig |
| 21 | +from levanter.utils.flop_utils import lm_flops_per_token |
| 22 | +from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned |
| 23 | + |
| 24 | +from experiments.grug.moe.launch import ( |
| 25 | + NEMOTRON_MIX_WITH_DEFAULT_VALIDATION, |
| 26 | + GrugMoeLaunchConfig, |
| 27 | + run_grug_moe, |
| 28 | +) |
| 29 | +from experiments.grug.moe.model import GrugModelConfig |
| 30 | +from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig |
| 31 | + |
| 32 | +# --------------------------------------------------------------------------- |
| 33 | +# Sweep axes |
| 34 | +# --------------------------------------------------------------------------- |
| 35 | +FLOP_BUDGETS: tuple[float, ...] = (1e18, 3e18, 1e19) |
| 36 | +HIDDEN_DIMS: tuple[int, ...] = (512, 768, 1024, 1536, 2048) |
| 37 | +NUM_DENSE_LAYERS_VALUES: tuple[int, ...] = (0, 1, 2, 4) |
| 38 | + |
| 39 | +# --------------------------------------------------------------------------- |
| 40 | +# Architecture constants (E=128, K=4 recipe) |
| 41 | +# --------------------------------------------------------------------------- |
| 42 | +VOCAB_SIZE = 128_256 |
| 43 | +NUM_EXPERTS = 128 |
| 44 | +NUM_EXPERTS_PER_TOKEN = 4 |
| 45 | +SEQ_LEN = 4096 |
| 46 | +TARGET_STEPS_LOG2 = 14 # ~16384 steps per run |
| 47 | + |
| 48 | +# Scaling heuristics (from iteration-02 MoE recipe). |
| 49 | +BASE_HIDDEN_LAYER_RATIO = 64 |
| 50 | +LAYER_SCALING_FACTOR = 4.0 |
| 51 | +LAYER_FORMULA_OFFSET = 9 |
| 52 | +LR_CONSTANT = 0.33 |
| 53 | +MAX_LR = 0.01 |
| 54 | +BETA2_BASE = 0.98 |
| 55 | +BETA2_BATCH_DIVISOR = 128 |
| 56 | + |
| 57 | + |
| 58 | +def _compute_num_layers(hidden_dim: int) -> int: |
| 59 | + hs_pow = math.log2(hidden_dim) |
| 60 | + return round(hidden_dim / (BASE_HIDDEN_LAYER_RATIO + LAYER_SCALING_FACTOR * hs_pow - LAYER_FORMULA_OFFSET)) |
| 61 | + |
| 62 | + |
| 63 | +def _round_up_pow2(x: float) -> int: |
| 64 | + if x <= 1: |
| 65 | + return 1 |
| 66 | + return 2 ** math.ceil(math.log2(x)) |
| 67 | + |
| 68 | + |
| 69 | +def _flops_per_token_for_config(cfg: GrugModelConfig) -> float: |
| 70 | + """Compute forward-pass FLOPs per token for a mixed dense/MoE model.""" |
| 71 | + num_moe_layers = cfg.num_layers - cfg.num_dense_layers |
| 72 | + moe_flops = lm_flops_per_token( |
| 73 | + hidden_dim=cfg.hidden_dim, |
| 74 | + intermediate_dim=cfg.intermediate_dim, |
| 75 | + shared_intermediate_dim=cfg.shared_expert_intermediate_dim, |
| 76 | + num_layers=num_moe_layers, |
| 77 | + num_kv_heads=cfg.num_kv_heads, |
| 78 | + num_heads=cfg.num_heads, |
| 79 | + seq_len=cfg.max_seq_len, |
| 80 | + vocab_size=cfg.vocab_size, |
| 81 | + glu=True, |
| 82 | + num_experts=cfg.num_experts, |
| 83 | + num_shared_experts=1 if cfg.shared_expert_intermediate_dim > 0 else 0, |
| 84 | + num_experts_per_tok=cfg.num_experts_per_token, |
| 85 | + ) |
| 86 | + if cfg.num_dense_layers > 0: |
| 87 | + dense_flops = lm_flops_per_token( |
| 88 | + hidden_dim=cfg.hidden_dim, |
| 89 | + intermediate_dim=cfg.resolved_dense_intermediate_dim, |
| 90 | + shared_intermediate_dim=0, |
| 91 | + num_layers=cfg.num_dense_layers, |
| 92 | + num_kv_heads=cfg.num_kv_heads, |
| 93 | + num_heads=cfg.num_heads, |
| 94 | + seq_len=cfg.max_seq_len, |
| 95 | + vocab_size=cfg.vocab_size, |
| 96 | + glu=True, |
| 97 | + num_experts=1, |
| 98 | + num_shared_experts=0, |
| 99 | + num_experts_per_tok=1, |
| 100 | + ) |
| 101 | + else: |
| 102 | + dense_flops = 0.0 |
| 103 | + return moe_flops + dense_flops |
| 104 | + |
| 105 | + |
| 106 | +def _build_model_config(hidden_dim: int, num_dense_layers: int) -> GrugModelConfig: |
| 107 | + num_layers = _compute_num_layers(hidden_dim) |
| 108 | + clamped_dense = min(num_dense_layers, num_layers) |
| 109 | + intermediate_dim = hidden_dim // 2 |
| 110 | + shared_expert_dim = hidden_dim |
| 111 | + num_heads = max(1, hidden_dim // 128) |
| 112 | + return GrugModelConfig( |
| 113 | + vocab_size=VOCAB_SIZE, |
| 114 | + hidden_dim=hidden_dim, |
| 115 | + intermediate_dim=intermediate_dim, |
| 116 | + shared_expert_intermediate_dim=shared_expert_dim, |
| 117 | + num_experts=NUM_EXPERTS, |
| 118 | + num_experts_per_token=NUM_EXPERTS_PER_TOKEN, |
| 119 | + num_layers=num_layers, |
| 120 | + num_heads=num_heads, |
| 121 | + num_kv_heads=num_heads, |
| 122 | + max_seq_len=SEQ_LEN, |
| 123 | + num_dense_layers=clamped_dense, |
| 124 | + ) |
| 125 | + |
| 126 | + |
| 127 | +def _build_sweep_step( |
| 128 | + budget: float, |
| 129 | + hidden_dim: int, |
| 130 | + num_dense_layers: int, |
| 131 | +) -> ExecutorStep: |
| 132 | + model = _build_model_config(hidden_dim, num_dense_layers) |
| 133 | + fpt = _flops_per_token_for_config(model) |
| 134 | + tokens = budget / (3 * fpt) |
| 135 | + target_steps = 2**TARGET_STEPS_LOG2 |
| 136 | + |
| 137 | + batch_exact = tokens / (target_steps * SEQ_LEN) |
| 138 | + effective_bs = _round_up_pow2(batch_exact) |
| 139 | + effective_bs = max(8, effective_bs) |
| 140 | + |
| 141 | + lr = min(MAX_LR, (LR_CONSTANT * math.sqrt(effective_bs)) / hidden_dim) |
| 142 | + beta2 = max(0.95, BETA2_BASE ** (effective_bs / BETA2_BATCH_DIVISOR)) |
| 143 | + steps = max(1, round(tokens / (effective_bs * SEQ_LEN))) |
| 144 | + |
| 145 | + budget_tag = f"{budget:.0e}" |
| 146 | + dense_tag = f"dense{num_dense_layers}" |
| 147 | + run_id = f"great-10t-fkd-{budget_tag}-d{hidden_dim}-{dense_tag}" |
| 148 | + |
| 149 | + return ExecutorStep( |
| 150 | + name=f"grug/great-10t-fkd/{budget_tag}/d{hidden_dim}/{dense_tag}", |
| 151 | + fn=run_grug_moe, |
| 152 | + config=GrugMoeLaunchConfig( |
| 153 | + model=versioned(model), |
| 154 | + data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION, |
| 155 | + output_path=this_output_path(), |
| 156 | + run_id=run_id, |
| 157 | + resources=versioned(ResourceConfig.with_tpu("v5p-8")), |
| 158 | + steps=versioned(steps), |
| 159 | + batch_size=versioned(effective_bs), |
| 160 | + seed=versioned(0), |
| 161 | + mp=versioned("params=float32,compute=bfloat16,output=bfloat16"), |
| 162 | + tracker=WandbConfig( |
| 163 | + project="marin", |
| 164 | + tags=["grug", "moe", "great-10t", "first-k-dense", budget_tag, dense_tag], |
| 165 | + group="great-10t-first-k-dense", |
| 166 | + name=None, |
| 167 | + ), |
| 168 | + optimizer=versioned( |
| 169 | + AdamConfig( |
| 170 | + learning_rate=lr, |
| 171 | + weight_decay=0.1, |
| 172 | + lr_schedule="cosine", |
| 173 | + decay=0.2, |
| 174 | + min_lr_ratio=0.1, |
| 175 | + warmup=0.1, |
| 176 | + beta2=beta2, |
| 177 | + ) |
| 178 | + ), |
| 179 | + grug_trainer=versioned( |
| 180 | + GrugTrainerConfig( |
| 181 | + z_loss_weight=1e-4, |
| 182 | + ema_beta=None, |
| 183 | + log_every=1, |
| 184 | + ) |
| 185 | + ), |
| 186 | + eval=versioned( |
| 187 | + GrugEvalConfig( |
| 188 | + eval_batch_size=min(512, effective_bs), |
| 189 | + steps_per_eval=1000, |
| 190 | + max_eval_batches=8, |
| 191 | + eval_current=True, |
| 192 | + eval_ema=False, |
| 193 | + ) |
| 194 | + ), |
| 195 | + ), |
| 196 | + ) |
| 197 | + |
| 198 | + |
| 199 | +def build_sweep_steps() -> list[ExecutorStep]: |
| 200 | + """Generate all ExecutorSteps for the first-k-dense isoflop ablation.""" |
| 201 | + steps: list[ExecutorStep] = [] |
| 202 | + for budget in FLOP_BUDGETS: |
| 203 | + for hidden_dim in HIDDEN_DIMS: |
| 204 | + for num_dense in NUM_DENSE_LAYERS_VALUES: |
| 205 | + num_layers = _compute_num_layers(hidden_dim) |
| 206 | + if num_dense > num_layers: |
| 207 | + continue |
| 208 | + steps.append(_build_sweep_step(budget, hidden_dim, num_dense)) |
| 209 | + return steps |
| 210 | + |
| 211 | + |
| 212 | +ALL_STEPS = build_sweep_steps() |
| 213 | + |
| 214 | + |
| 215 | +if __name__ == "__main__": |
| 216 | + executor_main( |
| 217 | + steps=ALL_STEPS, |
| 218 | + description="Great 10T ablation: sweep num_dense_layers across isoflop budgets (issue #4040).", |
| 219 | + ) |
0 commit comments