Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions experiments/grug/moe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,11 @@ class GrugTrainState:
params: Transformer
opt_state: optax.OptState
ema_params: Transformer | None
pending_qb_betas: jax.Array


def _apply_qb_betas(model: Transformer, qb_betas: jax.Array) -> Transformer:
"""Set router biases from QB betas (computed on previous step, applied on host)."""
"""Set router biases from QB betas (computed on previous step)."""
new_blocks = list(model.blocks)
moe_idx = 0
for i, block in enumerate(model.blocks):
Expand All @@ -260,11 +261,13 @@ def initial_state(
ema_beta: float | None,
) -> GrugTrainState:
params = mp.cast_to_param(Transformer.init(model_config, key=key))
num_moe_layers = sum(1 for b in params.blocks if b.mlp is not None)
return GrugTrainState(
step=jnp.array(0, dtype=jnp.int32),
params=params,
opt_state=optimizer.init(params),
ema_params=params if ema_beta is not None else None,
pending_qb_betas=jnp.zeros((num_moe_layers, model_config.num_experts)),
)


Expand All @@ -288,6 +291,14 @@ def _make_train_step(

@functools.partial(jax.jit, donate_argnums=(0,), static_argnames=("compute_watch",))
def train_step(state: GrugTrainState, batch, *, compute_watch: bool = False):
# Apply pending QB betas to router biases inside JIT (avoids eager
# host-side TPU kernel launches that can cause SPMD sync issues).
qb_params = _apply_qb_betas(state.params, state.pending_qb_betas)
if ema_beta is not None:
qb_ema_params = _apply_qb_betas(state.ema_params, state.pending_qb_betas)
else:
qb_ema_params = None

def loss_fn(params):
compute_params = mp.cast_to_compute(params)
return compute_params.next_token_loss(
Expand All @@ -299,19 +310,19 @@ def loss_fn(params):
return_router_metrics=True,
)

(loss, summarized_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
(loss, summarized_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(qb_params)
metrics = {"train/loss": loss, **summarized_metrics}
updates, opt_state = optimizer.update(grads, state.opt_state, state.params)
params = optax.apply_updates(state.params, updates)
updates, opt_state = optimizer.update(grads, state.opt_state, qb_params)
params = optax.apply_updates(qb_params, updates)

if ema_beta is None:
ema_params = None
else:
if state.ema_params is None:
if qb_ema_params is None:
raise ValueError("ema_params must be initialized when ema_beta is set.")
ema_params = jax.tree_util.tree_map(
lambda old, new: ema_beta * old + (1.0 - ema_beta) * new,
state.ema_params,
qb_ema_params,
params,
)

Expand All @@ -323,7 +334,7 @@ def loss_fn(params):
include_per_parameter_norms=watch_config.include_per_parameter_norms,
include_histogram=watch_config.include_histograms,
split_scan_layers=watch_config.split_scan_layers,
params=state.params,
params=qb_params,
grads=grads,
updates=updates,
opt_state=state.opt_state,
Expand All @@ -336,6 +347,7 @@ def loss_fn(params):
params=params,
opt_state=opt_state,
ema_params=ema_params,
pending_qb_betas=metrics["qb_beta_per_layer"],
)

return next_state, metrics, watch_stats
Expand Down Expand Up @@ -470,22 +482,10 @@ def _init_state(model_rng):

last_loss: float | jax.Array = 0.0
last_step_duration = 0.0
pending_qb_betas: jax.Array | None = None

# Main optimization loop.
try:
while int(state.step) < trainer.num_train_steps:
# QB: apply router bias updates from previous step (on host).
if pending_qb_betas is not None:
state = dataclasses.replace(
state,
params=_apply_qb_betas(state.params, pending_qb_betas),
ema_params=(
_apply_qb_betas(state.ema_params, pending_qb_betas) if state.ema_params is not None else None
),
)
pending_qb_betas = None

with jax.profiler.TraceAnnotation("load_batch"):
batch = next(iterator)
step_start = time.perf_counter()
Expand All @@ -498,7 +498,6 @@ def _init_state(model_rng):
step = int(state.step) - 1

jax.block_until_ready(metrics["train/loss"])
pending_qb_betas = metrics["qb_beta_per_layer"]

if jnp.isnan(metrics["train/loss"]):
logger.error(f"NaN loss at step {int(state.step)}. Stopping training.")
Expand Down
Loading