Skip to content

Commit f419c74

Browse files
Helw150claude
andcommitted
Fix grug/moe train.py: restore main's structure, add QB on top
Start from main's train.py and add only QB-specific changes: - _apply_qb_betas host-side bias update - pending_qb_betas pattern in training loop - NaN loss detection - qb_beta_per_layer metric filtering Preserves main's EMA handling, comments, FLOP calc, and except/else checkpoint pattern. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1670695 commit f419c74

1 file changed

Lines changed: 56 additions & 24 deletions

File tree

experiments/grug/moe/train.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import levanter.tracker
2525
from levanter.callbacks.state_adapter import StateCallbackRunner
2626
from levanter.callbacks.watch import WatchConfig, compute_watch_stats
27-
from experiments.grug.checkpointing import restore_grug_state_from_checkpoint
2827
from levanter.data import AsyncDataset, DataLoader
2928
from levanter.data.mixture import MixtureDataset, rescale_mixture_schedule_for_batch_schedule
3029
from levanter.data.text import GrugLmExample, LmDataConfig
@@ -39,9 +38,14 @@
3938
from levanter.utils.logging import LoadingTimeTrackerIterator
4039

4140
import equinox as eqx
41+
from experiments.grug.checkpointing import restore_grug_state_from_checkpoint
4242
from experiments.grug.dispatch import dispatch_grug_training_run
4343
from experiments.grug.moe.model import GrugModelConfig, Transformer
4444

45+
# This file intentionally mirrors `experiments/grug/base/train.py` with
46+
# variant-specific model/loss/FLOP wiring, per the grug copy-first workflow in
47+
# `.agents/skills/change-grug/`.
48+
4549
logger = logging.getLogger(__name__)
4650

4751

@@ -53,8 +57,8 @@ class GrugTrainerConfig:
5357
train_batch_pspec: P = field(default_factory=lambda: P(("data", "expert")))
5458
data_seed: int | None = None
5559
log_every: int = 1
56-
ema_beta: float | None = None
57-
z_loss_weight: float = 0.0
60+
ema_beta: float | None = None # EMA coefficient for eval/checkpoint model; None disables EMA.
61+
z_loss_weight: float = 0.0 # Weight on logsumexp (z-loss) stabilization term.
5862

5963

6064
@dataclass(frozen=True)
@@ -114,6 +118,7 @@ def build_train_loader(
114118
mesh: Mesh,
115119
batch_pspec: P = P(("data", "expert")),
116120
) -> DataLoader[GrugLmExample]:
121+
# DataLoader uses this batch axis mapping to shard batches across the distributed mesh.
117122
axis_resource = batch_pspec[0]
118123
return DataLoader(
119124
dataset,
@@ -181,6 +186,7 @@ def _compute_flops(
181186
flops_per_token = lm_flops_per_token(
182187
hidden_dim=model_config.hidden_dim,
183188
intermediate_dim=model_config.intermediate_dim,
189+
shared_intermediate_dim=model_config.shared_expert_intermediate_dim,
184190
num_layers=model_config.num_layers,
185191
num_kv_heads=model_config.num_kv_heads,
186192
num_heads=model_config.num_heads,
@@ -227,7 +233,22 @@ class GrugTrainState:
227233
step: jax.Array
228234
params: Transformer
229235
opt_state: optax.OptState
230-
ema_params: Transformer
236+
ema_params: Transformer | None
237+
238+
239+
def _apply_qb_betas(model: Transformer, qb_betas: jax.Array) -> Transformer:
240+
"""Set router biases from QB betas (computed on previous step, applied on host)."""
241+
new_blocks = list(model.blocks)
242+
moe_idx = 0
243+
for i, block in enumerate(model.blocks):
244+
if block.mlp is None:
245+
continue
246+
new_bias = -qb_betas[moe_idx]
247+
new_bias = new_bias - jnp.mean(new_bias)
248+
new_mlp = eqx.tree_at(lambda m: m.router_bias, block.mlp, new_bias)
249+
new_blocks[i] = eqx.tree_at(lambda b: b.mlp, block, new_mlp)
250+
moe_idx += 1
251+
return eqx.tree_at(lambda t: t.blocks, model, tuple(new_blocks))
231252

232253

233254
def initial_state(
@@ -236,13 +257,14 @@ def initial_state(
236257
optimizer: optax.GradientTransformation,
237258
mp: jmp.Policy,
238259
key: PRNGKeyArray,
260+
ema_beta: float | None,
239261
) -> GrugTrainState:
240262
params = mp.cast_to_param(Transformer.init(model_config, key=key))
241263
return GrugTrainState(
242264
step=jnp.array(0, dtype=jnp.int32),
243265
params=params,
244266
opt_state=optimizer.init(params),
245-
ema_params=params,
267+
ema_params=params if ema_beta is not None else None,
246268
)
247269

248270

@@ -282,23 +304,11 @@ def loss_fn(params):
282304
updates, opt_state = optimizer.update(grads, state.opt_state, state.params)
283305
params = optax.apply_updates(state.params, updates)
284306

285-
# Sharded QB: set router_bias = -(qb_beta - mean(qb_beta)) inside JIT.
286-
qb_betas = summarized_metrics["qb_beta_per_layer"]
287-
new_blocks = list(params.blocks)
288-
moe_idx = 0
289-
for i, block in enumerate(params.blocks):
290-
if block.mlp is not None:
291-
new_bias = -qb_betas[moe_idx]
292-
new_bias = new_bias - jnp.mean(new_bias)
293-
new_mlp = eqx.tree_at(lambda m: m.router_bias, block.mlp, new_bias)
294-
new_blocks[i] = eqx.tree_at(lambda b: b.mlp, block, new_mlp)
295-
metrics[f"moe_bias/layer_{moe_idx}/bias_norm"] = jnp.linalg.norm(new_bias)
296-
moe_idx += 1
297-
params = eqx.tree_at(lambda t: t.blocks, params, tuple(new_blocks))
298-
299307
if ema_beta is None:
300-
ema_params = params
308+
ema_params = None
301309
else:
310+
if state.ema_params is None:
311+
raise ValueError("ema_params must be initialized when ema_beta is set.")
302312
ema_params = jax.tree_util.tree_map(
303313
lambda old, new: ema_beta * old + (1.0 - ema_beta) * new,
304314
state.ema_params,
@@ -357,6 +367,7 @@ def _run_grug_local(config: GrugRunConfig) -> None:
357367
if config.trainer.data_seed is not None:
358368
data_key = jax.random.PRNGKey(config.trainer.data_seed)
359369

370+
# Build data/model state under the trainer mesh so all arrays are sharded consistently.
360371
with trainer.use_device_mesh():
361372
mesh = trainer.device_mesh
362373
batch_schedule = trainer.batch_schedule
@@ -381,6 +392,7 @@ def _init_state(model_rng):
381392
optimizer=optimizer,
382393
mp=trainer.mp,
383394
key=model_rng,
395+
ema_beta=config.trainer.ema_beta,
384396
)
385397

386398
state = _init_state(model_key)
@@ -422,7 +434,7 @@ def _init_state(model_rng):
422434
state_callbacks = StateCallbackRunner[GrugTrainState](
423435
step_getter=lambda s: s.step,
424436
model_getter=lambda s: s.params,
425-
eval_model_getter=lambda s: s.ema_params,
437+
eval_model_getter=lambda s: s.ema_params if s.ema_params is not None else s.params,
426438
opt_state_getter=lambda s: s.opt_state,
427439
)
428440
state_callbacks.add_hook(
@@ -458,21 +470,35 @@ def _init_state(model_rng):
458470

459471
last_loss: float | jax.Array = 0.0
460472
last_step_duration = 0.0
473+
pending_qb_betas: jax.Array | None = None
461474

462475
# Main optimization loop.
463476
try:
464477
while int(state.step) < trainer.num_train_steps:
478+
# QB: apply router bias updates from previous step (on host).
479+
if pending_qb_betas is not None:
480+
state = dataclasses.replace(
481+
state,
482+
params=_apply_qb_betas(state.params, pending_qb_betas),
483+
ema_params=(
484+
_apply_qb_betas(state.ema_params, pending_qb_betas) if state.ema_params is not None else None
485+
),
486+
)
487+
pending_qb_betas = None
488+
465489
with jax.profiler.TraceAnnotation("load_batch"):
466490
batch = next(iterator)
467491
step_start = time.perf_counter()
468492
current_step = int(state.step)
493+
# grad_watch runs only on its configured interval.
469494
compute_watch = (
470495
watch_config.is_enabled and watch_config.interval > 0 and current_step % watch_config.interval == 0
471496
)
472497
state, metrics, watch_stats = train_step(state, batch, compute_watch=compute_watch)
473498
step = int(state.step) - 1
474499

475500
jax.block_until_ready(metrics["train/loss"])
501+
pending_qb_betas = metrics["qb_beta_per_layer"]
476502

477503
if jnp.isnan(metrics["train/loss"]):
478504
logger.error(f"NaN loss at step {int(state.step)}. Stopping training.")
@@ -488,8 +514,8 @@ def _init_state(model_rng):
488514
router_metrics = {
489515
key: value
490516
for key, value in metrics.items()
491-
if (key.startswith("train/router/") or key.startswith("moe/") or key.startswith("moe_bias/"))
492-
and key not in ("train/router/routing_counts_per_layer",)
517+
if (key.startswith("train/router/") or key.startswith("moe_bias/"))
518+
and key not in ("train/router/routing_counts_per_layer", "qb_beta_per_layer")
493519
}
494520
if router_metrics:
495521
levanter.tracker.log(router_metrics, step=step)
@@ -504,7 +530,13 @@ def _init_state(model_rng):
504530

505531
if checkpointer is not None:
506532
checkpointer.on_step(tree=state, step=int(state.step))
507-
finally:
533+
except BaseException:
534+
logger.exception(
535+
"Fatal error in grug training loop; skipping final callbacks/checkpoint to preserve root cause"
536+
)
537+
raise
538+
else:
539+
# Mirror classic trainer behavior: force callbacks on the last completed step.
508540
state_callbacks.run(state, loss=last_loss, step_duration=last_step_duration, force=True)
509541
if checkpointer is not None:
510542
checkpointer.on_step(tree=state, step=int(state.step), force=True)

0 commit comments

Comments
 (0)