diff --git a/experiments/grug/moe/train.py b/experiments/grug/moe/train.py index acec78c972..03c4c0ff44 100644 --- a/experiments/grug/moe/train.py +++ b/experiments/grug/moe/train.py @@ -234,10 +234,11 @@ class GrugTrainState: params: Transformer opt_state: optax.OptState ema_params: Transformer | None + pending_qb_betas: jax.Array def _apply_qb_betas(model: Transformer, qb_betas: jax.Array) -> Transformer: - """Set router biases from QB betas (computed on previous step, applied on host).""" + """Set router biases from QB betas (computed on previous step).""" new_blocks = list(model.blocks) moe_idx = 0 for i, block in enumerate(model.blocks): @@ -260,11 +261,13 @@ def initial_state( ema_beta: float | None, ) -> GrugTrainState: params = mp.cast_to_param(Transformer.init(model_config, key=key)) + num_moe_layers = sum(1 for b in params.blocks if b.mlp is not None) return GrugTrainState( step=jnp.array(0, dtype=jnp.int32), params=params, opt_state=optimizer.init(params), ema_params=params if ema_beta is not None else None, + pending_qb_betas=jnp.zeros((num_moe_layers, model_config.num_experts)), ) @@ -288,6 +291,14 @@ def _make_train_step( @functools.partial(jax.jit, donate_argnums=(0,), static_argnames=("compute_watch",)) def train_step(state: GrugTrainState, batch, *, compute_watch: bool = False): + # Apply pending QB betas to router biases inside JIT (avoids eager + # host-side TPU kernel launches that can cause SPMD sync issues). + qb_params = _apply_qb_betas(state.params, state.pending_qb_betas) + if ema_beta is not None: + qb_ema_params = _apply_qb_betas(state.ema_params, state.pending_qb_betas) + else: + qb_ema_params = None + def loss_fn(params): compute_params = mp.cast_to_compute(params) return compute_params.next_token_loss( @@ -299,19 +310,19 @@ def loss_fn(params): return_router_metrics=True, ) - (loss, summarized_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) + (loss, summarized_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(qb_params) metrics = {"train/loss": loss, **summarized_metrics} - updates, opt_state = optimizer.update(grads, state.opt_state, state.params) - params = optax.apply_updates(state.params, updates) + updates, opt_state = optimizer.update(grads, state.opt_state, qb_params) + params = optax.apply_updates(qb_params, updates) if ema_beta is None: ema_params = None else: - if state.ema_params is None: + if qb_ema_params is None: raise ValueError("ema_params must be initialized when ema_beta is set.") ema_params = jax.tree_util.tree_map( lambda old, new: ema_beta * old + (1.0 - ema_beta) * new, - state.ema_params, + qb_ema_params, params, ) @@ -323,7 +334,7 @@ def loss_fn(params): include_per_parameter_norms=watch_config.include_per_parameter_norms, include_histogram=watch_config.include_histograms, split_scan_layers=watch_config.split_scan_layers, - params=state.params, + params=qb_params, grads=grads, updates=updates, opt_state=state.opt_state, @@ -336,6 +347,7 @@ def loss_fn(params): params=params, opt_state=opt_state, ema_params=ema_params, + pending_qb_betas=metrics["qb_beta_per_layer"], ) return next_state, metrics, watch_stats @@ -470,22 +482,10 @@ def _init_state(model_rng): last_loss: float | jax.Array = 0.0 last_step_duration = 0.0 - pending_qb_betas: jax.Array | None = None # Main optimization loop. try: while int(state.step) < trainer.num_train_steps: - # QB: apply router bias updates from previous step (on host). - if pending_qb_betas is not None: - state = dataclasses.replace( - state, - params=_apply_qb_betas(state.params, pending_qb_betas), - ema_params=( - _apply_qb_betas(state.ema_params, pending_qb_betas) if state.ema_params is not None else None - ), - ) - pending_qb_betas = None - with jax.profiler.TraceAnnotation("load_batch"): batch = next(iterator) step_start = time.perf_counter() @@ -498,7 +498,6 @@ def _init_state(model_rng): step = int(state.step) - 1 jax.block_until_ready(metrics["train/loss"]) - pending_qb_betas = metrics["qb_beta_per_layer"] if jnp.isnan(metrics["train/loss"]): logger.error(f"NaN loss at step {int(state.step)}. Stopping training.")