Skip to content

Commit f5d3a8e

Browse files
authored
Add modular_opt variant and move Grug variant docs (#3293)
## Summary - add `experiments/grug/modular_opt/` variant files - align modular-opt launch/train wiring with current grug base dispatch/resources flow - move variant-specific guidance out of `experiments/grug/README.md` into new `experiments/grug/variants.md` - update modular-opt launch docstring wording per review ## Validation - `uv run python -m py_compile experiments/grug/modular_opt/*.py` ## Notes - keeps unrelated local changes out of this PR
1 parent 79ddd1d commit f5d3a8e

6 files changed

Lines changed: 1092 additions & 0 deletions

File tree

experiments/grug/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
2. Make model/training changes in that variant, not in shared trainer libraries.
2424
3. Set run knobs in `<variant>/launch.py` (run id, data mix, optimizer, TPU type).
2525
4. Launch from the variant's `launch.py` entrypoint.
26+
5. Add or update variant-specific notes in `experiments/grug/variants.md`.
27+
28+
## Variant notes
29+
30+
Variant-specific guidance (including modular-opt notes) lives in `experiments/grug/variants.md`.
2631

2732
## Quickstart launch
2833

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Template: grug-modular-opt trial run.
5+
6+
This variant shows how to use optax.multi_transform to configure multiple
7+
optimizers for different modules. See #3075 for more context.
8+
"""
9+
10+
import dataclasses
11+
import os
12+
from dataclasses import dataclass, field
13+
from datetime import timedelta
14+
15+
import jax
16+
import jmp
17+
import optax
18+
from fray.cluster import ResourceConfig
19+
from levanter.callbacks.profiler import ProfilerConfig
20+
from levanter.checkpoint import CheckpointerConfig
21+
from levanter.data.text import LmDataConfig
22+
from levanter.optim import OptimizerConfig
23+
from levanter.tracker import TrackerConfig
24+
from levanter.tracker.wandb import WandbConfig
25+
from levanter.trainer import TrainerConfig
26+
from levanter.utils.jax_utils import leaf_key_paths
27+
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
28+
from marin.processing.tokenize import add_validation_sets_to_mixture
29+
30+
from experiments.defaults import default_validation_sets
31+
from experiments.grug.modular_opt.model import GrugModelConfig
32+
from experiments.grug.modular_opt.train import GrugEvalConfig, GrugRunConfig, GrugTrainerConfig, run_grug
33+
from experiments.tootsie.exp1295_32b import nemotron_mix
34+
35+
36+
@OptimizerConfig.register_subclass("grug_param_group_adam")
37+
@dataclass(frozen=True)
38+
class GrugParamGroupAdamConfig(OptimizerConfig):
39+
"""AdamW with path-based parameter groups.
40+
41+
Group routing is string-pattern based, so one rule can apply to many layers
42+
without listing each layer explicitly.
43+
"""
44+
45+
beta1: float = 0.9
46+
beta2: float = 0.95
47+
epsilon: float = 1e-8
48+
max_grad_norm: float | None = 1.0
49+
50+
embed_head_lr_multiplier: float = 0.5
51+
embed_head_weight_decay_multiplier: float = 0.0
52+
embed_head_beta1: float = 0.5
53+
embed_head_beta2: float = 0.95
54+
55+
special_lr_multiplier: float = 5.0
56+
special_weight_decay_multiplier: float = 0.0
57+
special_beta1: float = 0.9
58+
special_beta2: float = 0.99
59+
60+
embed_head_patterns: tuple[str, ...] = ("embed", "lm_head")
61+
special_patterns: tuple[str, ...] = ("scalar", "gate", "lambda")
62+
no_decay_patterns: tuple[str, ...] = ("norm", "bias")
63+
64+
def build(self, num_train_steps):
65+
lr_schedule = self.lr_scheduler(num_train_steps)
66+
embed_patterns = tuple(pattern.lower() for pattern in self.embed_head_patterns)
67+
special_patterns = tuple(pattern.lower() for pattern in self.special_patterns)
68+
no_decay_patterns = tuple(pattern.lower() for pattern in self.no_decay_patterns)
69+
70+
def _group_transform(
71+
*,
72+
learning_rate: float,
73+
weight_decay_multiplier: float,
74+
beta1: float,
75+
beta2: float,
76+
lr_multiplier: float,
77+
) -> optax.GradientTransformation:
78+
components: list[optax.GradientTransformation] = [
79+
optax.scale_by_adam(b1=beta1, b2=beta2, eps=self.epsilon),
80+
]
81+
decayed_weight = self.weight_decay * weight_decay_multiplier
82+
if decayed_weight > 0:
83+
components.append(optax.add_decayed_weights(decayed_weight))
84+
components.append(optax.scale(-learning_rate * lr_multiplier))
85+
return optax.chain(*components)
86+
87+
def _create_mask(params):
88+
paths = leaf_key_paths(params)
89+
90+
def _label_for_path(_, path):
91+
path_str = ".".join(path) if isinstance(path, (list, tuple)) else str(path)
92+
path_lower = path_str.lower()
93+
if any(pattern in path_lower for pattern in no_decay_patterns):
94+
return "no_decay"
95+
if any(pattern in path_lower for pattern in special_patterns):
96+
return "special"
97+
if any(pattern in path_lower for pattern in embed_patterns):
98+
return "embed_head"
99+
return "default"
100+
101+
return jax.tree.map(_label_for_path, params, paths)
102+
103+
def _optimizer(base_learning_rate):
104+
transforms = {
105+
"default": _group_transform(
106+
learning_rate=base_learning_rate,
107+
weight_decay_multiplier=1.0,
108+
beta1=self.beta1,
109+
beta2=self.beta2,
110+
lr_multiplier=1.0,
111+
),
112+
"no_decay": _group_transform(
113+
learning_rate=base_learning_rate,
114+
weight_decay_multiplier=0.0,
115+
beta1=self.beta1,
116+
beta2=self.beta2,
117+
lr_multiplier=1.0,
118+
),
119+
"embed_head": _group_transform(
120+
learning_rate=base_learning_rate,
121+
weight_decay_multiplier=self.embed_head_weight_decay_multiplier,
122+
beta1=self.embed_head_beta1,
123+
beta2=self.embed_head_beta2,
124+
lr_multiplier=self.embed_head_lr_multiplier,
125+
),
126+
"special": _group_transform(
127+
learning_rate=base_learning_rate,
128+
weight_decay_multiplier=self.special_weight_decay_multiplier,
129+
beta1=self.special_beta1,
130+
beta2=self.special_beta2,
131+
lr_multiplier=self.special_lr_multiplier,
132+
),
133+
}
134+
grouped = optax.multi_transform(transforms, _create_mask)
135+
if self.max_grad_norm is None:
136+
return grouped
137+
return optax.chain(optax.clip_by_global_norm(self.max_grad_norm), grouped)
138+
139+
return optax.inject_hyperparams(_optimizer)(base_learning_rate=lr_schedule)
140+
141+
142+
@dataclass(frozen=True)
143+
class GrugModularOptLaunchConfig:
144+
"""Last-mile run config for the modular-opt grug variant."""
145+
146+
model: GrugModelConfig
147+
data: LmDataConfig
148+
output_path: str
149+
run_id: str
150+
resources: ResourceConfig
151+
steps: int
152+
batch_size: int
153+
seed: int
154+
mp: str # jmp policy string, e.g. "params=float32,compute=bfloat16,output=bfloat16".
155+
tracker: TrackerConfig
156+
optimizer: OptimizerConfig
157+
grug_trainer: GrugTrainerConfig = field(default_factory=GrugTrainerConfig)
158+
eval: GrugEvalConfig | None = field(default_factory=GrugEvalConfig)
159+
160+
161+
GRUG_130M_MODEL = GrugModelConfig(
162+
vocab_size=128_256,
163+
hidden_dim=512,
164+
intermediate_dim=1792,
165+
num_layers=6,
166+
num_heads=8,
167+
num_kv_heads=8,
168+
max_seq_len=4096,
169+
head_dim=None,
170+
)
171+
172+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION = add_validation_sets_to_mixture(
173+
nemotron_mix,
174+
default_validation_sets(tokenizer=nemotron_mix.tokenizer),
175+
)
176+
177+
178+
def _resolve_run_id(default_run_id: str) -> str:
179+
"""Resolve run id and append `FERRY_DATE` when launching from ferry workflows."""
180+
run_id = os.environ.get("GRUG_RUN_ID", default_run_id)
181+
ferry_date = os.environ.get("FERRY_DATE")
182+
if ferry_date:
183+
run_id = f"{run_id}-{ferry_date}"
184+
return run_id
185+
186+
187+
def _resolve_tracker(tracker: TrackerConfig, run_id: str) -> TrackerConfig:
188+
if isinstance(tracker, WandbConfig):
189+
return dataclasses.replace(tracker, name=run_id)
190+
return tracker
191+
192+
193+
def run_grug_modular_opt_trial(config: GrugModularOptLaunchConfig) -> None:
194+
# Map template launch knobs onto full Levanter TrainerConfig.
195+
trainer = TrainerConfig(
196+
id=config.run_id,
197+
seed=config.seed,
198+
train_batch_size=config.batch_size,
199+
num_train_steps=config.steps,
200+
profiler=ProfilerConfig(enabled=False, start_step=5, num_steps=100, perfetto_link=False),
201+
mp=jmp.get_policy(config.mp),
202+
tracker=_resolve_tracker(config.tracker, config.run_id),
203+
use_explicit_mesh_axes=True,
204+
require_accelerator=True,
205+
allow_nondivisible_batch_size=False,
206+
checkpointer=CheckpointerConfig(
207+
base_path=os.path.join(config.output_path, "checkpoints"),
208+
append_run_id_to_base_path=False,
209+
save_interval=timedelta(minutes=10),
210+
keep=[{"every": 1000}],
211+
),
212+
)
213+
214+
grug_trainer = dataclasses.replace(config.grug_trainer, trainer=trainer)
215+
216+
run_config = GrugRunConfig(
217+
model=config.model,
218+
data=config.data,
219+
resources=config.resources,
220+
optimizer=config.optimizer,
221+
trainer=grug_trainer,
222+
eval=config.eval,
223+
)
224+
run_grug(run_config)
225+
226+
227+
RESOLVED_RUN_ID = _resolve_run_id("grug-modular-opt-trial")
228+
229+
230+
grug_modular_opt_trial = ExecutorStep(
231+
name="grug/modular-opt-trial",
232+
fn=run_grug_modular_opt_trial,
233+
config=GrugModularOptLaunchConfig(
234+
model=versioned(GRUG_130M_MODEL),
235+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
236+
# this_output_path() resolves to this step's output root (e.g. gs://.../grug/modular-opt-trial-<version>).
237+
output_path=this_output_path(),
238+
# Keep run id out of versioning so changing job metadata doesn't create a new output path.
239+
run_id=RESOLVED_RUN_ID,
240+
resources=versioned(ResourceConfig.with_tpu("v4-8")),
241+
steps=versioned(2_000),
242+
batch_size=versioned(512),
243+
seed=versioned(0),
244+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
245+
tracker=WandbConfig(
246+
project="marin",
247+
tags=["grug", "template", "modular_opt", "issue-3075"],
248+
group="grug-modular-opt-trial",
249+
name=None, # filled from run_id in _resolve_tracker
250+
),
251+
optimizer=versioned(
252+
GrugParamGroupAdamConfig(
253+
learning_rate=3e-3,
254+
weight_decay=0.1,
255+
lr_schedule="cosine",
256+
decay=0.2,
257+
min_lr_ratio=0.1,
258+
warmup=1000,
259+
embed_head_lr_multiplier=0.5,
260+
embed_head_weight_decay_multiplier=0.0,
261+
embed_head_beta1=0.5,
262+
embed_head_beta2=0.95,
263+
special_lr_multiplier=5.0,
264+
special_weight_decay_multiplier=0.0,
265+
special_beta1=0.9,
266+
special_beta2=0.99,
267+
special_patterns=("scalar", "gate", "lambda"),
268+
)
269+
),
270+
grug_trainer=versioned(
271+
GrugTrainerConfig(
272+
z_loss_weight=1e-4,
273+
ema_beta=None,
274+
log_every=1,
275+
)
276+
),
277+
eval=versioned(
278+
GrugEvalConfig(
279+
eval_batch_size=512,
280+
steps_per_eval=1000,
281+
max_eval_batches=8,
282+
eval_current=True,
283+
eval_ema=False,
284+
)
285+
),
286+
),
287+
)
288+
289+
290+
if __name__ == "__main__":
291+
executor_main(
292+
steps=[grug_modular_opt_trial],
293+
description="Template grug modular-opt 130M trial run (~2000 steps) with parameter-group Adam.",
294+
)

0 commit comments

Comments
 (0)