Skip to content

Commit e9a5e31

Browse files
committed
Simplify QB (#4458)
I am partially suspect that QB has some issues w.r.t. checkpoint resumption related to what @ClassicLarry has been experiencing since it handles some work in the main thread but uses post-step synchronization. Might not be the root cause of Larry's issues, but regardless this PR cleans up our handling of QB things a bit.
1 parent be63d53 commit e9a5e31

1 file changed

Lines changed: 19 additions & 20 deletions

File tree

experiments/grug/moe/train.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,11 @@ class GrugTrainState:
234234
params: Transformer
235235
opt_state: optax.OptState
236236
ema_params: Transformer | None
237+
pending_qb_betas: jax.Array
237238

238239

239240
def _apply_qb_betas(model: Transformer, qb_betas: jax.Array) -> Transformer:
240-
"""Set router biases from QB betas (computed on previous step, applied on host)."""
241+
"""Set router biases from QB betas (computed on previous step)."""
241242
new_blocks = list(model.blocks)
242243
moe_idx = 0
243244
for i, block in enumerate(model.blocks):
@@ -260,11 +261,13 @@ def initial_state(
260261
ema_beta: float | None,
261262
) -> GrugTrainState:
262263
params = mp.cast_to_param(Transformer.init(model_config, key=key))
264+
num_moe_layers = sum(1 for b in params.blocks if b.mlp is not None)
263265
return GrugTrainState(
264266
step=jnp.array(0, dtype=jnp.int32),
265267
params=params,
266268
opt_state=optimizer.init(params),
267269
ema_params=params if ema_beta is not None else None,
270+
pending_qb_betas=jnp.zeros((num_moe_layers, model_config.num_experts)),
268271
)
269272

270273

@@ -288,6 +291,14 @@ def _make_train_step(
288291

289292
@functools.partial(jax.jit, donate_argnums=(0,), static_argnames=("compute_watch",))
290293
def train_step(state: GrugTrainState, batch, *, compute_watch: bool = False):
294+
# Apply pending QB betas to router biases inside JIT (avoids eager
295+
# host-side TPU kernel launches that can cause SPMD sync issues).
296+
qb_params = _apply_qb_betas(state.params, state.pending_qb_betas)
297+
if ema_beta is not None:
298+
qb_ema_params = _apply_qb_betas(state.ema_params, state.pending_qb_betas)
299+
else:
300+
qb_ema_params = None
301+
291302
def loss_fn(params):
292303
compute_params = mp.cast_to_compute(params)
293304
return compute_params.next_token_loss(
@@ -299,19 +310,19 @@ def loss_fn(params):
299310
return_router_metrics=True,
300311
)
301312

302-
(loss, summarized_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
313+
(loss, summarized_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(qb_params)
303314
metrics = {"train/loss": loss, **summarized_metrics}
304-
updates, opt_state = optimizer.update(grads, state.opt_state, state.params)
305-
params = optax.apply_updates(state.params, updates)
315+
updates, opt_state = optimizer.update(grads, state.opt_state, qb_params)
316+
params = optax.apply_updates(qb_params, updates)
306317

307318
if ema_beta is None:
308319
ema_params = None
309320
else:
310-
if state.ema_params is None:
321+
if qb_ema_params is None:
311322
raise ValueError("ema_params must be initialized when ema_beta is set.")
312323
ema_params = jax.tree_util.tree_map(
313324
lambda old, new: ema_beta * old + (1.0 - ema_beta) * new,
314-
state.ema_params,
325+
qb_ema_params,
315326
params,
316327
)
317328

@@ -323,7 +334,7 @@ def loss_fn(params):
323334
include_per_parameter_norms=watch_config.include_per_parameter_norms,
324335
include_histogram=watch_config.include_histograms,
325336
split_scan_layers=watch_config.split_scan_layers,
326-
params=state.params,
337+
params=qb_params,
327338
grads=grads,
328339
updates=updates,
329340
opt_state=state.opt_state,
@@ -336,6 +347,7 @@ def loss_fn(params):
336347
params=params,
337348
opt_state=opt_state,
338349
ema_params=ema_params,
350+
pending_qb_betas=metrics["qb_beta_per_layer"],
339351
)
340352

341353
return next_state, metrics, watch_stats
@@ -470,22 +482,10 @@ def _init_state(model_rng):
470482

471483
last_loss: float | jax.Array = 0.0
472484
last_step_duration = 0.0
473-
pending_qb_betas: jax.Array | None = None
474485

475486
# Main optimization loop.
476487
try:
477488
while int(state.step) < trainer.num_train_steps:
478-
# QB: apply router bias updates from previous step (on host).
479-
if pending_qb_betas is not None:
480-
state = dataclasses.replace(
481-
state,
482-
params=_apply_qb_betas(state.params, pending_qb_betas),
483-
ema_params=(
484-
_apply_qb_betas(state.ema_params, pending_qb_betas) if state.ema_params is not None else None
485-
),
486-
)
487-
pending_qb_betas = None
488-
489489
with jax.profiler.TraceAnnotation("load_batch"):
490490
batch = next(iterator)
491491
step_start = time.perf_counter()
@@ -498,7 +498,6 @@ def _init_state(model_rng):
498498
step = int(state.step) - 1
499499

500500
jax.block_until_ready(metrics["train/loss"])
501-
pending_qb_betas = metrics["qb_beta_per_layer"]
502501

503502
if jnp.isnan(metrics["train/loss"]):
504503
logger.error(f"NaN loss at step {int(state.step)}. Stopping training.")

0 commit comments

Comments
 (0)