2424import levanter .tracker
2525from levanter .callbacks .state_adapter import StateCallbackRunner
2626from levanter .callbacks .watch import WatchConfig , compute_watch_stats
27- from experiments .grug .checkpointing import restore_grug_state_from_checkpoint
2827from levanter .data import AsyncDataset , DataLoader
2928from levanter .data .mixture import MixtureDataset , rescale_mixture_schedule_for_batch_schedule
3029from levanter .data .text import GrugLmExample , LmDataConfig
3938from levanter .utils .logging import LoadingTimeTrackerIterator
4039
4140import equinox as eqx
41+ from experiments .grug .checkpointing import restore_grug_state_from_checkpoint
4242from experiments .grug .dispatch import dispatch_grug_training_run
4343from experiments .grug .moe .model import GrugModelConfig , Transformer
4444
45+ # This file intentionally mirrors `experiments/grug/base/train.py` with
46+ # variant-specific model/loss/FLOP wiring, per the grug copy-first workflow in
47+ # `.agents/skills/change-grug/`.
48+
4549logger = logging .getLogger (__name__ )
4650
4751
@@ -53,8 +57,8 @@ class GrugTrainerConfig:
5357 train_batch_pspec : P = field (default_factory = lambda : P (("data" , "expert" )))
5458 data_seed : int | None = None
5559 log_every : int = 1
56- ema_beta : float | None = None
57- z_loss_weight : float = 0.0
60+ ema_beta : float | None = None # EMA coefficient for eval/checkpoint model; None disables EMA.
61+ z_loss_weight : float = 0.0 # Weight on logsumexp (z-loss) stabilization term.
5862
5963
6064@dataclass (frozen = True )
@@ -114,6 +118,7 @@ def build_train_loader(
114118 mesh : Mesh ,
115119 batch_pspec : P = P (("data" , "expert" )),
116120) -> DataLoader [GrugLmExample ]:
121+ # DataLoader uses this batch axis mapping to shard batches across the distributed mesh.
117122 axis_resource = batch_pspec [0 ]
118123 return DataLoader (
119124 dataset ,
@@ -181,6 +186,7 @@ def _compute_flops(
181186 flops_per_token = lm_flops_per_token (
182187 hidden_dim = model_config .hidden_dim ,
183188 intermediate_dim = model_config .intermediate_dim ,
189+ shared_intermediate_dim = model_config .shared_expert_intermediate_dim ,
184190 num_layers = model_config .num_layers ,
185191 num_kv_heads = model_config .num_kv_heads ,
186192 num_heads = model_config .num_heads ,
@@ -227,7 +233,22 @@ class GrugTrainState:
227233 step : jax .Array
228234 params : Transformer
229235 opt_state : optax .OptState
230- ema_params : Transformer
236+ ema_params : Transformer | None
237+
238+
239+ def _apply_qb_betas (model : Transformer , qb_betas : jax .Array ) -> Transformer :
240+ """Set router biases from QB betas (computed on previous step, applied on host)."""
241+ new_blocks = list (model .blocks )
242+ moe_idx = 0
243+ for i , block in enumerate (model .blocks ):
244+ if block .mlp is None :
245+ continue
246+ new_bias = - qb_betas [moe_idx ]
247+ new_bias = new_bias - jnp .mean (new_bias )
248+ new_mlp = eqx .tree_at (lambda m : m .router_bias , block .mlp , new_bias )
249+ new_blocks [i ] = eqx .tree_at (lambda b : b .mlp , block , new_mlp )
250+ moe_idx += 1
251+ return eqx .tree_at (lambda t : t .blocks , model , tuple (new_blocks ))
231252
232253
233254def initial_state (
@@ -236,13 +257,14 @@ def initial_state(
236257 optimizer : optax .GradientTransformation ,
237258 mp : jmp .Policy ,
238259 key : PRNGKeyArray ,
260+ ema_beta : float | None ,
239261) -> GrugTrainState :
240262 params = mp .cast_to_param (Transformer .init (model_config , key = key ))
241263 return GrugTrainState (
242264 step = jnp .array (0 , dtype = jnp .int32 ),
243265 params = params ,
244266 opt_state = optimizer .init (params ),
245- ema_params = params ,
267+ ema_params = params if ema_beta is not None else None ,
246268 )
247269
248270
@@ -282,23 +304,11 @@ def loss_fn(params):
282304 updates , opt_state = optimizer .update (grads , state .opt_state , state .params )
283305 params = optax .apply_updates (state .params , updates )
284306
285- # Sharded QB: set router_bias = -(qb_beta - mean(qb_beta)) inside JIT.
286- qb_betas = summarized_metrics ["qb_beta_per_layer" ]
287- new_blocks = list (params .blocks )
288- moe_idx = 0
289- for i , block in enumerate (params .blocks ):
290- if block .mlp is not None :
291- new_bias = - qb_betas [moe_idx ]
292- new_bias = new_bias - jnp .mean (new_bias )
293- new_mlp = eqx .tree_at (lambda m : m .router_bias , block .mlp , new_bias )
294- new_blocks [i ] = eqx .tree_at (lambda b : b .mlp , block , new_mlp )
295- metrics [f"moe_bias/layer_{ moe_idx } /bias_norm" ] = jnp .linalg .norm (new_bias )
296- moe_idx += 1
297- params = eqx .tree_at (lambda t : t .blocks , params , tuple (new_blocks ))
298-
299307 if ema_beta is None :
300- ema_params = params
308+ ema_params = None
301309 else :
310+ if state .ema_params is None :
311+ raise ValueError ("ema_params must be initialized when ema_beta is set." )
302312 ema_params = jax .tree_util .tree_map (
303313 lambda old , new : ema_beta * old + (1.0 - ema_beta ) * new ,
304314 state .ema_params ,
@@ -357,6 +367,7 @@ def _run_grug_local(config: GrugRunConfig) -> None:
357367 if config .trainer .data_seed is not None :
358368 data_key = jax .random .PRNGKey (config .trainer .data_seed )
359369
370+ # Build data/model state under the trainer mesh so all arrays are sharded consistently.
360371 with trainer .use_device_mesh ():
361372 mesh = trainer .device_mesh
362373 batch_schedule = trainer .batch_schedule
@@ -381,6 +392,7 @@ def _init_state(model_rng):
381392 optimizer = optimizer ,
382393 mp = trainer .mp ,
383394 key = model_rng ,
395+ ema_beta = config .trainer .ema_beta ,
384396 )
385397
386398 state = _init_state (model_key )
@@ -422,7 +434,7 @@ def _init_state(model_rng):
422434 state_callbacks = StateCallbackRunner [GrugTrainState ](
423435 step_getter = lambda s : s .step ,
424436 model_getter = lambda s : s .params ,
425- eval_model_getter = lambda s : s .ema_params ,
437+ eval_model_getter = lambda s : s .ema_params if s . ema_params is not None else s . params ,
426438 opt_state_getter = lambda s : s .opt_state ,
427439 )
428440 state_callbacks .add_hook (
@@ -458,21 +470,35 @@ def _init_state(model_rng):
458470
459471 last_loss : float | jax .Array = 0.0
460472 last_step_duration = 0.0
473+ pending_qb_betas : jax .Array | None = None
461474
462475 # Main optimization loop.
463476 try :
464477 while int (state .step ) < trainer .num_train_steps :
478+ # QB: apply router bias updates from previous step (on host).
479+ if pending_qb_betas is not None :
480+ state = dataclasses .replace (
481+ state ,
482+ params = _apply_qb_betas (state .params , pending_qb_betas ),
483+ ema_params = (
484+ _apply_qb_betas (state .ema_params , pending_qb_betas ) if state .ema_params is not None else None
485+ ),
486+ )
487+ pending_qb_betas = None
488+
465489 with jax .profiler .TraceAnnotation ("load_batch" ):
466490 batch = next (iterator )
467491 step_start = time .perf_counter ()
468492 current_step = int (state .step )
493+ # grad_watch runs only on its configured interval.
469494 compute_watch = (
470495 watch_config .is_enabled and watch_config .interval > 0 and current_step % watch_config .interval == 0
471496 )
472497 state , metrics , watch_stats = train_step (state , batch , compute_watch = compute_watch )
473498 step = int (state .step ) - 1
474499
475500 jax .block_until_ready (metrics ["train/loss" ])
501+ pending_qb_betas = metrics ["qb_beta_per_layer" ]
476502
477503 if jnp .isnan (metrics ["train/loss" ]):
478504 logger .error (f"NaN loss at step { int (state .step )} . Stopping training." )
@@ -488,8 +514,8 @@ def _init_state(model_rng):
488514 router_metrics = {
489515 key : value
490516 for key , value in metrics .items ()
491- if (key .startswith ("train/router/" ) or key .startswith ("moe/" ) or key . startswith ( " moe_bias/" ))
492- and key not in ("train/router/routing_counts_per_layer" ,)
517+ if (key .startswith ("train/router/" ) or key .startswith ("moe_bias/" ))
518+ and key not in ("train/router/routing_counts_per_layer" , "qb_beta_per_layer" )
493519 }
494520 if router_metrics :
495521 levanter .tracker .log (router_metrics , step = step )
@@ -504,7 +530,13 @@ def _init_state(model_rng):
504530
505531 if checkpointer is not None :
506532 checkpointer .on_step (tree = state , step = int (state .step ))
507- finally :
533+ except BaseException :
534+ logger .exception (
535+ "Fatal error in grug training loop; skipping final callbacks/checkpoint to preserve root cause"
536+ )
537+ raise
538+ else :
539+ # Mirror classic trainer behavior: force callbacks on the last completed step.
508540 state_callbacks .run (state , loss = last_loss , step_duration = last_step_duration , force = True )
509541 if checkpointer is not None :
510542 checkpointer .on_step (tree = state , step = int (state .step ), force = True )
0 commit comments