@@ -234,10 +234,11 @@ class GrugTrainState:
234234 params : Transformer
235235 opt_state : optax .OptState
236236 ema_params : Transformer | None
237+ pending_qb_betas : jax .Array
237238
238239
239240def _apply_qb_betas (model : Transformer , qb_betas : jax .Array ) -> Transformer :
240- """Set router biases from QB betas (computed on previous step, applied on host )."""
241+ """Set router biases from QB betas (computed on previous step)."""
241242 new_blocks = list (model .blocks )
242243 moe_idx = 0
243244 for i , block in enumerate (model .blocks ):
@@ -260,11 +261,13 @@ def initial_state(
260261 ema_beta : float | None ,
261262) -> GrugTrainState :
262263 params = mp .cast_to_param (Transformer .init (model_config , key = key ))
264+ num_moe_layers = sum (1 for b in params .blocks if b .mlp is not None )
263265 return GrugTrainState (
264266 step = jnp .array (0 , dtype = jnp .int32 ),
265267 params = params ,
266268 opt_state = optimizer .init (params ),
267269 ema_params = params if ema_beta is not None else None ,
270+ pending_qb_betas = jnp .zeros ((num_moe_layers , model_config .num_experts )),
268271 )
269272
270273
@@ -288,6 +291,14 @@ def _make_train_step(
288291
289292 @functools .partial (jax .jit , donate_argnums = (0 ,), static_argnames = ("compute_watch" ,))
290293 def train_step (state : GrugTrainState , batch , * , compute_watch : bool = False ):
294+ # Apply pending QB betas to router biases inside JIT (avoids eager
295+ # host-side TPU kernel launches that can cause SPMD sync issues).
296+ qb_params = _apply_qb_betas (state .params , state .pending_qb_betas )
297+ if ema_beta is not None :
298+ qb_ema_params = _apply_qb_betas (state .ema_params , state .pending_qb_betas )
299+ else :
300+ qb_ema_params = None
301+
291302 def loss_fn (params ):
292303 compute_params = mp .cast_to_compute (params )
293304 return compute_params .next_token_loss (
@@ -299,19 +310,19 @@ def loss_fn(params):
299310 return_router_metrics = True ,
300311 )
301312
302- (loss , summarized_metrics ), grads = jax .value_and_grad (loss_fn , has_aux = True )(state . params )
313+ (loss , summarized_metrics ), grads = jax .value_and_grad (loss_fn , has_aux = True )(qb_params )
303314 metrics = {"train/loss" : loss , ** summarized_metrics }
304- updates , opt_state = optimizer .update (grads , state .opt_state , state . params )
305- params = optax .apply_updates (state . params , updates )
315+ updates , opt_state = optimizer .update (grads , state .opt_state , qb_params )
316+ params = optax .apply_updates (qb_params , updates )
306317
307318 if ema_beta is None :
308319 ema_params = None
309320 else :
310- if state . ema_params is None :
321+ if qb_ema_params is None :
311322 raise ValueError ("ema_params must be initialized when ema_beta is set." )
312323 ema_params = jax .tree_util .tree_map (
313324 lambda old , new : ema_beta * old + (1.0 - ema_beta ) * new ,
314- state . ema_params ,
325+ qb_ema_params ,
315326 params ,
316327 )
317328
@@ -323,7 +334,7 @@ def loss_fn(params):
323334 include_per_parameter_norms = watch_config .include_per_parameter_norms ,
324335 include_histogram = watch_config .include_histograms ,
325336 split_scan_layers = watch_config .split_scan_layers ,
326- params = state . params ,
337+ params = qb_params ,
327338 grads = grads ,
328339 updates = updates ,
329340 opt_state = state .opt_state ,
@@ -336,6 +347,7 @@ def loss_fn(params):
336347 params = params ,
337348 opt_state = opt_state ,
338349 ema_params = ema_params ,
350+ pending_qb_betas = metrics ["qb_beta_per_layer" ],
339351 )
340352
341353 return next_state , metrics , watch_stats
@@ -470,22 +482,10 @@ def _init_state(model_rng):
470482
471483 last_loss : float | jax .Array = 0.0
472484 last_step_duration = 0.0
473- pending_qb_betas : jax .Array | None = None
474485
475486 # Main optimization loop.
476487 try :
477488 while int (state .step ) < trainer .num_train_steps :
478- # QB: apply router bias updates from previous step (on host).
479- if pending_qb_betas is not None :
480- state = dataclasses .replace (
481- state ,
482- params = _apply_qb_betas (state .params , pending_qb_betas ),
483- ema_params = (
484- _apply_qb_betas (state .ema_params , pending_qb_betas ) if state .ema_params is not None else None
485- ),
486- )
487- pending_qb_betas = None
488-
489489 with jax .profiler .TraceAnnotation ("load_batch" ):
490490 batch = next (iterator )
491491 step_start = time .perf_counter ()
@@ -498,7 +498,6 @@ def _init_state(model_rng):
498498 step = int (state .step ) - 1
499499
500500 jax .block_until_ready (metrics ["train/loss" ])
501- pending_qb_betas = metrics ["qb_beta_per_layer" ]
502501
503502 if jnp .isnan (metrics ["train/loss" ]):
504503 logger .error (f"NaN loss at step { int (state .step )} . Stopping training." )
0 commit comments