diff --git a/experiments/grug/max_model_flop_mini/__init__.py b/experiments/grug/max_model_flop_mini/__init__.py new file mode 100644 index 0000000000..ec8bc038b7 --- /dev/null +++ b/experiments/grug/max_model_flop_mini/__init__.py @@ -0,0 +1,2 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/experiments/grug/max_model_flop_mini/launch.py b/experiments/grug/max_model_flop_mini/launch.py new file mode 100644 index 0000000000..9b871d2a30 --- /dev/null +++ b/experiments/grug/max_model_flop_mini/launch.py @@ -0,0 +1,213 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Max-features MoE trial run. + +Results: eval/paloma/c4_en/bpb: 1.1136 @ 5000 steps, ~9.14e17 model FLOPs. + +Enhancements on top of experiments/grug/moe: + +Architecture: +- QK-norm: non-parametric RMS norm on Q/K before RoPE. +- Partial RoPE: only rotates first 50% of head_dim (partial_rotary_factor=0.5). +- Parameter-free RMSNorm: no learnable scale weight. +- Embed norm: RMSNorm applied immediately after token embedding lookup. +- Per-head attention gate: learned sigmoid gate on each attention head output. +- Value embeddings (VE): auxiliary vocabulary embedding mixed into V on the + last num_ve_layers layers via learnable lambda/gate parameters. +- Residual stream mixing (x0): per-layer learnable interpolation between the + current hidden state and the original post-embed-norm state. +- Sliding window attention: alternating short/long causal windows across layers + (long every 4th layer, short = sliding_window // 2). +- Zero-init output projections: lm_head, attn w_o, dense MLP w_down, and MoE + w_down all initialized to zeros. +- Load-balancing loss and router z-loss (configurable, None disables each). + +Training / config: +- 16 experts (vs 8 in base), top-2 routing. +- Configurable ep_capacity_factor for expert parallelism. +- GrugMuonConfig (Muon) replaces AdamConfig, with 3D expert-weight support + (Newton-Schulz vmapped over the expert dim). +- Full Levanter TrainerConfig with checkpointing, profiler, mixed-precision + policy, and WandB tracking. +- Runs on Nemotron data mix with default validation sets. +""" + +import dataclasses +import os +from dataclasses import dataclass, field +from datetime import timedelta + +import jmp +from fray.cluster import ResourceConfig +from levanter.callbacks.profiler import ProfilerConfig +from levanter.checkpoint import CheckpointerConfig +from levanter.data.text import LmDataConfig +from levanter.optim import OptimizerConfig +from levanter.optim import GrugMuonConfig +from levanter.tracker import TrackerConfig +from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerConfig +from levanter.utils.mesh import MeshConfig +from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned +from marin.processing.tokenize import add_validation_sets_to_mixture + +from experiments.defaults import default_validation_sets +from experiments.grug.max_model_flop_mini.model import GrugModelConfig +from experiments.grug.max_model_flop_mini.train import GrugEvalConfig, GrugRunConfig, GrugTrainerConfig, run_grug +from experiments.tootsie.exp1295_32b import nemotron_mix + + +@dataclass(frozen=True) +class GrugMoeLaunchConfig: + """Last-mile run config for the MoE grug template. + + Keep this as the main entry point for day-to-day edits (model/data/optimizer/trainer/eval knobs). + """ + + model: GrugModelConfig + data: LmDataConfig + output_path: str + run_id: str + steps: int + batch_size: int + seed: int + mp: str # jmp policy string, e.g. "params=float32,compute=bfloat16,output=bfloat16". + tracker: TrackerConfig + optimizer: OptimizerConfig + grug_trainer: GrugTrainerConfig = field(default_factory=GrugTrainerConfig) + eval: GrugEvalConfig | None = field(default_factory=GrugEvalConfig) + + +GRUG_MOE_TRIAL_MODEL = GrugModelConfig( + vocab_size=128_256, + hidden_dim=512, + intermediate_dim=512 * 2, + shared_expert_intermediate_dim=0, + num_experts=16, + num_experts_per_token=2, + num_layers=8, + num_heads=4, + num_kv_heads=4, + max_seq_len=2048, + head_dim=None, + initializer_std=0.02, + lbl_coef=0.01, + rzl_coef=0.001, + num_ve_layers=2, + sliding_window=2048, + rope_theta=1024, +) + +NEMOTRON_MIX_WITH_DEFAULT_VALIDATION = add_validation_sets_to_mixture( + nemotron_mix, + default_validation_sets(tokenizer=nemotron_mix.tokenizer), +) + + +def _resolve_run_id(default_run_id: str) -> str: + """Resolve run id and append `FERRY_DATE` when launching from ferry workflows.""" + run_id = os.environ.get("GRUG_RUN_ID", default_run_id) + ferry_date = os.environ.get("FERRY_DATE") + if ferry_date: + run_id = f"{run_id}-{ferry_date}" + return run_id + + +def _resolve_tracker(tracker: TrackerConfig, run_id: str) -> TrackerConfig: + if isinstance(tracker, WandbConfig): + return dataclasses.replace(tracker, name=run_id) + return tracker + + +def run_grug_moe_trial(config: GrugMoeLaunchConfig) -> None: + # Map template launch knobs onto full Levanter TrainerConfig. + trainer = TrainerConfig( + mesh=MeshConfig(axes={"data": -1, "expert": 1, "model": 1}), + id=config.run_id, + seed=config.seed, + train_batch_size=config.batch_size, + num_train_steps=config.steps, + profiler=ProfilerConfig(enabled=False, start_step=5, num_steps=100, perfetto_link=False), + mp=jmp.get_policy(config.mp), + tracker=_resolve_tracker(config.tracker, config.run_id), + use_explicit_mesh_axes=True, + require_accelerator=True, + allow_nondivisible_batch_size=False, + checkpointer=CheckpointerConfig( + base_path=os.path.join(config.output_path, "checkpoints"), + append_run_id_to_base_path=False, + save_interval=timedelta(minutes=10), + keep=[{"every": 1000}], + ), + ) + + grug_trainer = dataclasses.replace(config.grug_trainer, trainer=trainer) + + run_config = GrugRunConfig( + model=config.model, + data=config.data, + optimizer=config.optimizer, + trainer=grug_trainer, + eval=config.eval, + ) + run_grug(run_config) + + +RESOLVED_RUN_ID = _resolve_run_id("moe_feat_max_mini") + + +grug_moe_trial = ExecutorStep( + name="grug/moe_feat_max_mini", + fn=run_grug_moe_trial, + config=GrugMoeLaunchConfig( + model=versioned(GRUG_MOE_TRIAL_MODEL), + data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION, + # this_output_path() resolves to this step's output root (e.g. gs://.../grug/moe-trial-). + output_path=this_output_path(), + # Keep run id out of versioning so changing job metadata doesn't create a new output path. + run_id=RESOLVED_RUN_ID, + steps=versioned(5000), + batch_size=versioned(128), + seed=versioned(0), + mp=versioned("params=float32,compute=bfloat16,output=bfloat16"), + tracker=WandbConfig( + project="dial_moe", + tags=["grug", "template", "moe"], + group="moe_feat_max", + name=None, # filled from run_id in _resolve_tracker + ), + optimizer=versioned( + GrugMuonConfig( + learning_rate=0.02, + adam_lr=0.0064, + weight_decay=0, + min_lr_ratio=0.1, + warmup=0, + momentum=0.95, + beta1=0.8, + beta2=0.95, + epsilon=1e-15, + muon_epsilon=1e-5, + max_grad_norm=1, + lr_schedule="linear", + decay=0.5, + ) + ), + grug_trainer=versioned( + GrugTrainerConfig( + z_loss_weight=0, + ema_beta=None, + log_every=1, + ) + ), + ), + resources=ResourceConfig.with_tpu("v4-8"), +) + + +if __name__ == "__main__": + executor_main( + steps=[grug_moe_trial], + description="Applying max features to small scale MoE on Nemotron mix.", + ) diff --git a/experiments/grug/max_model_flop_mini/model.py b/experiments/grug/max_model_flop_mini/model.py new file mode 100644 index 0000000000..0ce5dd485b --- /dev/null +++ b/experiments/grug/max_model_flop_mini/model.py @@ -0,0 +1,628 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""MoE grug variant model. + +This variant intentionally mirrors `experiments/grug/base/model.py` and applies +MoE-specific changes inline. Keeping the file largely self-contained follows the +grug copy-first workflow in `docs/recipes/change_grug.md`. + +Architecture changes ported from ``experiments/speedrun/new_grug_moe/moe_max2.py``: + +1. **QK-Norm** - Non-parametric RMS norm on Q and K projections before RoPE. +2. **Parameter-free RMSNorm** - Learnable ``weight`` removed from ``RMSNorm``. +3. **Zero-init output projections** - ``w_o`` (attention), ``w_down`` (DenseMLP), + and ``w_down`` (MoEMLP) are initialised to zeros for stable residual streams. +4. **Embed norm** - ``RMSNorm`` applied immediately after token embedding lookup. +5. **Per-head attention gate** - Learned sigmoid gate on attention output per head. +6. **Value Embeddings (VE)** - Auxiliary token embedding for V heads, gated and + mixed via learnable lambdas on the last ``num_ve_layers`` layers. +7. **Residual stream mixing (x0)** - Per-layer learnable interpolation between + the current hidden state and the original (post-embed-norm) embedding. +8. **Sliding window attention** - Alternating short/long causal windows across + layers when ``sliding_window`` is set. +""" + +from dataclasses import dataclass + +import equinox as eqx +import jax +import jax.numpy as jnp +from einops import rearrange +from haliax.jax_utils import named_call +from jax import random +from jax.sharding import PartitionSpec as P +from jax.sharding import get_abstract_mesh, reshard +from jaxtyping import Array, Float, Int, PRNGKeyArray + +from levanter.grug.attention import AttentionMask, attention +from levanter.grug.grug_moe import MoeActivation, moe_mlp +from levanter.grug.loss import fused_linear_softmax_cross_entropy_loss +from levanter.grug.sharding import Pvocab +from levanter.tracker.histogram import Histogram +from levanter.utils.activation import ActivationFunctionEnum + +_DEFAULT_EP_CAPACITY_FACTOR = 1.25 + + +def _mesh_axis_size(mesh: jax.sharding.AbstractMesh | None, axis_name: str) -> int: + if mesh is None or mesh.empty or axis_name not in mesh.shape: + raise ValueError(f"grug/moe requires an abstract mesh with axis '{axis_name}'") + return int(mesh.shape[axis_name]) + + +def _batch_spec() -> P: + return P(("data", "expert")) + + +def _rotary_cache(seq_len: int, rotary_dim: int, rope_theta: float) -> tuple[Float[Array, "S D"], Float[Array, "S D"]]: + half_dim = rotary_dim // 2 + inv_freq = 1.0 / (rope_theta ** (jnp.arange(0, half_dim, dtype=jnp.float32) / half_dim)) + positions = jnp.arange(seq_len, dtype=jnp.float32) + angles = positions[:, None] * inv_freq[None, :] + return jnp.cos(angles), jnp.sin(angles) + + +@named_call +def apply_rotary_embedding( + q: Float[Array, "B S H D"], + k: Float[Array, "B S H D"], + *, + seq_len: int, + head_dim: int, + rope_theta: float, + partial_rotary_factor: float = 0.5, +) -> tuple[Float[Array, "B S H D"], Float[Array, "B S H D"]]: + """Partial rotary embedding: only the first ``partial_rotary_factor`` of the head dim is rotated.""" + rotary_dim = int(head_dim * partial_rotary_factor) + cos, sin = _rotary_cache(seq_len, rotary_dim, rope_theta) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + + def _apply(x: Float[Array, "B S H D"]) -> Float[Array, "B S H D"]: + x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:] + x1, x2 = jnp.split(x_rot, 2, axis=-1) + x_rot = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + return jnp.concatenate([x_rot, x_pass], axis=-1) + + return _apply(q), _apply(k) + + +@named_call +def qk_norm(x: Float[Array, "B S H D"], eps: float = 1e-6) -> Float[Array, "B S H D"]: + """Non-parametric RMS norm over the head dimension.""" + variance = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) + return (x * jax.lax.rsqrt(variance + eps)).astype(x.dtype) + + +@dataclass(frozen=True) +class GrugModelConfig: + """Hyperparameters for the compact grug MoE transformer.""" + + vocab_size: int + hidden_dim: int = 384 + intermediate_dim: int = 384 + shared_expert_intermediate_dim: int = 384 + num_experts: int = 8 + num_experts_per_token: int = 2 + num_layers: int = 8 + num_heads: int = 8 + num_kv_heads: int = 8 + head_dim: int | None = None + max_seq_len: int = 4096 + layer_norm_eps: float = 1e-5 + initializer_std: float = 0.006 + lbl_coef: float | None = None # Load-balancing loss coefficient; None disables. + rzl_coef: float | None = None # Router z-loss coefficient; None disables. + rope_theta: float = 1024.0 + partial_rotary_factor: float = 0.5 + gate_input_dim: int = 12 + num_ve_layers: int = 0 + sliding_window: int | None = None + ep_capacity_factor: float = 1.25 + + def __post_init__(self) -> None: + _ = self.inferred_head_dim + if self.num_heads % self.num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads for grouped-query attention") + if self.vocab_size <= 0: + raise ValueError("vocab_size must be positive") + if self.max_seq_len <= 0: + raise ValueError("max_seq_len must be positive") + if self.num_experts <= 0: + raise ValueError("num_experts must be positive") + if self.num_experts_per_token <= 0: + raise ValueError("num_experts_per_token must be positive") + if self.num_experts_per_token > self.num_experts: + raise ValueError("num_experts_per_token must be <= num_experts") + if self.shared_expert_intermediate_dim < 0: + raise ValueError("shared_expert_intermediate_dim must be non-negative") + + @property + def inferred_head_dim(self) -> int: + if self.head_dim is not None: + return self.head_dim + if self.hidden_dim % self.num_heads != 0: + raise ValueError( + f"hidden_dim={self.hidden_dim} is not divisible by num_heads={self.num_heads}; set head_dim explicitly" + ) + return self.hidden_dim // self.num_heads + + +class CausalSelfAttention(eqx.Module): + w_q: Float[Array, "D NH"] + w_k: Float[Array, "D MH"] + w_v: Float[Array, "D MH"] + w_o: Float[Array, "NH D"] + ve_embed: jax.Array | None + value_lambda: jax.Array + ve_lambda: jax.Array + ve_gate: jax.Array + attn_gate: jax.Array + cfg: GrugModelConfig = eqx.field(static=True) + + @staticmethod + def init(cfg: GrugModelConfig, *, key: PRNGKeyArray, has_ve: bool = False) -> "CausalSelfAttention": + k_q, k_k, k_v, _k_o, k_ve = random.split(key, 5) + d, n, m, h = cfg.hidden_dim, cfg.num_heads, cfg.num_kv_heads, cfg.inferred_head_dim + g = cfg.gate_input_dim + + ve_embed = None + if has_ve: + ve_dim = m * h + ve_embed = reshard(_init_weight(k_ve, (cfg.vocab_size, ve_dim), cfg.initializer_std), Pvocab) + + return CausalSelfAttention( + w_q=reshard(_init_weight(k_q, (d, n * h), cfg.initializer_std), P("data", "model")), + w_k=reshard(_init_weight(k_k, (d, m * h), cfg.initializer_std), P("data", "model")), + w_v=reshard(_init_weight(k_v, (d, m * h), cfg.initializer_std), P("data", "model")), + w_o=reshard(jnp.zeros((n * h, d)), P("model", "data")), + ve_embed=ve_embed, + value_lambda=jnp.full((), 0.5, dtype=jnp.float32), + ve_lambda=jnp.full((), 0.5, dtype=jnp.float32), + ve_gate=reshard(jnp.zeros((g, m), dtype=jnp.float32), P(None, None)), + attn_gate=reshard(jnp.zeros((g, n), dtype=jnp.float32), P(None, None)), + cfg=cfg, + ) + + @named_call + def __call__( + self, + x: Float[Array, "B S D"], + mask: AttentionMask | jax.Array, + token_ids: Int[Array, "B S"] | None = None, + ) -> Float[Array, "B S D"]: + head_dim = self.cfg.inferred_head_dim + seq_len = x.shape[1] + batch_spec = _batch_spec() + g = self.cfg.gate_input_dim + + q = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_q), "... (n d) -> ... n d", d=head_dim) + k = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_k), "... (m d) -> ... m d", d=head_dim) + v = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_v), "... (m d) -> ... m d", d=head_dim) + + q = qk_norm(q) + k = qk_norm(k) + + q, k = apply_rotary_embedding( + q, + k, + seq_len=seq_len, + head_dim=head_dim, + rope_theta=self.cfg.rope_theta, + partial_rotary_factor=self.cfg.partial_rotary_factor, + ) + + if self.ve_embed is not None and token_ids is not None: + ve = self.ve_embed.at[token_ids].get(out_sharding=batch_spec) + ve_heads = rearrange(ve, "... (m d) -> ... m d", d=head_dim) + gate_out = 2 * jax.nn.sigmoid(x[..., :g] @ self.ve_gate) + v = self.value_lambda * v + self.ve_lambda * gate_out[..., None] * ve_heads + + attn_out = attention(q, k, v, mask) + + attn_gate_out = 2 * jax.nn.sigmoid(x[..., :g] @ self.attn_gate) + attn_out = attn_gate_out[..., None] * attn_out + + attn_out = rearrange(attn_out, "... n d -> ... (n d)") + return jnp.einsum("bsh,hd->bsd", attn_out, self.w_o, out_sharding=batch_spec) + + +class RMSNorm(eqx.Module): + eps: float = eqx.field(static=True) + + @staticmethod + def init(dim: int, eps: float) -> "RMSNorm": + return RMSNorm(eps=eps) + + @named_call + def __call__(self, x: Float[Array, "... D"]) -> Float[Array, "... D"]: + dtype = x.dtype + x = x.astype(jnp.float32) + variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + normed = x * jax.lax.rsqrt(variance + self.eps) + return normed.astype(dtype) + + +class DenseMLP(eqx.Module): + w_gate: jax.Array + w_up: jax.Array + w_down: jax.Array + + @staticmethod + def init(hidden_dim: int, intermediate_dim: int, initializer_std: float, *, key: PRNGKeyArray) -> "DenseMLP": + k_gate, k_up = random.split(key, 2) + return DenseMLP( + w_gate=reshard(_init_weight(k_gate, (hidden_dim, intermediate_dim), initializer_std), P("data", "model")), + w_up=reshard(_init_weight(k_up, (hidden_dim, intermediate_dim), initializer_std), P("data", "model")), + w_down=reshard(jnp.zeros((intermediate_dim, hidden_dim)), P("model", "data")), + ) + + @named_call + def __call__( + self, + x: Float[Array, "B S D"], + *, + activation: MoeActivation = ActivationFunctionEnum.silu, + ) -> Float[Array, "B S D"]: + if isinstance(activation, ActivationFunctionEnum): + activation_fn = activation.to_jax_fn() + else: + activation_fn = activation + + b, s, _ = x.shape + x_flat = rearrange(x, "b s d -> (b s) d") + gate = jnp.einsum("td,dm->tm", x_flat, self.w_gate) + up = jnp.einsum("td,dm->tm", x_flat, self.w_up) + out_flat = jnp.einsum("tm,md->td", activation_fn(gate) * up, self.w_down, out_sharding=_batch_spec()) + return rearrange(out_flat, "(b s) d -> b s d", b=b, s=s) + + +def _routing_stats_from_selected_experts( + selected_experts: Int[Array, "T K"], + *, + num_experts: int, +) -> dict[str, jax.Array]: + expert_counts = jnp.sum(jax.nn.one_hot(selected_experts, num_experts, dtype=jnp.float32), axis=(0, 1)) + total_assignments = jnp.maximum(jnp.sum(expert_counts), 1.0) + expert_loads = expert_counts / total_assignments + routing_entropy = -jnp.sum(expert_loads * jnp.log(expert_loads + 1e-6)) + return { + "routing_counts": expert_counts, + "routing_entropy": routing_entropy, + } + + +def _summarize_router_metrics(router_metrics: dict[str, jax.Array]) -> dict[str, jax.Array | Histogram]: + routing_entropy = router_metrics["routing_entropy_per_layer"] + routing_counts = router_metrics["routing_counts_per_layer"] + num_layers = int(routing_entropy.shape[0]) + + num_experts = int(routing_counts.shape[1]) + out: dict[str, jax.Array | Histogram] = { + "train/router/routing_entropy_mean": jnp.mean(routing_entropy), + } + for i in range(num_layers): + out[f"train/router/layer_{i}/routing_entropy"] = routing_entropy[i] + out[f"train/router/layer_{i}/routing_hist"] = _histogram_from_expert_counts(routing_counts[i]) + layer_total = jnp.maximum(jnp.sum(routing_counts[i]), 1.0) + layer_loads = routing_counts[i] / layer_total + for j in range(num_experts): + out[f"moe/layer_{i}/expert_{j}/load"] = layer_loads[j] + return out + + +def _histogram_from_expert_counts(expert_counts: jax.Array) -> Histogram: + counts = jnp.asarray(expert_counts, dtype=jnp.float32) + num_experts = counts.shape[0] + expert_ids = jnp.arange(num_experts, dtype=jnp.float32) + num = jnp.sum(counts) + sum_values = jnp.sum(counts * expert_ids) + sum_squares = jnp.sum(counts * expert_ids * expert_ids) + nonzero = counts > 0 + min_value = jnp.where(nonzero, expert_ids, jnp.inf).min() + max_value = jnp.where(nonzero, expert_ids, -jnp.inf).max() + min_value = jnp.where(num > 0, min_value, 0.0) + max_value = jnp.where(num > 0, max_value, 0.0) + bucket_limits = jnp.arange(num_experts + 1, dtype=jnp.float32) + return Histogram( + min=min_value, + max=max_value, + num=num, + sum=sum_values, + sum_squares=sum_squares, + bucket_limits=bucket_limits, + bucket_counts=counts, + ) + + +class MoEMLP(eqx.Module): + router: jax.Array + w_up_gate: jax.Array + w_down: jax.Array + cfg: GrugModelConfig = eqx.field(static=True) + + @staticmethod + def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MoEMLP": + k_router, k_w_up_gate = random.split(key, 2) + mesh = get_abstract_mesh() + + expert_axis_size = _mesh_axis_size(mesh, "expert") + if cfg.num_experts % expert_axis_size != 0: + raise ValueError(f"num_experts={cfg.num_experts} must be divisible by expert axis size={expert_axis_size}") + + d, e, i = ( + cfg.hidden_dim, + cfg.num_experts, + cfg.intermediate_dim, + ) + + return MoEMLP( + router=reshard(_init_weight(k_router, (d, e), cfg.initializer_std), P(None, None)), + w_up_gate=reshard( + _init_weight(k_w_up_gate, (e, d, 2 * i), cfg.initializer_std), P("expert", "data", "model") + ), + w_down=reshard(jnp.zeros((e, i, d)), P("expert", "model", "data")), + cfg=cfg, + ) + + @named_call + def __call__( + self, + x: Float[Array, "B S D"], + ) -> tuple[Float[Array, "B S D"], dict[str, jax.Array]]: + b, s, _ = x.shape + x_flat = rearrange(x, "b s d -> (b s) d") + router_logits = jnp.einsum("td,de->te", x_flat, reshard(self.router, P(None, None))) + topk_logits, selected_experts = jax.lax.top_k(router_logits, self.cfg.num_experts_per_token) + combine_weights = jax.nn.softmax(topk_logits, axis=-1).astype(x.dtype) + router_stats = _routing_stats_from_selected_experts(selected_experts, num_experts=self.cfg.num_experts) + + if self.cfg.lbl_coef is not None: + router_probs = jax.nn.softmax(router_logits.astype(jnp.float32), axis=-1) + expert_counts = jnp.sum( + jax.nn.one_hot(selected_experts, self.cfg.num_experts, dtype=jnp.float32), axis=(0, 1) + ) + expert_loads = expert_counts / jnp.maximum(jnp.sum(expert_counts), 1.0) + f = expert_loads * (self.cfg.num_experts / self.cfg.num_experts_per_token) + p = jnp.mean(router_probs, axis=0) + router_stats["load_balancing_loss"] = jnp.asarray(self.cfg.lbl_coef, dtype=jnp.float32) * jnp.sum(f * p) + + if self.cfg.rzl_coef is not None: + z = jax.scipy.special.logsumexp(router_logits.astype(jnp.float32), axis=-1) + router_stats["router_z_loss"] = jnp.asarray(self.cfg.rzl_coef, dtype=jnp.float32) * jnp.mean(z**2) + + routed_flat = moe_mlp( + x_flat, + selected_experts.astype(jnp.int32), + combine_weights, + self.w_up_gate, + self.w_down, + activation=ActivationFunctionEnum.silu, + mesh=get_abstract_mesh(), + capacity_factor=self.cfg.ep_capacity_factor, + ) + routed = rearrange(routed_flat, "(b s) d -> b s d", b=b, s=s) + routed = reshard(routed, _batch_spec()) + return routed, router_stats + + +class Block(eqx.Module): + rms_attn: RMSNorm + attn: CausalSelfAttention + rms_mlp: RMSNorm + mlp: MoEMLP + shared: DenseMLP | None + + @staticmethod + def init(cfg: GrugModelConfig, *, key: PRNGKeyArray, has_ve: bool = False) -> "Block": + attn_key, mlp_key, shared_key = random.split(key, 3) + shared = None + if cfg.shared_expert_intermediate_dim > 0: + shared = DenseMLP.init( + cfg.hidden_dim, + cfg.shared_expert_intermediate_dim, + cfg.initializer_std, + key=shared_key, + ) + return Block( + rms_attn=RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps), + attn=CausalSelfAttention.init(cfg, key=attn_key, has_ve=has_ve), + rms_mlp=RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps), + mlp=MoEMLP.init(cfg, key=mlp_key), + shared=shared, + ) + + @named_call + def __call__( + self, + x: Float[Array, "B S D"], + mask: AttentionMask | jax.Array, + token_ids: Int[Array, "B S"] | None = None, + x0: Float[Array, "B S D"] | None = None, + resid_lambda: jax.Array | None = None, + x0_lambda: jax.Array | None = None, + ) -> tuple[Float[Array, "B S D"], dict[str, jax.Array]]: + x = x + self.attn(self.rms_attn(x), mask, token_ids=token_ids) + if resid_lambda is not None and x0 is not None: + x = resid_lambda * x + x0_lambda * x0 + mlp_in = self.rms_mlp(x) + mlp_out, router_stats = self.mlp(mlp_in) + if self.shared is not None: + mlp_out = mlp_out + self.shared(mlp_in, activation=ActivationFunctionEnum.silu) + x = x + mlp_out + return x, router_stats + + +class Transformer(eqx.Module): + token_embed: jax.Array + output_proj: jax.Array + blocks: tuple[Block, ...] + embed_norm: RMSNorm + final_norm: RMSNorm + resid_lambdas: tuple[jax.Array, ...] + x0_lambdas: tuple[jax.Array, ...] + config: GrugModelConfig = eqx.field(static=True) + + @staticmethod + def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "Transformer": + embed_key, _out_key, *block_keys = random.split(key, cfg.num_layers + 2) + token_embed = reshard(_init_weight(embed_key, (cfg.vocab_size, cfg.hidden_dim), cfg.initializer_std), Pvocab) + output_proj = reshard(jnp.zeros((cfg.hidden_dim, cfg.vocab_size)), Pvocab) + blocks = tuple( + Block.init(cfg, key=block_keys[i], has_ve=(i >= cfg.num_layers - cfg.num_ve_layers)) + for i in range(cfg.num_layers) + ) + embed_norm = RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps) + final_norm = RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps) + resid_lambdas = tuple(jnp.ones((), dtype=jnp.float32) for _ in range(cfg.num_layers)) + x0_lambdas = tuple(jnp.zeros((), dtype=jnp.float32) for _ in range(cfg.num_layers)) + + return Transformer( + token_embed=token_embed, + output_proj=output_proj, + blocks=blocks, + embed_norm=embed_norm, + final_norm=final_norm, + resid_lambdas=resid_lambdas, + x0_lambdas=x0_lambdas, + config=cfg, + ) + + @named_call + def __call__( + self, + token_ids: Int[Array, "B S"], + mask: AttentionMask | jax.Array | None = None, + ) -> tuple[Float[Array, "B S D"], dict[str, jax.Array]]: + if mask is None: + mask = AttentionMask.causal() + + batch_spec = _batch_spec() + hidden = self.token_embed.at[token_ids].get(out_sharding=batch_spec) + hidden = self.embed_norm(hidden) + x0 = hidden + + cfg = self.config + if cfg.sliding_window is not None: + segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None + short_mask = AttentionMask(is_causal=True, sliding_window=cfg.sliding_window // 2, segment_ids=segment_ids) + long_mask = AttentionMask(is_causal=True, sliding_window=cfg.sliding_window, segment_ids=segment_ids) + + all_router_stats: list[dict[str, jax.Array]] = [] + for i, (resid_lambda, x0_lambda, block) in enumerate( + zip(self.resid_lambdas, self.x0_lambdas, self.blocks, strict=True) + ): + if cfg.sliding_window is not None: + layer_mask = long_mask if i % 4 == 3 else short_mask + else: + layer_mask = mask + hidden, router_stats = eqx.filter_checkpoint(block)( + hidden, + layer_mask, + token_ids=token_ids, + x0=x0, + resid_lambda=resid_lambda, + x0_lambda=x0_lambda, + ) + all_router_stats.append(router_stats) + + router_metrics: dict[str, jax.Array] = { + "routing_entropy_per_layer": jnp.stack([s["routing_entropy"] for s in all_router_stats], axis=0), + "routing_counts_per_layer": jnp.stack([s["routing_counts"] for s in all_router_stats], axis=0), + } + if "load_balancing_loss" in all_router_stats[0]: + router_metrics["load_balancing_loss"] = sum(s["load_balancing_loss"] for s in all_router_stats) + if "router_z_loss" in all_router_stats[0]: + router_metrics["router_z_loss"] = sum(s["router_z_loss"] for s in all_router_stats) + return self.final_norm(hidden), router_metrics + + @named_call + def logits( + self, + token_ids: Int[Array, "B S"], + mask: AttentionMask | jax.Array | None = None, + ) -> Float[Array, "B S V"]: + batch_spec = _batch_spec() + hidden, _ = self(token_ids, mask=mask) + return jnp.einsum("bsh,hd->bsd", hidden, self.output_proj, out_sharding=batch_spec) + + def next_token_loss( + self, + token_ids: Int[Array, "B S"], + loss_weight: Float[Array, "B S"], + *, + mask: AttentionMask | jax.Array | None = None, + reduction: str = "mean", + logsumexp_weight: float | None = None, + loss_dtype: jnp.dtype = jnp.float32, + return_router_metrics: bool = False, + ) -> jax.Array | tuple[jax.Array, dict[str, jax.Array | Histogram]]: + hidden, router_metrics = self(token_ids, mask=mask) + labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32) + loss_weight = loss_weight.astype(loss_dtype) + + loss = fused_linear_softmax_cross_entropy_loss( + hidden, + self.output_proj, + labels, + weight=loss_weight, + reduction=reduction, + logsumexp_weight=logsumexp_weight, + dtype=loss_dtype, + ) + + aux_loss = jnp.zeros((), dtype=loss_dtype) + if "load_balancing_loss" in router_metrics: + aux_loss = aux_loss + router_metrics["load_balancing_loss"] + if "router_z_loss" in router_metrics: + aux_loss = aux_loss + router_metrics["router_z_loss"] + loss = loss + aux_loss + + if return_router_metrics: + summarized = _summarize_router_metrics(router_metrics) + if "load_balancing_loss" in router_metrics: + summarized["train/load_balancing_loss"] = jax.lax.stop_gradient(router_metrics["load_balancing_loss"]) + if "router_z_loss" in router_metrics: + summarized["train/router_z_loss"] = jax.lax.stop_gradient(router_metrics["router_z_loss"]) + return loss, summarized + return loss + + +def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float[Array, "..."]: + return std * random.truncated_normal(key, -3, 3, shape) + + +def debug_mesh_and_token_pspec(num_devices: int) -> tuple[jax.sharding.AbstractMesh, P]: + """Return a small abstract mesh and token sharding for lowering contract tests.""" + if num_devices <= 0: + raise ValueError(f"num_devices must be positive, got {num_devices}") + # Keep expert axis at 2 when possible to exercise EP lowering, otherwise + # fall back to expert=1. + expert = 2 if num_devices % 2 == 0 else 1 + data = max(1, num_devices // expert) + mesh = jax.sharding.AbstractMesh( + axis_sizes=(data, expert, 1), + axis_names=("data", "expert", "model"), + axis_types=( + jax.sharding.AxisType.Explicit, + jax.sharding.AxisType.Explicit, + jax.sharding.AxisType.Explicit, + ), + ) + return mesh, P(("data", "expert"), None) + + +__all__ = [ + "Block", + "CausalSelfAttention", + "DenseMLP", + "GrugModelConfig", + "MoEMLP", + "MoeActivation", + "RMSNorm", + "Transformer", + "debug_mesh_and_token_pspec", + "moe_mlp", + "qk_norm", +] diff --git a/experiments/grug/max_model_flop_mini/train.py b/experiments/grug/max_model_flop_mini/train.py new file mode 100644 index 0000000000..db58866ec3 --- /dev/null +++ b/experiments/grug/max_model_flop_mini/train.py @@ -0,0 +1,511 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import dataclasses +import functools +import logging +import time +from dataclasses import dataclass, field + +import jax +import jax.numpy as jnp +import jmp +import optax +from haliax import Axis +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from jax.tree_util import register_dataclass +from jaxtyping import PRNGKeyArray + +import levanter.callbacks as callbacks +import levanter.tracker +from levanter.callbacks.state_adapter import StateCallbackRunner +from levanter.callbacks.watch import WatchConfig, compute_watch_stats +from levanter.checkpoint import load_checkpoint +from levanter.data import AsyncDataset, DataLoader +from levanter.data.mixture import MixtureDataset, rescale_mixture_schedule_for_batch_schedule +from levanter.data.text import GrugLmExample, LmDataConfig +from levanter.data.text.examples import grug_lm_example_from_named +from levanter.eval import TaggedEvaluator, cb_tagged_evaluate +from levanter.models.lm_model import LmExample +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.schedule import BatchSchedule +from levanter.trainer import TrainerConfig +from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.jax_utils import parameter_count +from levanter.utils.logging import LoadingTimeTrackerIterator + +from experiments.grug.max_model_flop_mini.model import GrugModelConfig, Transformer + +# This file intentionally mirrors `experiments/grug/base/train.py` with +# variant-specific model/loss/FLOP wiring, per the grug copy-first workflow in +# `docs/recipes/change_grug.md`. + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class GrugTrainerConfig: + """Runtime knobs for grug training.""" + + trainer: TrainerConfig = field(default_factory=lambda: TrainerConfig(use_explicit_mesh_axes=True)) + train_batch_pspec: P = field(default_factory=lambda: P(("data",))) + data_seed: int | None = None + log_every: int = 1 + ema_beta: float | None = None # EMA coefficient for eval/checkpoint model; None disables EMA. + z_loss_weight: float = 0.0 # Weight on logsumexp (z-loss) stabilization term. + + +@dataclass(frozen=True) +class GrugEvalConfig: + """Perplexity eval settings for grug training.""" + + eval_batch_size: int = 512 + eval_batch_pspec: P = field(default_factory=lambda: P(("data",))) + steps_per_eval: int | None = 1000 + max_eval_batches: int | None = None + prefix: str = "eval" + eval_current: bool = True + eval_ema: bool = True + compute_bpb: bool = True + + +@dataclass(frozen=True) +class GrugRunConfig: + """Top-level config for grug training.""" + + model: GrugModelConfig + data: LmDataConfig + optimizer: OptimizerConfig = field(default_factory=AdamConfig) + trainer: GrugTrainerConfig = field(default_factory=GrugTrainerConfig) + eval: GrugEvalConfig | None = field(default_factory=GrugEvalConfig) + + +def build_train_dataset( + data_config: LmDataConfig, + *, + max_seq_len: int, + batch_schedule: BatchSchedule, + key: PRNGKeyArray, +) -> MixtureDataset[GrugLmExample]: + pos = Axis("position", max_seq_len) + mix_key, shuffle_key = jax.random.split(key) + weights = data_config.train_weights + if isinstance(weights, list): + weights = rescale_mixture_schedule_for_batch_schedule(weights, batch_schedule) + + initial_batch_size = batch_schedule.batch_size_at_step(0) + datasets = data_config.train_sets(pos, key=shuffle_key, initial_batch_size=initial_batch_size) + return MixtureDataset( + datasets=datasets, + weights=weights, + stop_strategy=data_config.stop_strategy, + key=mix_key, + block_size=data_config.mixture_block_size, + ) + + +def build_train_loader( + dataset: AsyncDataset[GrugLmExample], + *, + batch_schedule: BatchSchedule, + mesh: Mesh, + batch_pspec: P = P(("data",)), +) -> DataLoader[GrugLmExample]: + # DataLoader uses this batch axis mapping to shard batches across the distributed mesh. + axis_resource = batch_pspec[0] + return DataLoader( + dataset, + batch_schedule.schedule, + mesh=mesh, + axis_resources={"__BATCH__": axis_resource}, + batch_axis_name="__BATCH__", + allow_nondivisible_batch_size=False, + ) + + +def build_tagged_evaluator( + *, + data_config: LmDataConfig, + max_seq_len: int, + mesh: Mesh, + eval_cfg: GrugEvalConfig, +) -> TaggedEvaluator[LmExample | GrugLmExample, Transformer] | None: + pos = Axis("position", max_seq_len) + tagged_eval_sets = data_config.tagged_eval_sets(pos) + if len(tagged_eval_sets) == 0: + logger.warning("No evaluation datasets provided.") + return None + + max_examples_per_dataset = None + if eval_cfg.max_eval_batches is not None: + max_examples_per_dataset = eval_cfg.max_eval_batches * eval_cfg.eval_batch_size + + tokenizer = data_config.the_tokenizer if eval_cfg.compute_bpb else None + batch_axis_resource = eval_cfg.eval_batch_pspec[0] + eval_axis_mapping = {"batch": batch_axis_resource} + eval_batch = Axis("batch", eval_cfg.eval_batch_size) + eval_array_sharding = NamedSharding(mesh, P(batch_axis_resource, None)) + + def eval_loss_fn(model: Transformer, batch: LmExample | GrugLmExample) -> tuple[jax.Array, jax.Array, jax.Array]: + if isinstance(batch, LmExample): + batch = grug_lm_example_from_named(batch) + per_pos_loss = model.next_token_loss( + batch.tokens, + batch.loss_weight, + mask=batch.attn_mask, + reduction="none", + logsumexp_weight=None, + ) + per_pos_loss = jax.sharding.reshard(per_pos_loss, eval_array_sharding) + per_pos_weight = jax.sharding.reshard(batch.loss_weight, eval_array_sharding) + per_pos_token_id = jnp.roll(batch.tokens, -1, axis=-1) + return per_pos_loss, per_pos_weight, per_pos_token_id + + return TaggedEvaluator( + EvalBatch=eval_batch, + tagged_eval_sets=tagged_eval_sets, + loss_fn=eval_loss_fn, + tokenizer=tokenizer, + device_mesh=mesh, + axis_mapping=eval_axis_mapping, + max_examples_per_dataset=max_examples_per_dataset, + ) + + +def _compute_flops( + *, + model_config: GrugModelConfig, +) -> tuple[float, dict[str, float]]: + flops_per_token = lm_flops_per_token( + hidden_dim=model_config.hidden_dim, + intermediate_dim=model_config.intermediate_dim, + num_layers=model_config.num_layers, + num_kv_heads=model_config.num_kv_heads, + num_heads=model_config.num_heads, + seq_len=model_config.max_seq_len, + vocab_size=model_config.vocab_size, + glu=True, + num_experts=model_config.num_experts, + num_shared_experts=1 if model_config.shared_expert_intermediate_dim > 0 else 0, + num_experts_per_tok=model_config.num_experts_per_token, + ) + flops_per_example = 3 * flops_per_token * model_config.max_seq_len + + flops_summary: dict[str, float] = { + "throughput/flops_per_token_analytic": flops_per_token, + "throughput/flops_per_example_analytic": flops_per_example, + } + + return flops_per_example, flops_summary + + +def _make_mixture_stage_callback(train_dataset: MixtureDataset, batch_schedule: BatchSchedule): + last_mixture_stage = -1 + + def log_mixture_stage(step_info): + nonlocal last_mixture_stage + seq_index = batch_schedule.global_data_offset_by_step(step_info.step) + block_id = seq_index // train_dataset.block_size + stage = train_dataset._get_stage_for_block(block_id) + if stage == last_mixture_stage: + return + + weights = train_dataset.weight_stages[stage][1] + mixture_log = {f"mixture/weight/{name}": weight for name, weight in weights.items()} + mixture_log["mixture/stage"] = stage + levanter.tracker.log(mixture_log, step=step_info.step) + last_mixture_stage = stage + + return log_mixture_stage + + +@register_dataclass +@dataclass(frozen=True) +class GrugTrainState: + step: jax.Array + params: Transformer + opt_state: optax.OptState + ema_params: Transformer + + +def initial_state( + model_config: GrugModelConfig, + *, + optimizer: optax.GradientTransformation, + mp: jmp.Policy, + key: PRNGKeyArray, +) -> GrugTrainState: + params = mp.cast_to_param(Transformer.init(model_config, key=key)) + return GrugTrainState( + step=jnp.array(0, dtype=jnp.int32), + params=params, + opt_state=optimizer.init(params), + ema_params=params, + ) + + +def _make_train_step( + optimizer: optax.GradientTransformation, + mp: jmp.Policy, + *, + z_loss_weight: float, + ema_beta: float | None, + watch_config: WatchConfig | None = None, +): + one = jnp.array(1, dtype=jnp.int32) + z_loss = z_loss_weight if z_loss_weight > 0 else None + if watch_config is not None: + if isinstance(watch_config.watch_targets, str): + watch_targets = tuple(t.strip() for t in watch_config.watch_targets.split(",")) + else: + watch_targets = tuple(watch_config.watch_targets) + else: + watch_targets = () + + @functools.partial(jax.jit, donate_argnums=(0,), static_argnames=("compute_watch",)) + def train_step(state: GrugTrainState, batch, *, compute_watch: bool = False): + def loss_fn(params): + compute_params = mp.cast_to_compute(params) + return compute_params.next_token_loss( + batch.tokens, + batch.loss_weight, + mask=batch.attn_mask, + reduction="mean", + logsumexp_weight=z_loss, + return_router_metrics=True, + ) + + (loss, summarized_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) + metrics = {"train/loss": loss, **summarized_metrics} + updates, opt_state = optimizer.update(grads, state.opt_state, state.params) + params = optax.apply_updates(state.params, updates) + + if ema_beta is None: + ema_params = params + else: + ema_params = jax.tree_util.tree_map( + lambda old, new: ema_beta * old + (1.0 - ema_beta) * new, + state.ema_params, + params, + ) + + watch_stats = None + if watch_config is not None and compute_watch: + watch_stats = compute_watch_stats( + watch_targets=watch_targets, + include_norms=watch_config.include_norms, + include_per_parameter_norms=watch_config.include_per_parameter_norms, + include_histogram=watch_config.include_histograms, + split_scan_layers=watch_config.split_scan_layers, + params=state.params, + grads=grads, + updates=updates, + opt_state=state.opt_state, + model_tree_type=type(state.params), + ) + + next_state = dataclasses.replace( + state, + step=state.step + one, + params=params, + opt_state=opt_state, + ema_params=ema_params, + ) + + return next_state, metrics, watch_stats + + return train_step + + +def run_grug(config: GrugRunConfig) -> None: + """Entry point for the grug template training loop.""" + trainer = config.trainer.trainer + trainer.initialize() + levanter.tracker.log_configuration(config) + + run_id = trainer.id + if run_id is None: + raise ValueError("trainer.id was not initialized") + + optimizer = config.optimizer.build(trainer.num_train_steps) + watch_config = trainer.watch + train_step = _make_train_step( + optimizer, + trainer.mp, + z_loss_weight=config.trainer.z_loss_weight, + ema_beta=config.trainer.ema_beta, + watch_config=watch_config if watch_config.is_enabled else None, + ) + + data_key, model_key = jax.random.split(jax.random.PRNGKey(trainer.seed), 2) + if config.trainer.data_seed is not None: + data_key = jax.random.PRNGKey(config.trainer.data_seed) + + # Build data/model state under the trainer mesh so all arrays are sharded consistently. + with trainer.use_device_mesh(): + mesh = trainer.device_mesh + batch_schedule = trainer.batch_schedule + + train_dataset = build_train_dataset( + config.data, + max_seq_len=config.model.max_seq_len, + batch_schedule=batch_schedule, + key=data_key, + ) + train_loader = build_train_loader( + train_dataset, + batch_schedule=batch_schedule, + mesh=mesh, + batch_pspec=config.trainer.train_batch_pspec, + ) + + @jax.jit + def _init_state(model_rng): + return initial_state( + config.model, + optimizer=optimizer, + mp=trainer.mp, + key=model_rng, + ) + + state = _init_state(model_key) + + checkpointer = trainer.checkpointer.create(run_id) + checkpoint_path = trainer.load_checkpoint_path + if checkpoint_path is None and checkpointer is not None: + checkpoint_path = trainer.checkpointer.expanded_path(run_id) + if checkpoint_path is None: + if trainer.load_checkpoint: + raise FileNotFoundError("load_checkpoint=True but no checkpoint path is configured.") + elif trainer.load_checkpoint is not False: + try: + state = load_checkpoint( + state, + checkpoint_path, + discover_latest=True, + axis_mapping=None, + mesh=mesh, + allow_partial=trainer.allow_partial_checkpoint, + ) + except FileNotFoundError: + if trainer.load_checkpoint is True: + raise + logger.info(f"Checkpoint not found at {checkpoint_path}. Starting from scratch.") + + levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)}) + + flops_per_example, flops_summary = _compute_flops(model_config=config.model) + levanter.tracker.log_summary(flops_summary) + + eval_cfg = config.eval + evaluator = None + if eval_cfg is not None: + evaluator = build_tagged_evaluator( + data_config=config.data, + max_seq_len=config.model.max_seq_len, + mesh=mesh, + eval_cfg=eval_cfg, + ) + + profiler_cfg = trainer.profiler + profiler_num_steps = profiler_cfg.resolve_num_profile_steps(num_train_steps=trainer.num_train_steps) + profiler_enabled = profiler_cfg.is_enabled and profiler_num_steps > 0 + + log_every = max(1, config.trainer.log_every) + iterator = LoadingTimeTrackerIterator(train_loader.iter_from_step(int(state.step))) + + state_callbacks = StateCallbackRunner[GrugTrainState]( + step_getter=lambda s: s.step, + model_getter=lambda s: s.params, + eval_model_getter=lambda s: s.ema_params, + opt_state_getter=lambda s: s.opt_state, + ) + state_callbacks.add_hook( + callbacks.log_performance_stats(config.model.max_seq_len, batch_schedule, flops_per_example), + every=log_every, + ) + state_callbacks.add_hook(callbacks.pbar_logger(total=trainer.num_train_steps), every=log_every) + state_callbacks.add_hook(callbacks.log_step_info(trainer.num_train_steps), every=log_every) + if profiler_enabled: + state_callbacks.add_hook( + callbacks.profile( + str(trainer.log_dir / run_id / "profiler"), + profiler_cfg.start_step, + profiler_num_steps, + profiler_cfg.perfetto_link, + ), + every=1, + ) + state_callbacks.add_hook(_make_mixture_stage_callback(train_dataset, batch_schedule), every=1) + if evaluator is not None and eval_cfg is not None: + interval = eval_cfg.steps_per_eval + eval_ema = eval_cfg.eval_ema and config.trainer.ema_beta is not None + if interval is not None and interval > 0 and (eval_cfg.eval_current or eval_ema): + state_callbacks.add_hook( + cb_tagged_evaluate( + evaluator, + prefix=eval_cfg.prefix, + eval_current=eval_cfg.eval_current, + eval_ema=eval_ema, + ), + every=interval, + ) + + last_loss: float | jax.Array = 0.0 + last_step_duration = 0.0 + + # Main optimization loop. + try: + while int(state.step) < trainer.num_train_steps: + batch = next(iterator) + step_start = time.perf_counter() + current_step = int(state.step) + # grad_watch runs only on its configured interval. + compute_watch = ( + watch_config.is_enabled and watch_config.interval > 0 and current_step % watch_config.interval == 0 + ) + state, metrics, watch_stats = train_step(state, batch, compute_watch=compute_watch) + step = int(state.step) - 1 + + jax.block_until_ready(metrics["train/loss"]) + duration = time.perf_counter() - step_start + hook_start = time.perf_counter() + state_callbacks.run(state, loss=metrics["train/loss"], step_duration=duration) + last_loss = metrics["train/loss"] + last_step_duration = duration + levanter.tracker.log({"throughput/hook_time": time.perf_counter() - hook_start}, step=step) + levanter.tracker.log({"throughput/loading_time": iterator.this_load_time}, step=step) + aux_metrics = { + key: value + for key, value in metrics.items() + if key.startswith(("train/router/", "train/load_balancing_loss", "train/router_z_loss", "moe/")) + } + if aux_metrics: + levanter.tracker.log(aux_metrics, step=step) + + if watch_stats is not None: + levanter.tracker.log(watch_stats, step=step) + + if checkpointer is not None: + checkpointer.on_step(tree={"train_state": state}, step=int(state.step)) + finally: + # Mirror classic trainer behavior: force callbacks on the last completed step. + state_callbacks.run(state, loss=last_loss, step_duration=last_step_duration, force=True) + if checkpointer is not None: + checkpointer.on_step(tree={"train_state": state}, step=int(state.step), force=True) + checkpointer.wait_until_finished() + + levanter.tracker.current_tracker().finish() + + +__all__ = [ + "GrugEvalConfig", + "GrugRunConfig", + "GrugTrainState", + "GrugTrainerConfig", + "initial_state", + "run_grug", +] diff --git a/lib/levanter/src/levanter/optim/grugmuon.py b/lib/levanter/src/levanter/optim/grugmuon.py index 57e15a23bc..9a12a0318a 100644 --- a/lib/levanter/src/levanter/optim/grugmuon.py +++ b/lib/levanter/src/levanter/optim/grugmuon.py @@ -5,8 +5,9 @@ Muon optimizer for models using raw JAX arrays with (fan_in, fan_out) layout, such as Grug models. -All 2D arrays are routed to Muon, except those whose path contains -'embed', 'lm_head', or 'output' (case-insensitive), which use AdamW. +2D arrays and 3D arrays (where the first dim is a batch/expert dim) are routed +to Muon, except those whose path contains 'embed', 'lm_head', or 'output' +(case-insensitive), which use AdamW. """ from dataclasses import dataclass @@ -30,7 +31,7 @@ class GrugMuonConfig(MuonConfig): Muon optimizer for models that use raw JAX arrays in (fan_in, fan_out) layout. Routing rules: - - 2D arrays whose path does NOT contain 'embed', 'lm_head', or 'output' -> Muon + - 2D and 3D arrays whose path does NOT contain 'embed', 'lm_head', or 'output' -> Muon - Everything else -> AdamW """ @@ -86,7 +87,7 @@ def mask_fn(param, path): path_lower = path_str.lower() if "embed" in path_lower or "lm_head" in path_lower or "output" in path_lower: return "adamw" - elif hasattr(param, "ndim") and param.ndim == 2: + elif hasattr(param, "ndim") and param.ndim in (2, 3): return "muon" else: return "adamw" @@ -123,16 +124,29 @@ def update_fn(updates, state, params=None): updates = buf def transform_array(x): - if not hasattr(x, "ndim") or x.ndim != 2: + if not hasattr(x, "ndim") or x.ndim not in (2, 3): return x - updated = _zeropower_via_newtonschulz(x, steps=steps, eps=muon_eps, coefficient_type=coefficient_type) - # Layout is (fan_in, fan_out) - fan_in, fan_out = updated.shape + from jax.sharding import PartitionSpec as P, reshard + + original_spec = jax.typeof(x).sharding.spec + if x.ndim == 3: + # Keep the first dim's existing sharding, replicate only the last 2 dims for Newton-Schulz + x = reshard(x, P(original_spec[0], None, None)) + updated = jax.vmap( + lambda m: _newtonschulz_core(m, steps=steps, eps=muon_eps, coefficient_type=coefficient_type) + )(x) + # Layout per slice is (fan_in, fan_out) + fan_in, fan_out = updated.shape[1], updated.shape[2] + else: + updated = _zeropower_via_newtonschulz(x, steps=steps, eps=muon_eps, coefficient_type=coefficient_type) + # Layout is (fan_in, fan_out) + fan_in, fan_out = updated.shape if not use_kimi_scaling: scale = jnp.sqrt(jnp.maximum(1, fan_out / fan_in)) else: scale = 0.2 * jnp.sqrt(jnp.maximum(fan_in, fan_out)) updated *= scale + updated = reshard(updated, original_spec) return updated updates = jax.tree.map(transform_array, updates) @@ -142,20 +156,12 @@ def transform_array(x): return optax.GradientTransformation(init_fn, update_fn) -def _zeropower_via_newtonschulz(X, steps: int = 5, eps: float = 1e-7, coefficient_type: CoefficientType = "quintic"): - """Newton-Schulz iteration to orthogonalize X. +def _newtonschulz_core(X, steps: int = 5, eps: float = 1e-7, coefficient_type: CoefficientType = "quintic"): + """Pure Newton-Schulz iteration on a 2D matrix. - Replicates the array across devices before iterating to avoid sharding - ambiguities in the X @ X.T contractions, then reshards back to the - original sharding. + No resharding logic so this is safe to call under vmap. Callers are + responsible for replicating X before calling this function. """ - from jax.sharding import PartitionSpec as P, reshard - - assert X.ndim == 2 - - orig_sharding = X.sharding if hasattr(X, "sharding") else None - X = reshard(X, P(None, None)) - coeffs = NEWTON_SCHULZ_COEFFICIENTS[coefficient_type] X /= jnp.linalg.norm(X) + eps @@ -173,7 +179,17 @@ def _zeropower_via_newtonschulz(X, steps: int = 5, eps: float = 1e-7, coefficien if transpose: X = X.T - if orig_sharding is not None: - X = reshard(X, orig_sharding) - return X + + +def _zeropower_via_newtonschulz(X, steps: int = 5, eps: float = 1e-7, coefficient_type: CoefficientType = "quintic"): + """Newton-Schulz iteration to orthogonalize a 2D matrix. + + Replicates X before iterating to avoid sharding ambiguities in the + X @ X.T contractions. + """ + from jax.sharding import PartitionSpec as P, reshard + + assert X.ndim == 2 + X = reshard(X, P(None, None)) + return _newtonschulz_core(X, steps=steps, eps=eps, coefficient_type=coefficient_type)