|
| 1 | +# Copyright The Marin Authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Template: grug-modular-opt trial run. |
| 5 | +
|
| 6 | +This variant shows how to use optax.multi_transform to configure multiple |
| 7 | +optimizers for different modules. See #3075 for more context. |
| 8 | +""" |
| 9 | + |
| 10 | +import dataclasses |
| 11 | +import os |
| 12 | +from dataclasses import dataclass, field |
| 13 | +from datetime import timedelta |
| 14 | + |
| 15 | +import jax |
| 16 | +import jmp |
| 17 | +import optax |
| 18 | +from fray.cluster import ResourceConfig |
| 19 | +from levanter.callbacks.profiler import ProfilerConfig |
| 20 | +from levanter.checkpoint import CheckpointerConfig |
| 21 | +from levanter.data.text import LmDataConfig |
| 22 | +from levanter.optim import OptimizerConfig |
| 23 | +from levanter.tracker import TrackerConfig |
| 24 | +from levanter.tracker.wandb import WandbConfig |
| 25 | +from levanter.trainer import TrainerConfig |
| 26 | +from levanter.utils.jax_utils import leaf_key_paths |
| 27 | +from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned |
| 28 | +from marin.processing.tokenize import add_validation_sets_to_mixture |
| 29 | + |
| 30 | +from experiments.defaults import default_validation_sets |
| 31 | +from experiments.grug.modular_opt.model import GrugModelConfig |
| 32 | +from experiments.grug.modular_opt.train import GrugEvalConfig, GrugRunConfig, GrugTrainerConfig, run_grug |
| 33 | +from experiments.tootsie.exp1295_32b import nemotron_mix |
| 34 | + |
| 35 | + |
| 36 | +@OptimizerConfig.register_subclass("grug_param_group_adam") |
| 37 | +@dataclass(frozen=True) |
| 38 | +class GrugParamGroupAdamConfig(OptimizerConfig): |
| 39 | + """AdamW with path-based parameter groups. |
| 40 | +
|
| 41 | + Group routing is string-pattern based, so one rule can apply to many layers |
| 42 | + without listing each layer explicitly. |
| 43 | + """ |
| 44 | + |
| 45 | + beta1: float = 0.9 |
| 46 | + beta2: float = 0.95 |
| 47 | + epsilon: float = 1e-8 |
| 48 | + max_grad_norm: float | None = 1.0 |
| 49 | + |
| 50 | + embed_head_lr_multiplier: float = 0.5 |
| 51 | + embed_head_weight_decay_multiplier: float = 0.0 |
| 52 | + embed_head_beta1: float = 0.5 |
| 53 | + embed_head_beta2: float = 0.95 |
| 54 | + |
| 55 | + special_lr_multiplier: float = 5.0 |
| 56 | + special_weight_decay_multiplier: float = 0.0 |
| 57 | + special_beta1: float = 0.9 |
| 58 | + special_beta2: float = 0.99 |
| 59 | + |
| 60 | + embed_head_patterns: tuple[str, ...] = ("embed", "lm_head") |
| 61 | + special_patterns: tuple[str, ...] = ("scalar", "gate", "lambda") |
| 62 | + no_decay_patterns: tuple[str, ...] = ("norm", "bias") |
| 63 | + |
| 64 | + def build(self, num_train_steps): |
| 65 | + lr_schedule = self.lr_scheduler(num_train_steps) |
| 66 | + embed_patterns = tuple(pattern.lower() for pattern in self.embed_head_patterns) |
| 67 | + special_patterns = tuple(pattern.lower() for pattern in self.special_patterns) |
| 68 | + no_decay_patterns = tuple(pattern.lower() for pattern in self.no_decay_patterns) |
| 69 | + |
| 70 | + def _group_transform( |
| 71 | + *, |
| 72 | + learning_rate: float, |
| 73 | + weight_decay_multiplier: float, |
| 74 | + beta1: float, |
| 75 | + beta2: float, |
| 76 | + lr_multiplier: float, |
| 77 | + ) -> optax.GradientTransformation: |
| 78 | + components: list[optax.GradientTransformation] = [ |
| 79 | + optax.scale_by_adam(b1=beta1, b2=beta2, eps=self.epsilon), |
| 80 | + ] |
| 81 | + decayed_weight = self.weight_decay * weight_decay_multiplier |
| 82 | + if decayed_weight > 0: |
| 83 | + components.append(optax.add_decayed_weights(decayed_weight)) |
| 84 | + components.append(optax.scale(-learning_rate * lr_multiplier)) |
| 85 | + return optax.chain(*components) |
| 86 | + |
| 87 | + def _create_mask(params): |
| 88 | + paths = leaf_key_paths(params) |
| 89 | + |
| 90 | + def _label_for_path(_, path): |
| 91 | + path_str = ".".join(path) if isinstance(path, (list, tuple)) else str(path) |
| 92 | + path_lower = path_str.lower() |
| 93 | + if any(pattern in path_lower for pattern in no_decay_patterns): |
| 94 | + return "no_decay" |
| 95 | + if any(pattern in path_lower for pattern in special_patterns): |
| 96 | + return "special" |
| 97 | + if any(pattern in path_lower for pattern in embed_patterns): |
| 98 | + return "embed_head" |
| 99 | + return "default" |
| 100 | + |
| 101 | + return jax.tree.map(_label_for_path, params, paths) |
| 102 | + |
| 103 | + def _optimizer(base_learning_rate): |
| 104 | + transforms = { |
| 105 | + "default": _group_transform( |
| 106 | + learning_rate=base_learning_rate, |
| 107 | + weight_decay_multiplier=1.0, |
| 108 | + beta1=self.beta1, |
| 109 | + beta2=self.beta2, |
| 110 | + lr_multiplier=1.0, |
| 111 | + ), |
| 112 | + "no_decay": _group_transform( |
| 113 | + learning_rate=base_learning_rate, |
| 114 | + weight_decay_multiplier=0.0, |
| 115 | + beta1=self.beta1, |
| 116 | + beta2=self.beta2, |
| 117 | + lr_multiplier=1.0, |
| 118 | + ), |
| 119 | + "embed_head": _group_transform( |
| 120 | + learning_rate=base_learning_rate, |
| 121 | + weight_decay_multiplier=self.embed_head_weight_decay_multiplier, |
| 122 | + beta1=self.embed_head_beta1, |
| 123 | + beta2=self.embed_head_beta2, |
| 124 | + lr_multiplier=self.embed_head_lr_multiplier, |
| 125 | + ), |
| 126 | + "special": _group_transform( |
| 127 | + learning_rate=base_learning_rate, |
| 128 | + weight_decay_multiplier=self.special_weight_decay_multiplier, |
| 129 | + beta1=self.special_beta1, |
| 130 | + beta2=self.special_beta2, |
| 131 | + lr_multiplier=self.special_lr_multiplier, |
| 132 | + ), |
| 133 | + } |
| 134 | + grouped = optax.multi_transform(transforms, _create_mask) |
| 135 | + if self.max_grad_norm is None: |
| 136 | + return grouped |
| 137 | + return optax.chain(optax.clip_by_global_norm(self.max_grad_norm), grouped) |
| 138 | + |
| 139 | + return optax.inject_hyperparams(_optimizer)(base_learning_rate=lr_schedule) |
| 140 | + |
| 141 | + |
| 142 | +@dataclass(frozen=True) |
| 143 | +class GrugModularOptLaunchConfig: |
| 144 | + """Last-mile run config for the modular-opt grug variant.""" |
| 145 | + |
| 146 | + model: GrugModelConfig |
| 147 | + data: LmDataConfig |
| 148 | + output_path: str |
| 149 | + run_id: str |
| 150 | + resources: ResourceConfig |
| 151 | + steps: int |
| 152 | + batch_size: int |
| 153 | + seed: int |
| 154 | + mp: str # jmp policy string, e.g. "params=float32,compute=bfloat16,output=bfloat16". |
| 155 | + tracker: TrackerConfig |
| 156 | + optimizer: OptimizerConfig |
| 157 | + grug_trainer: GrugTrainerConfig = field(default_factory=GrugTrainerConfig) |
| 158 | + eval: GrugEvalConfig | None = field(default_factory=GrugEvalConfig) |
| 159 | + |
| 160 | + |
| 161 | +GRUG_130M_MODEL = GrugModelConfig( |
| 162 | + vocab_size=128_256, |
| 163 | + hidden_dim=512, |
| 164 | + intermediate_dim=1792, |
| 165 | + num_layers=6, |
| 166 | + num_heads=8, |
| 167 | + num_kv_heads=8, |
| 168 | + max_seq_len=4096, |
| 169 | + head_dim=None, |
| 170 | +) |
| 171 | + |
| 172 | +NEMOTRON_MIX_WITH_DEFAULT_VALIDATION = add_validation_sets_to_mixture( |
| 173 | + nemotron_mix, |
| 174 | + default_validation_sets(tokenizer=nemotron_mix.tokenizer), |
| 175 | +) |
| 176 | + |
| 177 | + |
| 178 | +def _resolve_run_id(default_run_id: str) -> str: |
| 179 | + """Resolve run id and append `FERRY_DATE` when launching from ferry workflows.""" |
| 180 | + run_id = os.environ.get("GRUG_RUN_ID", default_run_id) |
| 181 | + ferry_date = os.environ.get("FERRY_DATE") |
| 182 | + if ferry_date: |
| 183 | + run_id = f"{run_id}-{ferry_date}" |
| 184 | + return run_id |
| 185 | + |
| 186 | + |
| 187 | +def _resolve_tracker(tracker: TrackerConfig, run_id: str) -> TrackerConfig: |
| 188 | + if isinstance(tracker, WandbConfig): |
| 189 | + return dataclasses.replace(tracker, name=run_id) |
| 190 | + return tracker |
| 191 | + |
| 192 | + |
| 193 | +def run_grug_modular_opt_trial(config: GrugModularOptLaunchConfig) -> None: |
| 194 | + # Map template launch knobs onto full Levanter TrainerConfig. |
| 195 | + trainer = TrainerConfig( |
| 196 | + id=config.run_id, |
| 197 | + seed=config.seed, |
| 198 | + train_batch_size=config.batch_size, |
| 199 | + num_train_steps=config.steps, |
| 200 | + profiler=ProfilerConfig(enabled=False, start_step=5, num_steps=100, perfetto_link=False), |
| 201 | + mp=jmp.get_policy(config.mp), |
| 202 | + tracker=_resolve_tracker(config.tracker, config.run_id), |
| 203 | + use_explicit_mesh_axes=True, |
| 204 | + require_accelerator=True, |
| 205 | + allow_nondivisible_batch_size=False, |
| 206 | + checkpointer=CheckpointerConfig( |
| 207 | + base_path=os.path.join(config.output_path, "checkpoints"), |
| 208 | + append_run_id_to_base_path=False, |
| 209 | + save_interval=timedelta(minutes=10), |
| 210 | + keep=[{"every": 1000}], |
| 211 | + ), |
| 212 | + ) |
| 213 | + |
| 214 | + grug_trainer = dataclasses.replace(config.grug_trainer, trainer=trainer) |
| 215 | + |
| 216 | + run_config = GrugRunConfig( |
| 217 | + model=config.model, |
| 218 | + data=config.data, |
| 219 | + resources=config.resources, |
| 220 | + optimizer=config.optimizer, |
| 221 | + trainer=grug_trainer, |
| 222 | + eval=config.eval, |
| 223 | + ) |
| 224 | + run_grug(run_config) |
| 225 | + |
| 226 | + |
| 227 | +RESOLVED_RUN_ID = _resolve_run_id("grug-modular-opt-trial") |
| 228 | + |
| 229 | + |
| 230 | +grug_modular_opt_trial = ExecutorStep( |
| 231 | + name="grug/modular-opt-trial", |
| 232 | + fn=run_grug_modular_opt_trial, |
| 233 | + config=GrugModularOptLaunchConfig( |
| 234 | + model=versioned(GRUG_130M_MODEL), |
| 235 | + data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION, |
| 236 | + # this_output_path() resolves to this step's output root (e.g. gs://.../grug/modular-opt-trial-<version>). |
| 237 | + output_path=this_output_path(), |
| 238 | + # Keep run id out of versioning so changing job metadata doesn't create a new output path. |
| 239 | + run_id=RESOLVED_RUN_ID, |
| 240 | + resources=versioned(ResourceConfig.with_tpu("v4-8")), |
| 241 | + steps=versioned(2_000), |
| 242 | + batch_size=versioned(512), |
| 243 | + seed=versioned(0), |
| 244 | + mp=versioned("params=float32,compute=bfloat16,output=bfloat16"), |
| 245 | + tracker=WandbConfig( |
| 246 | + project="marin", |
| 247 | + tags=["grug", "template", "modular_opt", "issue-3075"], |
| 248 | + group="grug-modular-opt-trial", |
| 249 | + name=None, # filled from run_id in _resolve_tracker |
| 250 | + ), |
| 251 | + optimizer=versioned( |
| 252 | + GrugParamGroupAdamConfig( |
| 253 | + learning_rate=3e-3, |
| 254 | + weight_decay=0.1, |
| 255 | + lr_schedule="cosine", |
| 256 | + decay=0.2, |
| 257 | + min_lr_ratio=0.1, |
| 258 | + warmup=1000, |
| 259 | + embed_head_lr_multiplier=0.5, |
| 260 | + embed_head_weight_decay_multiplier=0.0, |
| 261 | + embed_head_beta1=0.5, |
| 262 | + embed_head_beta2=0.95, |
| 263 | + special_lr_multiplier=5.0, |
| 264 | + special_weight_decay_multiplier=0.0, |
| 265 | + special_beta1=0.9, |
| 266 | + special_beta2=0.99, |
| 267 | + special_patterns=("scalar", "gate", "lambda"), |
| 268 | + ) |
| 269 | + ), |
| 270 | + grug_trainer=versioned( |
| 271 | + GrugTrainerConfig( |
| 272 | + z_loss_weight=1e-4, |
| 273 | + ema_beta=None, |
| 274 | + log_every=1, |
| 275 | + ) |
| 276 | + ), |
| 277 | + eval=versioned( |
| 278 | + GrugEvalConfig( |
| 279 | + eval_batch_size=512, |
| 280 | + steps_per_eval=1000, |
| 281 | + max_eval_batches=8, |
| 282 | + eval_current=True, |
| 283 | + eval_ema=False, |
| 284 | + ) |
| 285 | + ), |
| 286 | + ), |
| 287 | +) |
| 288 | + |
| 289 | + |
| 290 | +if __name__ == "__main__": |
| 291 | + executor_main( |
| 292 | + steps=[grug_modular_opt_trial], |
| 293 | + description="Template grug modular-opt 130M trial run (~2000 steps) with parameter-group Adam.", |
| 294 | + ) |
0 commit comments