From aec55c6b5db12dcea89e924c41943a0893046ece Mon Sep 17 00:00:00 2001 From: Saken Tukenov Date: Fri, 20 Mar 2026 00:06:23 +0500 Subject: [PATCH 1/4] Add draft MLX smoke submission folder --- .../README.md | 52 + .../eval_probe.log | 1151 ++++++++++++++ .../train_gpt_mlx.py | 1104 +++++++++++++ .../train_partial.log | 1360 +++++++++++++++++ 4 files changed, 3667 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/README.md create mode 100644 records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/eval_probe.log create mode 100644 records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_gpt_mlx.py create mode 100644 records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_partial.log diff --git a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/README.md b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/README.md new file mode 100644 index 0000000000..3dce84ef01 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/README.md @@ -0,0 +1,52 @@ +This draft folder documents an in-progress local Apple Silicon reproduction using the current root `train_gpt_mlx.py`. + +This is not a finished non-record submission yet. It exists to show concrete reproduction work, local logs, and the exact script snapshot before moving to a completed cloud-backed run. + +What is included: +- `train_gpt_mlx.py`: exact MLX script snapshot used for the local runs. +- `train_partial.log`: 200-step local smoke run on an Apple M1 with 1 FineWeb train shard and the full fixed validation split. Training completed through step 200 and then entered full validation. +- `eval_probe.log`: follow-up probe using a larger validation batch size to test the local 8GB memory / eval-time tradeoff. + +Local machine: +- Apple M1 MacBook Air +- 8GB unified memory +- Python 3.14 +- MLX 0.31.1 + +Smoke configuration: +- Tokenizer / dataset: `sp1024`, `fineweb10B_sp1024` +- Train shards: `1` +- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` +- Tied embeddings: `TIE_EMBEDDINGS=1` +- Training: `ITERATIONS=200 TRAIN_BATCH_TOKENS=8192` +- Validation: full fixed `fineweb_val_*` split + +Observed outcome so far: +- Local training is stable and reproduces the baseline setup on Apple Silicon. +- Full final validation on an 8GB M1 is much slower than training, so this draft does not yet include a finished `submission.json`. +- The next step is to rerun the same baseline path on cloud GPUs and convert this folder into a completed non-record submission with final `val_bpb`, artifact bytes, and `submission.json`. + +Command used for the main smoke run: +```bash +source .venv/bin/activate +RUN_ID=stukenov_mlx_smoke \ +ITERATIONS=200 \ +TRAIN_BATCH_TOKENS=8192 \ +VAL_LOSS_EVERY=0 \ +VAL_BATCH_SIZE=8192 \ +TRAIN_LOG_EVERY=50 \ +python train_gpt_mlx.py +``` + +Command used for the eval probe: +```bash +source .venv/bin/activate +RUN_ID=stukenov_mlx_probe \ +ITERATIONS=1 \ +TRAIN_BATCH_TOKENS=8192 \ +VAL_LOSS_EVERY=0 \ +VAL_BATCH_SIZE=65536 \ +TRAIN_LOG_EVERY=1 \ +MAX_WALLCLOCK_SECONDS=0 \ +python train_gpt_mlx.py +``` diff --git a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/eval_probe.log b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/eval_probe.log new file mode 100644 index 0000000000..1a9fa26fde --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/eval_probe.log @@ -0,0 +1,1151 @@ +#!/usr/bin/env python3 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" +from __future__ import annotations + +import glob +import json +import math +import os +import pickle +import sys +import time +import uuid +import zlib +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import sentencepiece as spm + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten, tree_unflatten + +# ============================================================================== +# SHARD FORMAT + COMPUTE DTYPE +# ============================================================================== + +COMPUTE_DTYPE = mx.bfloat16 + +# ============================================================================== +# HYPERPARAMETERS +# ============================================================================== +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap +class Hyperparameters: + # Data / tokenizer. + data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed: int = int(os.environ.get("SEED", 1337)) + + # Training loop. These defaults now mirror train_gpt.py on a single process. + iterations: int = int(os.environ.get("ITERATIONS", 20_000)) + val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) + # Validation always uses the full fineweb_val split. + val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) + # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak + # memory pressure without changing the effective optimizer batch. + mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + # Force MLX to materialize the graph after every sub-batch, preventing lazy + # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. + # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). + mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) + warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) + warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) + max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # Model (defaults match the current baseline setup). + vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) + model_dim: int = int(os.environ.get("MODEL_DIM", 512)) + num_heads: int = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) + logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) + qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Optimizer. We keep the same per-group defaults as train_gpt.py. + beta1: float = float(os.environ.get("BETA1", 0.9)) + beta2: float = float(os.environ.get("BETA2", 0.95)) + adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) + tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) + matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + out_dir: str = os.environ.get("OUT_DIR", "logs") + + @property + def train_files(self) -> str: + return f"{self.data_path}/fineweb_train_*.bin" + + @property + def val_files(self) -> str: + return f"{self.data_path}/fineweb_val_*.bin" + + @property + def microbatch_tokens(self) -> int: + return self.train_batch_tokens // self.grad_accum_steps + + def lr_mul(self, step: int, elapsed_ms: float) -> float: + if self.warmdown_iters <= 0: + return 1.0 + if self.max_wallclock_seconds <= 0: + warmdown_start = max(self.iterations - self.warmdown_iters, 0) + return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = self.warmdown_iters * step_ms + remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) + + +def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: + usable_total = (total_tokens // seq_len) * seq_len + if usable_total <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) + chunks: list[int] = [] + remaining = usable_total + while remaining > 0: + chunk = min(remaining, usable_chunk) + chunks.append(chunk) + remaining -= chunk + return chunks + + +def accumulate_flat_grads( + accum: dict[str, mx.array] | None, + grads_tree: dict, + scale: float, +) -> dict[str, mx.array]: + flat = dict(tree_flatten(grads_tree)) + if accum is None: + return {k: g * scale for k, g in flat.items()} + for k, g in flat.items(): + accum[k] = accum[k] + g * scale + return accum + + +# ============================================================================== +# MATH HELPERS +# ============================================================================== + +def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: + return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) + + +def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + # Background on Muon: https://kellerjordan.github.io/posts/muon/ + a, b, c = 3.4445, -4.7750, 2.0315 + x = g.astype(mx.float32) + x = x / (mx.sqrt(mx.sum(x * x)) + eps) + transposed = x.shape[0] > x.shape[1] + if transposed: + x = x.T + for _ in range(steps): + a_mat = x @ x.T + b_mat = b * a_mat + c * (a_mat @ a_mat) + x = a * x + b_mat @ x + if transposed: + x = x.T + return x.astype(g.dtype) + + +def load_data_shard(path: Path) -> np.ndarray: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + if self.file_idx == 0: + self.epoch += 1 + if self.log_fn is not None: + self.log_fn( + f"WARNING: starting epoch:{self.epoch} " + f"dataset:{self.dataset_name} train_shards:{len(self.files)}" + ) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> np.ndarray: + chunks: list[np.ndarray] = [] + left = n + while left > 0: + if self.pos >= self.tokens.size: + self.next_file() + k = min(left, int(self.tokens.size - self.pos)) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + left -= k + return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) + + +class TokenLoader: + def __init__( + self, + pattern: str, + log_fn: Callable[[str], None] | None = None, + dataset_name: str = "", + ): + self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) + + def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: + usable = (batch_tokens // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + chunk = self.stream.take(usable + 1) + x = chunk[:-1].reshape(-1, seq_len) + y = chunk[1:].reshape(-1, seq_len) + return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) + + +# ============================================================================== +# MODEL BLOCKS +# ============================================================================== + +class CastedLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + return x @ self.weight.astype(x.dtype).T + + +class RMSNormNoWeight(nn.Module): + # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. + def __call__(self, x: mx.array) -> mx.array: + return rms_norm(x) + + +class CausalSelfAttention(nn.Module): + # - separate q/k/v projections + # - RMSNorm on q and k before attention + # - RoPE on q and k + # - causal masked SDPA + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim) + self.c_k = CastedLinear(dim, kv_dim) + self.c_v = CastedLinear(dim, kv_dim) + self.proj = CastedLinear(dim, dim) + self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init + self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) + self.scale = self.head_dim ** -0.5 + + def __call__(self, x: mx.array) -> mx.array: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) + k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) + q = q * self.q_gain.astype(q.dtype)[None, :, None, None] + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") + y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = dim * mlp_mult + self.fc = CastedLinear(dim, hidden) + self.proj = CastedLinear(hidden, dim) + + def __call__(self, x: mx.array) -> mx.array: + x = nn.relu(self.fc(x)) + return self.proj(x * x) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNormNoWeight() + self.mlp_norm = RMSNormNoWeight() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = mx.ones((dim,), dtype=mx.float32) + self.mlp_scale = mx.ones((dim,), dtype=mx.float32) + self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) + + def __call__(self, x: mx.array, x0: mx.array) -> mx.array: + mix = self.resid_mix.astype(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + # - token embedding + RMSNorm + # - encoder half accumulates skip tensors + # - decoder half consumes reversed skips with learned skip_weights + # - tied embeddings for the LM head (the baseline default setup) + def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, + qk_gain_init: float): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.logit_chunk_tokens = logit_chunk_tokens + self.logit_softcap = logit_softcap + + self.tok_emb = nn.Embedding(vocab_size, dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) + self.blocks = [ + Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for i in range(num_layers) + ] + self.final_norm = RMSNormNoWeight() + + for b in self.blocks: + b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) + b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) + self.tok_emb.weight = ( + mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std + ).astype(COMPUTE_DTYPE) + + def softcap(self, logits: mx.array) -> mx.array: + c = self.logit_softcap + return c * mx.tanh(logits / c) + + def __call__(self, input_ids: mx.array) -> mx.array: + x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + x0 = x + skips: list[mx.array] = [] + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + # Odd layer counts have one more decoder block than encoder block. The baseline only + # applies a skip connection when one exists, then runs the remaining decoder block(s) + # without an added skip. + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return self.final_norm(x) + + def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful + # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: + logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + + loss_sum = mx.array(0.0, dtype=mx.float32) + n = int(x.shape[0]) + for s in range(0, n, self.logit_chunk_tokens): + e = min(s + self.logit_chunk_tokens, n) + logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") + return loss_sum / float(n) + +# ============================================================================== +# OPTIMIZERS (MUON + ADAM SPLIT) +# ============================================================================== +class Muon: + # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the + # parameter update. + def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): + self.keys = keys + self.args = args + self.buffers = {k: mx.zeros_like(params[k]) for k in keys} + + def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: + if self.args.muon_momentum_warmup_steps: + t = min(step / self.args.muon_momentum_warmup_steps, 1.0) + momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum + else: + momentum = self.args.muon_momentum + lr = self.args.matrix_lr * lr_mul + out: dict[str, mx.array] = {} + for k in self.keys: + p = params[k] + g = grads[k] + buf = momentum * self.buffers[k] + g + self.buffers[k] = buf + g_eff = g + momentum * buf + g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) + scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) + out[k] = p - lr * (g_ortho * scale).astype(p.dtype) + return out + + +class SplitOptimizers: + # - embeddings: Adam with the tied-embedding LR + # - block matrices (2D): Muon + # - block scalars + skip weights: Adam + # This preserves the high-level optimization behavior even though MLX internals differ. + def __init__(self, model: GPT, args: Hyperparameters): + self.args = args + params = dict(tree_flatten(model.parameters())) + self.embed_key = "tok_emb.weight" + self.matrix_keys = [ + k + for k, p in params.items() + if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + self.scalar_keys = [ + k + for k, p in params.items() + if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) + ] + + self.muon = Muon(self.matrix_keys, params, args) + self.adam_embed = optim.Adam( + learning_rate=args.tied_embed_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + self.adam_scalar = optim.Adam( + learning_rate=args.scalar_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + + def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: + params = dict(tree_flatten(model.parameters())) + grads = dict(tree_flatten(grads_tree)) + updated = dict(params) + + updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) + + self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul + updated.update( + self.adam_embed.apply_gradients( + {self.embed_key: grads[self.embed_key]}, + {self.embed_key: params[self.embed_key]}, + ) + ) + + self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul + scalar_grads = {k: grads[k] for k in self.scalar_keys} + scalar_params = {k: params[k] for k in self.scalar_keys} + updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) + + model.update(tree_unflatten(list(updated.items()))) + +# ============================================================================== +# QUANTIZATION (INT8 + ZLIB) +# ============================================================================== +# - per-row int8 for 2D float tensors +# - per-tensor int8 for other float tensors +# - fp16 passthrough for small float tensors +# - exact passthrough for non-floats + +MX_DTYPE_FROM_NAME = { + "float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16, +} + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 +INT8_PER_ROW_SCALE_DTYPE = np.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def _np_float32(arr: mx.array) -> np.ndarray: + return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) + + +def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return np.ascontiguousarray(_np_float32(arr)) + if arr.dtype in {mx.float32, mx.bfloat16}: + passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] + return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) + return np.ascontiguousarray(np.array(arr, copy=True)) + + +def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: + f32 = _np_float32(arr) + if f32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) + q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), scale + + +def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: + quantized: dict[str, np.ndarray] = {} + scales: dict[str, np.ndarray] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, np.ndarray] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, arr in flat_state.items(): + stats["param_count"] += int(arr.size) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += int(arr.nbytes) + if not mx.issubdtype(arr.dtype, mx.floating): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = np.ascontiguousarray(np.array(arr)) + stats["int8_payload_bytes"] += int(passthrough[name].nbytes) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_array(name, arr, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += int(kept.nbytes) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_array(arr) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(arr.dtype).split(".")[-1] + stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: + out: dict[str, mx.array] = {} + qmeta = quant_obj.get("qmeta", {}) + passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) + for name, q in quant_obj["quantized"].items(): + q_np = np.asarray(q, dtype=np.int8) + dtype_name = quant_obj["dtypes"][name] + scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) + if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: + # Broadcast the saved row scale back across trailing dimensions. + out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) + else: + out_arr = q_np.astype(np.float32) * float(scale) + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) + for name, arr in quant_obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_arr = np.array(arr, copy=True) + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) + else: + out[name] = mx.array(out_arr) + return out + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_lut = np.zeros((table_size,), dtype=np.int16) + has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_lut[token_id] = False + if sp.is_byte(token_id): + base_bytes_lut[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_lut[token_id] = True + piece = piece[1:] + base_bytes_lut[token_id] = len(piece.encode("utf-8")) + return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + + +def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: + # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we + # decode bytes with the exact tokenizer that produced the shards. The manifest + # lets the training script fail fast on accidental dataset/tokenizer mismatches. + dataset_dir = Path(data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + if len(dataset_dir.parents) < 2: + return dataset_dir.name, actual_train_files, None + manifest_path = dataset_dir.parents[1] / "manifest.json" + if not manifest_path.is_file(): + return dataset_dir.name, actual_train_files, None + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) + if dataset_entry is None: + return dataset_dir.name, actual_train_files, None + + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = ( + next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_name + else None + ) + expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name + if expected_name and Path(tokenizer_path).name != expected_name: + raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") + expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") + if expected_train_files is not None: + expected_train_files = int(expected_train_files) + if actual_train_files > expected_train_files: + raise ValueError( + f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " + f"manifest says {expected_train_files}" + ) + return dataset_dir.name, actual_train_files, expected_train_files + + +def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) + usable = ((tokens.size - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def loss_and_grad_chunked( + args: Hyperparameters, + train_loader: TokenLoader, + compiled_loss_and_grad, +) -> tuple[mx.array, dict]: + chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) + total_tokens = float(sum(chunk_sizes)) + loss_value = mx.array(0.0, dtype=mx.float32) + grad_accum: dict[str, mx.array] | None = None + for chunk_tokens in chunk_sizes: + x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) + loss, grads = compiled_loss_and_grad(x, y) + scale = float(y.size) / total_tokens + loss_value = loss_value + loss.astype(mx.float32) * scale + grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + if args.mlx_eager_eval: + mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory + return loss_value, tree_unflatten(list(grad_accum.items())) + + +def eval_val( + args: Hyperparameters, + compiled_loss, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, + log_fn: Callable[[str], None] | None = None, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + val_batch_seqs = val_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.size - 1) // args.train_seq_len + total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) + total_loss_sum = 0.0 + total_tokens = 0.0 + total_bytes = 0.0 + for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): + batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + chunk = val_tokens[raw_start:raw_end] + x_np = chunk[:-1].reshape(-1, args.train_seq_len) + y_np = chunk[1:].reshape(-1, args.train_seq_len) + x = mx.array(x_np, dtype=mx.int32) + y = mx.array(y_np, dtype=mx.int32) + chunk_token_count = float(y.size) + batch_loss = compiled_loss(x, y).astype(mx.float32) + mx.eval(batch_loss) + total_loss_sum += float(batch_loss.item()) * chunk_token_count + prev_ids = x_np.reshape(-1) + tgt_ids = y_np.reshape(-1) + bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) + bytes_np += ( + has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + ).astype(np.int16, copy=False) + total_tokens += chunk_token_count + total_bytes += float(bytes_np.astype(np.float64).sum()) + if log_fn is not None and total_batches > 1 and ( + batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 + ): + log_fn(f"val_progress:{batch_idx}/{total_batches}") + val_loss = total_loss_sum / total_tokens + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: + if max_norm <= 0: + return grads_tree + flat = dict(tree_flatten(grads_tree)) + total_sq = 0.0 + for grad in flat.values(): + total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) + if total_sq <= 0.0: + return grads_tree + total_norm = math.sqrt(total_sq) + if total_norm <= max_norm: + return grads_tree + scale = max_norm / (total_norm + 1e-12) + return tree_unflatten([(k, g * scale) for k, g in flat.items()]) + + +def main() -> None: + # ============================================================================== + # TOKENIZER + VALIDATION METRIC SETUP + # ============================================================================== + args = Hyperparameters() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + logfile = out_dir / f"{args.run_id}.txt" + print(logfile) + + def log(msg: str, console: bool = True) -> None: + if console: + print(msg) + with logfile.open("a", encoding="utf-8") as f: + print(msg, file=f) + + code = Path(__file__).read_text(encoding="utf-8") + log(code, console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running MLX {mx.__version__}", console=False) + log("=" * 100, console=False) + + if not args.tie_embeddings: + raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( + args.data_path, + args.tokenizer_path, + ) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size + ) + + # ============================================================================== + # TRAINING SETUP + # ============================================================================== + mx.random.seed(args.seed) + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + # ============================================================================== + # MODEL + OPTIMIZER SETUP + # ============================================================================== + model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + logit_chunk_tokens=args.logit_chunk_tokens, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + tied_embed_init_std=args.tied_embed_init_std, + qk_gain_init=args.qk_gain_init, + ) + opt = SplitOptimizers(model, args) + + # ============================================================================== + # COMPILED TRAIN / EVAL FUNCTIONS (MLX) + # ============================================================================== + # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example + # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". + # Compiling the model-bound functions and capturing the full model state fixes that while still + # returning gradients only for trainable parameters via nn.value_and_grad(...). + compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + inputs=model.state, + outputs=model.state, + ) + + # Print config once so logs are self-describing. + n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) + log(f"run_id:{args.run_id}") + log(f"mlx_version:{mx.__version__}") + log(f"train_loader:shards pattern={args.train_files}") + log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") + if expected_train_files is None: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") + elif actual_train_files < expected_train_files: + log( + f"WARNING: train_loader:subset dataset:{dataset_name} " + f"train_shards:{actual_train_files}/{expected_train_files} " + f"new epochs will arrive sooner than the full dataset" + ) + else: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") + log(f"tokenizer_path:{args.tokenizer_path}") + log( + f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " + f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " + f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" + ) + log( + f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " + f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " + f"val_batch_size:{args.val_batch_size} " + f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") + log( + f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " + f"embed_lr:{args.tied_embed_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" + ) + log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") + log( + f"dtypes tok_emb:{model.tok_emb.weight.dtype} " + f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " + f"skip_weights:{model.skip_weights.dtype}" + ) + + # ============================================================================== + # TRAINING LOOP + # ============================================================================== + if args.warmup_steps > 0: + # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us + # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. + # Instead we run the real train shapes, force the loss/grads to materialize, and then reset + # the loader so measured training still starts from the true init and token window. + for warmup_step in range(args.warmup_steps): + accum: dict[str, mx.array] | None = None + warmup_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + mx.eval(warmup_loss, accum) + mx.synchronize() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + + # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) + warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] + x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) + y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) + warm_val_loss = compiled_loss(x_val, y_val) + mx.eval(warm_val_loss) + mx.synchronize() + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + train_time_ms = 0.0 + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + stop_after_step: int | None = None + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + train_time_ms += 1000.0 * (time.perf_counter() - t0) + # Validation always scans the same fixed full validation split. + val_loss, val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + log_fn=log, + ) + if step % 25 == 0 or last_step: + log( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" + ) + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) + step_t0 = time.perf_counter() + + accum: dict[str, mx.array] | None = None + train_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + train_loss = train_loss + loss.astype(mx.float32) * grad_scale + if args.mlx_eager_eval: + mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory + + grads = tree_unflatten(list(accum.items())) + grads = clip_grad_tree(grads, args.grad_clip_norm) + train_loss_value = float(train_loss.item()) + opt.step(model, grads, step=step, lr_mul=lr_mul) + mx.synchronize() + + step_ms = 1000.0 * (time.perf_counter() - step_t0) + approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) + tok_s = args.train_batch_tokens / (step_ms / 1000.0) + step += 1 + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log( + f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " + f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" + ) + if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: + stop_after_step = step + + # ============================================================================== + # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL + # ============================================================================== + # We always write a raw artifact and a quantized artifact, then validate the + # quantized roundtrip directly by loading the dequantized tensors back into the + # model and running one final validation pass. + out_path = out_dir / f"{args.run_id}_mlx_model.npz" + flat_state = {k: v for k, v in tree_flatten(model.state)} + mx.savez(str(out_path), **flat_state) + log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") + + quant_obj, quant_stats = quantize_state_dict_int8(flat_state) + quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) + quant_blob = zlib.compress(quant_raw, level=9) + quant_serialized_bytes = len(quant_raw) + quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" + with quant_path.open("wb") as f: + f.write(quant_blob) + quant_file_bytes = quant_path.stat().st_size + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log( + f"serialized_model_int8_zlib:{quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" + ) + + with quant_path.open("rb") as f: + quant_blob_disk = f.read() + quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) + model.update(tree_unflatten(list(quant_flat.items()))) + q_t0 = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + log_fn=log, + ) + q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) + log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") + log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.14.2 (main, Dec 5 2025, 16:49:16) [Clang 17.0.0 (clang-1700.4.4.1)] +Running MLX 0.31.1 +==================================================================================================== +run_id:stukenov_mlx_probe +mlx_version:0.31.1 +train_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_train_*.bin +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +WARNING: train_loader:subset dataset:fineweb10B_sp1024 train_shards:1/195 new epochs will arrive sooner than the full dataset +tokenizer_path:./data/tokenizers/fineweb_1024_bpe.model +model_params:17059912 vocab_size:1024 layers:9 dim:512 heads:8 kv_heads:4 seq_len:1024 tie_embeddings:True +iterations:1 train_batch_tokens:8192 grad_accum_steps:8 microbatch_tokens:1024 microbatch_batch_size:1 val_batch_size:65536 warmup_steps:20 max_wallclock_seconds:0.000 +mlx_max_microbatch_tokens:8192 +optimizer:muon+adam muon_matrix_params:54 scalar_params:37 embed_lr:0.05 matrix_lr:0.04 scalar_lr:0.04 muon_momentum:0.95 muon_steps:5 +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +compute_dtype:mlx.core.bfloat16 compile:True +dtypes tok_emb:mlx.core.bfloat16 linear_weight:mlx.core.float32 skip_weights:mlx.core.float32 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/1 train_loss:6.9428 train_time:1651ms step_avg:1651.25ms tok_s:4962 +val_progress:1/7571 +val_progress:25/7571 +val_progress:50/7571 +val_progress:75/7571 +val_progress:100/7571 +val_progress:125/7571 +val_progress:150/7571 +val_progress:175/7571 diff --git a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_gpt_mlx.py b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_gpt_mlx.py new file mode 100644 index 0000000000..7b9e935aa6 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_gpt_mlx.py @@ -0,0 +1,1104 @@ +#!/usr/bin/env python3 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" +from __future__ import annotations + +import glob +import json +import math +import os +import pickle +import sys +import time +import uuid +import zlib +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import sentencepiece as spm + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten, tree_unflatten + +# ============================================================================== +# SHARD FORMAT + COMPUTE DTYPE +# ============================================================================== + +COMPUTE_DTYPE = mx.bfloat16 + +# ============================================================================== +# HYPERPARAMETERS +# ============================================================================== +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap +class Hyperparameters: + # Data / tokenizer. + data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed: int = int(os.environ.get("SEED", 1337)) + + # Training loop. These defaults now mirror train_gpt.py on a single process. + iterations: int = int(os.environ.get("ITERATIONS", 20_000)) + val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) + # Validation always uses the full fineweb_val split. + val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) + # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak + # memory pressure without changing the effective optimizer batch. + mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + # Force MLX to materialize the graph after every sub-batch, preventing lazy + # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. + # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). + mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) + warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) + warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) + max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # Model (defaults match the current baseline setup). + vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) + model_dim: int = int(os.environ.get("MODEL_DIM", 512)) + num_heads: int = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) + logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) + qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Optimizer. We keep the same per-group defaults as train_gpt.py. + beta1: float = float(os.environ.get("BETA1", 0.9)) + beta2: float = float(os.environ.get("BETA2", 0.95)) + adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) + tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) + matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + out_dir: str = os.environ.get("OUT_DIR", "logs") + + @property + def train_files(self) -> str: + return f"{self.data_path}/fineweb_train_*.bin" + + @property + def val_files(self) -> str: + return f"{self.data_path}/fineweb_val_*.bin" + + @property + def microbatch_tokens(self) -> int: + return self.train_batch_tokens // self.grad_accum_steps + + def lr_mul(self, step: int, elapsed_ms: float) -> float: + if self.warmdown_iters <= 0: + return 1.0 + if self.max_wallclock_seconds <= 0: + warmdown_start = max(self.iterations - self.warmdown_iters, 0) + return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = self.warmdown_iters * step_ms + remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) + + +def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: + usable_total = (total_tokens // seq_len) * seq_len + if usable_total <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) + chunks: list[int] = [] + remaining = usable_total + while remaining > 0: + chunk = min(remaining, usable_chunk) + chunks.append(chunk) + remaining -= chunk + return chunks + + +def accumulate_flat_grads( + accum: dict[str, mx.array] | None, + grads_tree: dict, + scale: float, +) -> dict[str, mx.array]: + flat = dict(tree_flatten(grads_tree)) + if accum is None: + return {k: g * scale for k, g in flat.items()} + for k, g in flat.items(): + accum[k] = accum[k] + g * scale + return accum + + +# ============================================================================== +# MATH HELPERS +# ============================================================================== + +def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: + return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) + + +def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + # Background on Muon: https://kellerjordan.github.io/posts/muon/ + a, b, c = 3.4445, -4.7750, 2.0315 + x = g.astype(mx.float32) + x = x / (mx.sqrt(mx.sum(x * x)) + eps) + transposed = x.shape[0] > x.shape[1] + if transposed: + x = x.T + for _ in range(steps): + a_mat = x @ x.T + b_mat = b * a_mat + c * (a_mat @ a_mat) + x = a * x + b_mat @ x + if transposed: + x = x.T + return x.astype(g.dtype) + + +def load_data_shard(path: Path) -> np.ndarray: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + if self.file_idx == 0: + self.epoch += 1 + if self.log_fn is not None: + self.log_fn( + f"WARNING: starting epoch:{self.epoch} " + f"dataset:{self.dataset_name} train_shards:{len(self.files)}" + ) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> np.ndarray: + chunks: list[np.ndarray] = [] + left = n + while left > 0: + if self.pos >= self.tokens.size: + self.next_file() + k = min(left, int(self.tokens.size - self.pos)) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + left -= k + return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) + + +class TokenLoader: + def __init__( + self, + pattern: str, + log_fn: Callable[[str], None] | None = None, + dataset_name: str = "", + ): + self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) + + def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: + usable = (batch_tokens // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + chunk = self.stream.take(usable + 1) + x = chunk[:-1].reshape(-1, seq_len) + y = chunk[1:].reshape(-1, seq_len) + return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) + + +# ============================================================================== +# MODEL BLOCKS +# ============================================================================== + +class CastedLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + return x @ self.weight.astype(x.dtype).T + + +class RMSNormNoWeight(nn.Module): + # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. + def __call__(self, x: mx.array) -> mx.array: + return rms_norm(x) + + +class CausalSelfAttention(nn.Module): + # - separate q/k/v projections + # - RMSNorm on q and k before attention + # - RoPE on q and k + # - causal masked SDPA + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim) + self.c_k = CastedLinear(dim, kv_dim) + self.c_v = CastedLinear(dim, kv_dim) + self.proj = CastedLinear(dim, dim) + self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init + self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) + self.scale = self.head_dim ** -0.5 + + def __call__(self, x: mx.array) -> mx.array: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) + k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) + q = q * self.q_gain.astype(q.dtype)[None, :, None, None] + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") + y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = dim * mlp_mult + self.fc = CastedLinear(dim, hidden) + self.proj = CastedLinear(hidden, dim) + + def __call__(self, x: mx.array) -> mx.array: + x = nn.relu(self.fc(x)) + return self.proj(x * x) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNormNoWeight() + self.mlp_norm = RMSNormNoWeight() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = mx.ones((dim,), dtype=mx.float32) + self.mlp_scale = mx.ones((dim,), dtype=mx.float32) + self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) + + def __call__(self, x: mx.array, x0: mx.array) -> mx.array: + mix = self.resid_mix.astype(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + # - token embedding + RMSNorm + # - encoder half accumulates skip tensors + # - decoder half consumes reversed skips with learned skip_weights + # - tied embeddings for the LM head (the baseline default setup) + def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, + qk_gain_init: float): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.logit_chunk_tokens = logit_chunk_tokens + self.logit_softcap = logit_softcap + + self.tok_emb = nn.Embedding(vocab_size, dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) + self.blocks = [ + Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for i in range(num_layers) + ] + self.final_norm = RMSNormNoWeight() + + for b in self.blocks: + b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) + b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) + self.tok_emb.weight = ( + mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std + ).astype(COMPUTE_DTYPE) + + def softcap(self, logits: mx.array) -> mx.array: + c = self.logit_softcap + return c * mx.tanh(logits / c) + + def __call__(self, input_ids: mx.array) -> mx.array: + x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + x0 = x + skips: list[mx.array] = [] + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + # Odd layer counts have one more decoder block than encoder block. The baseline only + # applies a skip connection when one exists, then runs the remaining decoder block(s) + # without an added skip. + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return self.final_norm(x) + + def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful + # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: + logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + + loss_sum = mx.array(0.0, dtype=mx.float32) + n = int(x.shape[0]) + for s in range(0, n, self.logit_chunk_tokens): + e = min(s + self.logit_chunk_tokens, n) + logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") + return loss_sum / float(n) + +# ============================================================================== +# OPTIMIZERS (MUON + ADAM SPLIT) +# ============================================================================== +class Muon: + # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the + # parameter update. + def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): + self.keys = keys + self.args = args + self.buffers = {k: mx.zeros_like(params[k]) for k in keys} + + def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: + if self.args.muon_momentum_warmup_steps: + t = min(step / self.args.muon_momentum_warmup_steps, 1.0) + momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum + else: + momentum = self.args.muon_momentum + lr = self.args.matrix_lr * lr_mul + out: dict[str, mx.array] = {} + for k in self.keys: + p = params[k] + g = grads[k] + buf = momentum * self.buffers[k] + g + self.buffers[k] = buf + g_eff = g + momentum * buf + g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) + scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) + out[k] = p - lr * (g_ortho * scale).astype(p.dtype) + return out + + +class SplitOptimizers: + # - embeddings: Adam with the tied-embedding LR + # - block matrices (2D): Muon + # - block scalars + skip weights: Adam + # This preserves the high-level optimization behavior even though MLX internals differ. + def __init__(self, model: GPT, args: Hyperparameters): + self.args = args + params = dict(tree_flatten(model.parameters())) + self.embed_key = "tok_emb.weight" + self.matrix_keys = [ + k + for k, p in params.items() + if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + self.scalar_keys = [ + k + for k, p in params.items() + if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) + ] + + self.muon = Muon(self.matrix_keys, params, args) + self.adam_embed = optim.Adam( + learning_rate=args.tied_embed_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + self.adam_scalar = optim.Adam( + learning_rate=args.scalar_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + + def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: + params = dict(tree_flatten(model.parameters())) + grads = dict(tree_flatten(grads_tree)) + updated = dict(params) + + updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) + + self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul + updated.update( + self.adam_embed.apply_gradients( + {self.embed_key: grads[self.embed_key]}, + {self.embed_key: params[self.embed_key]}, + ) + ) + + self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul + scalar_grads = {k: grads[k] for k in self.scalar_keys} + scalar_params = {k: params[k] for k in self.scalar_keys} + updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) + + model.update(tree_unflatten(list(updated.items()))) + +# ============================================================================== +# QUANTIZATION (INT8 + ZLIB) +# ============================================================================== +# - per-row int8 for 2D float tensors +# - per-tensor int8 for other float tensors +# - fp16 passthrough for small float tensors +# - exact passthrough for non-floats + +MX_DTYPE_FROM_NAME = { + "float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16, +} + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 +INT8_PER_ROW_SCALE_DTYPE = np.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def _np_float32(arr: mx.array) -> np.ndarray: + return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) + + +def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return np.ascontiguousarray(_np_float32(arr)) + if arr.dtype in {mx.float32, mx.bfloat16}: + passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] + return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) + return np.ascontiguousarray(np.array(arr, copy=True)) + + +def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: + f32 = _np_float32(arr) + if f32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) + q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), scale + + +def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: + quantized: dict[str, np.ndarray] = {} + scales: dict[str, np.ndarray] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, np.ndarray] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, arr in flat_state.items(): + stats["param_count"] += int(arr.size) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += int(arr.nbytes) + if not mx.issubdtype(arr.dtype, mx.floating): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = np.ascontiguousarray(np.array(arr)) + stats["int8_payload_bytes"] += int(passthrough[name].nbytes) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_array(name, arr, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += int(kept.nbytes) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_array(arr) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(arr.dtype).split(".")[-1] + stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: + out: dict[str, mx.array] = {} + qmeta = quant_obj.get("qmeta", {}) + passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) + for name, q in quant_obj["quantized"].items(): + q_np = np.asarray(q, dtype=np.int8) + dtype_name = quant_obj["dtypes"][name] + scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) + if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: + # Broadcast the saved row scale back across trailing dimensions. + out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) + else: + out_arr = q_np.astype(np.float32) * float(scale) + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) + for name, arr in quant_obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_arr = np.array(arr, copy=True) + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) + else: + out[name] = mx.array(out_arr) + return out + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_lut = np.zeros((table_size,), dtype=np.int16) + has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_lut[token_id] = False + if sp.is_byte(token_id): + base_bytes_lut[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_lut[token_id] = True + piece = piece[1:] + base_bytes_lut[token_id] = len(piece.encode("utf-8")) + return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + + +def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: + # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we + # decode bytes with the exact tokenizer that produced the shards. The manifest + # lets the training script fail fast on accidental dataset/tokenizer mismatches. + dataset_dir = Path(data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + if len(dataset_dir.parents) < 2: + return dataset_dir.name, actual_train_files, None + manifest_path = dataset_dir.parents[1] / "manifest.json" + if not manifest_path.is_file(): + return dataset_dir.name, actual_train_files, None + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) + if dataset_entry is None: + return dataset_dir.name, actual_train_files, None + + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = ( + next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_name + else None + ) + expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name + if expected_name and Path(tokenizer_path).name != expected_name: + raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") + expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") + if expected_train_files is not None: + expected_train_files = int(expected_train_files) + if actual_train_files > expected_train_files: + raise ValueError( + f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " + f"manifest says {expected_train_files}" + ) + return dataset_dir.name, actual_train_files, expected_train_files + + +def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) + usable = ((tokens.size - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def loss_and_grad_chunked( + args: Hyperparameters, + train_loader: TokenLoader, + compiled_loss_and_grad, +) -> tuple[mx.array, dict]: + chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) + total_tokens = float(sum(chunk_sizes)) + loss_value = mx.array(0.0, dtype=mx.float32) + grad_accum: dict[str, mx.array] | None = None + for chunk_tokens in chunk_sizes: + x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) + loss, grads = compiled_loss_and_grad(x, y) + scale = float(y.size) / total_tokens + loss_value = loss_value + loss.astype(mx.float32) * scale + grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + if args.mlx_eager_eval: + mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory + return loss_value, tree_unflatten(list(grad_accum.items())) + + +def eval_val( + args: Hyperparameters, + compiled_loss, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, + log_fn: Callable[[str], None] | None = None, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + val_batch_seqs = val_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.size - 1) // args.train_seq_len + total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) + total_loss_sum = 0.0 + total_tokens = 0.0 + total_bytes = 0.0 + for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): + batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + chunk = val_tokens[raw_start:raw_end] + x_np = chunk[:-1].reshape(-1, args.train_seq_len) + y_np = chunk[1:].reshape(-1, args.train_seq_len) + x = mx.array(x_np, dtype=mx.int32) + y = mx.array(y_np, dtype=mx.int32) + chunk_token_count = float(y.size) + batch_loss = compiled_loss(x, y).astype(mx.float32) + mx.eval(batch_loss) + total_loss_sum += float(batch_loss.item()) * chunk_token_count + prev_ids = x_np.reshape(-1) + tgt_ids = y_np.reshape(-1) + bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) + bytes_np += ( + has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + ).astype(np.int16, copy=False) + total_tokens += chunk_token_count + total_bytes += float(bytes_np.astype(np.float64).sum()) + if log_fn is not None and total_batches > 1 and ( + batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 + ): + log_fn(f"val_progress:{batch_idx}/{total_batches}") + val_loss = total_loss_sum / total_tokens + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: + if max_norm <= 0: + return grads_tree + flat = dict(tree_flatten(grads_tree)) + total_sq = 0.0 + for grad in flat.values(): + total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) + if total_sq <= 0.0: + return grads_tree + total_norm = math.sqrt(total_sq) + if total_norm <= max_norm: + return grads_tree + scale = max_norm / (total_norm + 1e-12) + return tree_unflatten([(k, g * scale) for k, g in flat.items()]) + + +def main() -> None: + # ============================================================================== + # TOKENIZER + VALIDATION METRIC SETUP + # ============================================================================== + args = Hyperparameters() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + logfile = out_dir / f"{args.run_id}.txt" + print(logfile) + + def log(msg: str, console: bool = True) -> None: + if console: + print(msg) + with logfile.open("a", encoding="utf-8") as f: + print(msg, file=f) + + code = Path(__file__).read_text(encoding="utf-8") + log(code, console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running MLX {mx.__version__}", console=False) + log("=" * 100, console=False) + + if not args.tie_embeddings: + raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( + args.data_path, + args.tokenizer_path, + ) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size + ) + + # ============================================================================== + # TRAINING SETUP + # ============================================================================== + mx.random.seed(args.seed) + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + # ============================================================================== + # MODEL + OPTIMIZER SETUP + # ============================================================================== + model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + logit_chunk_tokens=args.logit_chunk_tokens, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + tied_embed_init_std=args.tied_embed_init_std, + qk_gain_init=args.qk_gain_init, + ) + opt = SplitOptimizers(model, args) + + # ============================================================================== + # COMPILED TRAIN / EVAL FUNCTIONS (MLX) + # ============================================================================== + # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example + # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". + # Compiling the model-bound functions and capturing the full model state fixes that while still + # returning gradients only for trainable parameters via nn.value_and_grad(...). + compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + inputs=model.state, + outputs=model.state, + ) + + # Print config once so logs are self-describing. + n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) + log(f"run_id:{args.run_id}") + log(f"mlx_version:{mx.__version__}") + log(f"train_loader:shards pattern={args.train_files}") + log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") + if expected_train_files is None: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") + elif actual_train_files < expected_train_files: + log( + f"WARNING: train_loader:subset dataset:{dataset_name} " + f"train_shards:{actual_train_files}/{expected_train_files} " + f"new epochs will arrive sooner than the full dataset" + ) + else: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") + log(f"tokenizer_path:{args.tokenizer_path}") + log( + f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " + f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " + f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" + ) + log( + f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " + f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " + f"val_batch_size:{args.val_batch_size} " + f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") + log( + f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " + f"embed_lr:{args.tied_embed_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" + ) + log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") + log( + f"dtypes tok_emb:{model.tok_emb.weight.dtype} " + f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " + f"skip_weights:{model.skip_weights.dtype}" + ) + + # ============================================================================== + # TRAINING LOOP + # ============================================================================== + if args.warmup_steps > 0: + # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us + # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. + # Instead we run the real train shapes, force the loss/grads to materialize, and then reset + # the loader so measured training still starts from the true init and token window. + for warmup_step in range(args.warmup_steps): + accum: dict[str, mx.array] | None = None + warmup_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + mx.eval(warmup_loss, accum) + mx.synchronize() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + + # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) + warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] + x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) + y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) + warm_val_loss = compiled_loss(x_val, y_val) + mx.eval(warm_val_loss) + mx.synchronize() + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + train_time_ms = 0.0 + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + stop_after_step: int | None = None + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + train_time_ms += 1000.0 * (time.perf_counter() - t0) + # Validation always scans the same fixed full validation split. + val_loss, val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + log_fn=log, + ) + if step % 25 == 0 or last_step: + log( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" + ) + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) + step_t0 = time.perf_counter() + + accum: dict[str, mx.array] | None = None + train_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + train_loss = train_loss + loss.astype(mx.float32) * grad_scale + if args.mlx_eager_eval: + mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory + + grads = tree_unflatten(list(accum.items())) + grads = clip_grad_tree(grads, args.grad_clip_norm) + train_loss_value = float(train_loss.item()) + opt.step(model, grads, step=step, lr_mul=lr_mul) + mx.synchronize() + + step_ms = 1000.0 * (time.perf_counter() - step_t0) + approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) + tok_s = args.train_batch_tokens / (step_ms / 1000.0) + step += 1 + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log( + f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " + f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" + ) + if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: + stop_after_step = step + + # ============================================================================== + # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL + # ============================================================================== + # We always write a raw artifact and a quantized artifact, then validate the + # quantized roundtrip directly by loading the dequantized tensors back into the + # model and running one final validation pass. + out_path = out_dir / f"{args.run_id}_mlx_model.npz" + flat_state = {k: v for k, v in tree_flatten(model.state)} + mx.savez(str(out_path), **flat_state) + log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") + + quant_obj, quant_stats = quantize_state_dict_int8(flat_state) + quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) + quant_blob = zlib.compress(quant_raw, level=9) + quant_serialized_bytes = len(quant_raw) + quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" + with quant_path.open("wb") as f: + f.write(quant_blob) + quant_file_bytes = quant_path.stat().st_size + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log( + f"serialized_model_int8_zlib:{quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" + ) + + with quant_path.open("rb") as f: + quant_blob_disk = f.read() + quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) + model.update(tree_unflatten(list(quant_flat.items()))) + q_t0 = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + log_fn=log, + ) + q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) + log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") + log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_partial.log b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_partial.log new file mode 100644 index 0000000000..91cec65c1e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_partial.log @@ -0,0 +1,1360 @@ +#!/usr/bin/env python3 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" +from __future__ import annotations + +import glob +import json +import math +import os +import pickle +import sys +import time +import uuid +import zlib +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import sentencepiece as spm + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten, tree_unflatten + +# ============================================================================== +# SHARD FORMAT + COMPUTE DTYPE +# ============================================================================== + +COMPUTE_DTYPE = mx.bfloat16 + +# ============================================================================== +# HYPERPARAMETERS +# ============================================================================== +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap +class Hyperparameters: + # Data / tokenizer. + data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed: int = int(os.environ.get("SEED", 1337)) + + # Training loop. These defaults now mirror train_gpt.py on a single process. + iterations: int = int(os.environ.get("ITERATIONS", 20_000)) + val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) + # Validation always uses the full fineweb_val split. + val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) + # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak + # memory pressure without changing the effective optimizer batch. + mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + # Force MLX to materialize the graph after every sub-batch, preventing lazy + # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. + # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). + mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) + warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) + warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) + max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # Model (defaults match the current baseline setup). + vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) + model_dim: int = int(os.environ.get("MODEL_DIM", 512)) + num_heads: int = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) + logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) + qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Optimizer. We keep the same per-group defaults as train_gpt.py. + beta1: float = float(os.environ.get("BETA1", 0.9)) + beta2: float = float(os.environ.get("BETA2", 0.95)) + adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) + tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) + matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + out_dir: str = os.environ.get("OUT_DIR", "logs") + + @property + def train_files(self) -> str: + return f"{self.data_path}/fineweb_train_*.bin" + + @property + def val_files(self) -> str: + return f"{self.data_path}/fineweb_val_*.bin" + + @property + def microbatch_tokens(self) -> int: + return self.train_batch_tokens // self.grad_accum_steps + + def lr_mul(self, step: int, elapsed_ms: float) -> float: + if self.warmdown_iters <= 0: + return 1.0 + if self.max_wallclock_seconds <= 0: + warmdown_start = max(self.iterations - self.warmdown_iters, 0) + return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = self.warmdown_iters * step_ms + remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) + + +def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: + usable_total = (total_tokens // seq_len) * seq_len + if usable_total <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) + chunks: list[int] = [] + remaining = usable_total + while remaining > 0: + chunk = min(remaining, usable_chunk) + chunks.append(chunk) + remaining -= chunk + return chunks + + +def accumulate_flat_grads( + accum: dict[str, mx.array] | None, + grads_tree: dict, + scale: float, +) -> dict[str, mx.array]: + flat = dict(tree_flatten(grads_tree)) + if accum is None: + return {k: g * scale for k, g in flat.items()} + for k, g in flat.items(): + accum[k] = accum[k] + g * scale + return accum + + +# ============================================================================== +# MATH HELPERS +# ============================================================================== + +def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: + return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) + + +def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + # Background on Muon: https://kellerjordan.github.io/posts/muon/ + a, b, c = 3.4445, -4.7750, 2.0315 + x = g.astype(mx.float32) + x = x / (mx.sqrt(mx.sum(x * x)) + eps) + transposed = x.shape[0] > x.shape[1] + if transposed: + x = x.T + for _ in range(steps): + a_mat = x @ x.T + b_mat = b * a_mat + c * (a_mat @ a_mat) + x = a * x + b_mat @ x + if transposed: + x = x.T + return x.astype(g.dtype) + + +def load_data_shard(path: Path) -> np.ndarray: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + if self.file_idx == 0: + self.epoch += 1 + if self.log_fn is not None: + self.log_fn( + f"WARNING: starting epoch:{self.epoch} " + f"dataset:{self.dataset_name} train_shards:{len(self.files)}" + ) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> np.ndarray: + chunks: list[np.ndarray] = [] + left = n + while left > 0: + if self.pos >= self.tokens.size: + self.next_file() + k = min(left, int(self.tokens.size - self.pos)) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + left -= k + return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) + + +class TokenLoader: + def __init__( + self, + pattern: str, + log_fn: Callable[[str], None] | None = None, + dataset_name: str = "", + ): + self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) + + def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: + usable = (batch_tokens // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + chunk = self.stream.take(usable + 1) + x = chunk[:-1].reshape(-1, seq_len) + y = chunk[1:].reshape(-1, seq_len) + return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) + + +# ============================================================================== +# MODEL BLOCKS +# ============================================================================== + +class CastedLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + return x @ self.weight.astype(x.dtype).T + + +class RMSNormNoWeight(nn.Module): + # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. + def __call__(self, x: mx.array) -> mx.array: + return rms_norm(x) + + +class CausalSelfAttention(nn.Module): + # - separate q/k/v projections + # - RMSNorm on q and k before attention + # - RoPE on q and k + # - causal masked SDPA + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim) + self.c_k = CastedLinear(dim, kv_dim) + self.c_v = CastedLinear(dim, kv_dim) + self.proj = CastedLinear(dim, dim) + self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init + self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) + self.scale = self.head_dim ** -0.5 + + def __call__(self, x: mx.array) -> mx.array: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) + k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) + q = q * self.q_gain.astype(q.dtype)[None, :, None, None] + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") + y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = dim * mlp_mult + self.fc = CastedLinear(dim, hidden) + self.proj = CastedLinear(hidden, dim) + + def __call__(self, x: mx.array) -> mx.array: + x = nn.relu(self.fc(x)) + return self.proj(x * x) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNormNoWeight() + self.mlp_norm = RMSNormNoWeight() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = mx.ones((dim,), dtype=mx.float32) + self.mlp_scale = mx.ones((dim,), dtype=mx.float32) + self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) + + def __call__(self, x: mx.array, x0: mx.array) -> mx.array: + mix = self.resid_mix.astype(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + # - token embedding + RMSNorm + # - encoder half accumulates skip tensors + # - decoder half consumes reversed skips with learned skip_weights + # - tied embeddings for the LM head (the baseline default setup) + def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, + qk_gain_init: float): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.logit_chunk_tokens = logit_chunk_tokens + self.logit_softcap = logit_softcap + + self.tok_emb = nn.Embedding(vocab_size, dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) + self.blocks = [ + Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for i in range(num_layers) + ] + self.final_norm = RMSNormNoWeight() + + for b in self.blocks: + b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) + b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) + self.tok_emb.weight = ( + mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std + ).astype(COMPUTE_DTYPE) + + def softcap(self, logits: mx.array) -> mx.array: + c = self.logit_softcap + return c * mx.tanh(logits / c) + + def __call__(self, input_ids: mx.array) -> mx.array: + x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + x0 = x + skips: list[mx.array] = [] + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + # Odd layer counts have one more decoder block than encoder block. The baseline only + # applies a skip connection when one exists, then runs the remaining decoder block(s) + # without an added skip. + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return self.final_norm(x) + + def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful + # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: + logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + + loss_sum = mx.array(0.0, dtype=mx.float32) + n = int(x.shape[0]) + for s in range(0, n, self.logit_chunk_tokens): + e = min(s + self.logit_chunk_tokens, n) + logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") + return loss_sum / float(n) + +# ============================================================================== +# OPTIMIZERS (MUON + ADAM SPLIT) +# ============================================================================== +class Muon: + # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the + # parameter update. + def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): + self.keys = keys + self.args = args + self.buffers = {k: mx.zeros_like(params[k]) for k in keys} + + def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: + if self.args.muon_momentum_warmup_steps: + t = min(step / self.args.muon_momentum_warmup_steps, 1.0) + momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum + else: + momentum = self.args.muon_momentum + lr = self.args.matrix_lr * lr_mul + out: dict[str, mx.array] = {} + for k in self.keys: + p = params[k] + g = grads[k] + buf = momentum * self.buffers[k] + g + self.buffers[k] = buf + g_eff = g + momentum * buf + g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) + scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) + out[k] = p - lr * (g_ortho * scale).astype(p.dtype) + return out + + +class SplitOptimizers: + # - embeddings: Adam with the tied-embedding LR + # - block matrices (2D): Muon + # - block scalars + skip weights: Adam + # This preserves the high-level optimization behavior even though MLX internals differ. + def __init__(self, model: GPT, args: Hyperparameters): + self.args = args + params = dict(tree_flatten(model.parameters())) + self.embed_key = "tok_emb.weight" + self.matrix_keys = [ + k + for k, p in params.items() + if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + self.scalar_keys = [ + k + for k, p in params.items() + if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) + ] + + self.muon = Muon(self.matrix_keys, params, args) + self.adam_embed = optim.Adam( + learning_rate=args.tied_embed_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + self.adam_scalar = optim.Adam( + learning_rate=args.scalar_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + + def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: + params = dict(tree_flatten(model.parameters())) + grads = dict(tree_flatten(grads_tree)) + updated = dict(params) + + updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) + + self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul + updated.update( + self.adam_embed.apply_gradients( + {self.embed_key: grads[self.embed_key]}, + {self.embed_key: params[self.embed_key]}, + ) + ) + + self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul + scalar_grads = {k: grads[k] for k in self.scalar_keys} + scalar_params = {k: params[k] for k in self.scalar_keys} + updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) + + model.update(tree_unflatten(list(updated.items()))) + +# ============================================================================== +# QUANTIZATION (INT8 + ZLIB) +# ============================================================================== +# - per-row int8 for 2D float tensors +# - per-tensor int8 for other float tensors +# - fp16 passthrough for small float tensors +# - exact passthrough for non-floats + +MX_DTYPE_FROM_NAME = { + "float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16, +} + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 +INT8_PER_ROW_SCALE_DTYPE = np.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def _np_float32(arr: mx.array) -> np.ndarray: + return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) + + +def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return np.ascontiguousarray(_np_float32(arr)) + if arr.dtype in {mx.float32, mx.bfloat16}: + passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] + return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) + return np.ascontiguousarray(np.array(arr, copy=True)) + + +def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: + f32 = _np_float32(arr) + if f32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) + q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), scale + + +def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: + quantized: dict[str, np.ndarray] = {} + scales: dict[str, np.ndarray] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, np.ndarray] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, arr in flat_state.items(): + stats["param_count"] += int(arr.size) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += int(arr.nbytes) + if not mx.issubdtype(arr.dtype, mx.floating): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = np.ascontiguousarray(np.array(arr)) + stats["int8_payload_bytes"] += int(passthrough[name].nbytes) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_array(name, arr, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += int(kept.nbytes) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_array(arr) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(arr.dtype).split(".")[-1] + stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: + out: dict[str, mx.array] = {} + qmeta = quant_obj.get("qmeta", {}) + passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) + for name, q in quant_obj["quantized"].items(): + q_np = np.asarray(q, dtype=np.int8) + dtype_name = quant_obj["dtypes"][name] + scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) + if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: + # Broadcast the saved row scale back across trailing dimensions. + out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) + else: + out_arr = q_np.astype(np.float32) * float(scale) + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) + for name, arr in quant_obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_arr = np.array(arr, copy=True) + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) + else: + out[name] = mx.array(out_arr) + return out + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_lut = np.zeros((table_size,), dtype=np.int16) + has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_lut[token_id] = False + if sp.is_byte(token_id): + base_bytes_lut[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_lut[token_id] = True + piece = piece[1:] + base_bytes_lut[token_id] = len(piece.encode("utf-8")) + return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + + +def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: + # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we + # decode bytes with the exact tokenizer that produced the shards. The manifest + # lets the training script fail fast on accidental dataset/tokenizer mismatches. + dataset_dir = Path(data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + if len(dataset_dir.parents) < 2: + return dataset_dir.name, actual_train_files, None + manifest_path = dataset_dir.parents[1] / "manifest.json" + if not manifest_path.is_file(): + return dataset_dir.name, actual_train_files, None + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) + if dataset_entry is None: + return dataset_dir.name, actual_train_files, None + + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = ( + next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_name + else None + ) + expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name + if expected_name and Path(tokenizer_path).name != expected_name: + raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") + expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") + if expected_train_files is not None: + expected_train_files = int(expected_train_files) + if actual_train_files > expected_train_files: + raise ValueError( + f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " + f"manifest says {expected_train_files}" + ) + return dataset_dir.name, actual_train_files, expected_train_files + + +def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) + usable = ((tokens.size - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def loss_and_grad_chunked( + args: Hyperparameters, + train_loader: TokenLoader, + compiled_loss_and_grad, +) -> tuple[mx.array, dict]: + chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) + total_tokens = float(sum(chunk_sizes)) + loss_value = mx.array(0.0, dtype=mx.float32) + grad_accum: dict[str, mx.array] | None = None + for chunk_tokens in chunk_sizes: + x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) + loss, grads = compiled_loss_and_grad(x, y) + scale = float(y.size) / total_tokens + loss_value = loss_value + loss.astype(mx.float32) * scale + grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + if args.mlx_eager_eval: + mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory + return loss_value, tree_unflatten(list(grad_accum.items())) + + +def eval_val( + args: Hyperparameters, + compiled_loss, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, + log_fn: Callable[[str], None] | None = None, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + val_batch_seqs = val_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.size - 1) // args.train_seq_len + total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) + total_loss_sum = 0.0 + total_tokens = 0.0 + total_bytes = 0.0 + for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): + batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + chunk = val_tokens[raw_start:raw_end] + x_np = chunk[:-1].reshape(-1, args.train_seq_len) + y_np = chunk[1:].reshape(-1, args.train_seq_len) + x = mx.array(x_np, dtype=mx.int32) + y = mx.array(y_np, dtype=mx.int32) + chunk_token_count = float(y.size) + batch_loss = compiled_loss(x, y).astype(mx.float32) + mx.eval(batch_loss) + total_loss_sum += float(batch_loss.item()) * chunk_token_count + prev_ids = x_np.reshape(-1) + tgt_ids = y_np.reshape(-1) + bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) + bytes_np += ( + has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + ).astype(np.int16, copy=False) + total_tokens += chunk_token_count + total_bytes += float(bytes_np.astype(np.float64).sum()) + if log_fn is not None and total_batches > 1 and ( + batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 + ): + log_fn(f"val_progress:{batch_idx}/{total_batches}") + val_loss = total_loss_sum / total_tokens + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: + if max_norm <= 0: + return grads_tree + flat = dict(tree_flatten(grads_tree)) + total_sq = 0.0 + for grad in flat.values(): + total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) + if total_sq <= 0.0: + return grads_tree + total_norm = math.sqrt(total_sq) + if total_norm <= max_norm: + return grads_tree + scale = max_norm / (total_norm + 1e-12) + return tree_unflatten([(k, g * scale) for k, g in flat.items()]) + + +def main() -> None: + # ============================================================================== + # TOKENIZER + VALIDATION METRIC SETUP + # ============================================================================== + args = Hyperparameters() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + logfile = out_dir / f"{args.run_id}.txt" + print(logfile) + + def log(msg: str, console: bool = True) -> None: + if console: + print(msg) + with logfile.open("a", encoding="utf-8") as f: + print(msg, file=f) + + code = Path(__file__).read_text(encoding="utf-8") + log(code, console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running MLX {mx.__version__}", console=False) + log("=" * 100, console=False) + + if not args.tie_embeddings: + raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( + args.data_path, + args.tokenizer_path, + ) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size + ) + + # ============================================================================== + # TRAINING SETUP + # ============================================================================== + mx.random.seed(args.seed) + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + # ============================================================================== + # MODEL + OPTIMIZER SETUP + # ============================================================================== + model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + logit_chunk_tokens=args.logit_chunk_tokens, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + tied_embed_init_std=args.tied_embed_init_std, + qk_gain_init=args.qk_gain_init, + ) + opt = SplitOptimizers(model, args) + + # ============================================================================== + # COMPILED TRAIN / EVAL FUNCTIONS (MLX) + # ============================================================================== + # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example + # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". + # Compiling the model-bound functions and capturing the full model state fixes that while still + # returning gradients only for trainable parameters via nn.value_and_grad(...). + compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + inputs=model.state, + outputs=model.state, + ) + + # Print config once so logs are self-describing. + n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) + log(f"run_id:{args.run_id}") + log(f"mlx_version:{mx.__version__}") + log(f"train_loader:shards pattern={args.train_files}") + log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") + if expected_train_files is None: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") + elif actual_train_files < expected_train_files: + log( + f"WARNING: train_loader:subset dataset:{dataset_name} " + f"train_shards:{actual_train_files}/{expected_train_files} " + f"new epochs will arrive sooner than the full dataset" + ) + else: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") + log(f"tokenizer_path:{args.tokenizer_path}") + log( + f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " + f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " + f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" + ) + log( + f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " + f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " + f"val_batch_size:{args.val_batch_size} " + f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") + log( + f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " + f"embed_lr:{args.tied_embed_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" + ) + log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") + log( + f"dtypes tok_emb:{model.tok_emb.weight.dtype} " + f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " + f"skip_weights:{model.skip_weights.dtype}" + ) + + # ============================================================================== + # TRAINING LOOP + # ============================================================================== + if args.warmup_steps > 0: + # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us + # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. + # Instead we run the real train shapes, force the loss/grads to materialize, and then reset + # the loader so measured training still starts from the true init and token window. + for warmup_step in range(args.warmup_steps): + accum: dict[str, mx.array] | None = None + warmup_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + mx.eval(warmup_loss, accum) + mx.synchronize() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + + # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) + warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] + x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) + y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) + warm_val_loss = compiled_loss(x_val, y_val) + mx.eval(warm_val_loss) + mx.synchronize() + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + train_time_ms = 0.0 + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + stop_after_step: int | None = None + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + train_time_ms += 1000.0 * (time.perf_counter() - t0) + # Validation always scans the same fixed full validation split. + val_loss, val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + log_fn=log, + ) + if step % 25 == 0 or last_step: + log( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" + ) + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) + step_t0 = time.perf_counter() + + accum: dict[str, mx.array] | None = None + train_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + train_loss = train_loss + loss.astype(mx.float32) * grad_scale + if args.mlx_eager_eval: + mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory + + grads = tree_unflatten(list(accum.items())) + grads = clip_grad_tree(grads, args.grad_clip_norm) + train_loss_value = float(train_loss.item()) + opt.step(model, grads, step=step, lr_mul=lr_mul) + mx.synchronize() + + step_ms = 1000.0 * (time.perf_counter() - step_t0) + approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) + tok_s = args.train_batch_tokens / (step_ms / 1000.0) + step += 1 + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log( + f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " + f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" + ) + if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: + stop_after_step = step + + # ============================================================================== + # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL + # ============================================================================== + # We always write a raw artifact and a quantized artifact, then validate the + # quantized roundtrip directly by loading the dequantized tensors back into the + # model and running one final validation pass. + out_path = out_dir / f"{args.run_id}_mlx_model.npz" + flat_state = {k: v for k, v in tree_flatten(model.state)} + mx.savez(str(out_path), **flat_state) + log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") + + quant_obj, quant_stats = quantize_state_dict_int8(flat_state) + quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) + quant_blob = zlib.compress(quant_raw, level=9) + quant_serialized_bytes = len(quant_raw) + quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" + with quant_path.open("wb") as f: + f.write(quant_blob) + quant_file_bytes = quant_path.stat().st_size + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log( + f"serialized_model_int8_zlib:{quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" + ) + + with quant_path.open("rb") as f: + quant_blob_disk = f.read() + quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) + model.update(tree_unflatten(list(quant_flat.items()))) + q_t0 = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + log_fn=log, + ) + q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) + log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") + log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.14.2 (main, Dec 5 2025, 16:49:16) [Clang 17.0.0 (clang-1700.4.4.1)] +Running MLX 0.31.1 +==================================================================================================== +run_id:stukenov_mlx_smoke +mlx_version:0.31.1 +train_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_train_*.bin +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +WARNING: train_loader:subset dataset:fineweb10B_sp1024 train_shards:1/195 new epochs will arrive sooner than the full dataset +tokenizer_path:./data/tokenizers/fineweb_1024_bpe.model +model_params:17059912 vocab_size:1024 layers:9 dim:512 heads:8 kv_heads:4 seq_len:1024 tie_embeddings:True +iterations:200 train_batch_tokens:8192 grad_accum_steps:8 microbatch_tokens:1024 microbatch_batch_size:1 val_batch_size:8192 warmup_steps:20 max_wallclock_seconds:600.000 +mlx_max_microbatch_tokens:8192 +optimizer:muon+adam muon_matrix_params:54 scalar_params:37 embed_lr:0.05 matrix_lr:0.04 scalar_lr:0.04 muon_momentum:0.95 muon_steps:5 +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +compute_dtype:mlx.core.bfloat16 compile:True +dtypes tok_emb:mlx.core.bfloat16 linear_weight:mlx.core.float32 skip_weights:mlx.core.float32 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/200 train_loss:6.9428 train_time:1474ms step_avg:1474.19ms tok_s:5557 +step:2/200 train_loss:18.7855 train_time:3330ms step_avg:1665.02ms tok_s:4416 +step:3/200 train_loss:16.6689 train_time:5029ms step_avg:1676.42ms tok_s:4822 +step:4/200 train_loss:14.0871 train_time:6728ms step_avg:1682.08ms tok_s:4823 +step:5/200 train_loss:11.7481 train_time:8424ms step_avg:1684.83ms tok_s:4832 +step:6/200 train_loss:9.4689 train_time:10150ms step_avg:1691.60ms tok_s:4748 +step:7/200 train_loss:7.9685 train_time:11838ms step_avg:1691.14ms tok_s:4854 +step:8/200 train_loss:7.1412 train_time:13515ms step_avg:1689.41ms tok_s:4885 +step:9/200 train_loss:6.6587 train_time:15235ms step_avg:1692.74ms tok_s:4765 +step:10/200 train_loss:6.2999 train_time:16919ms step_avg:1691.86ms tok_s:4866 +step:50/200 train_loss:4.8429 train_time:84371ms step_avg:1687.41ms tok_s:4838 +step:100/200 train_loss:4.4248 train_time:168819ms step_avg:1688.19ms tok_s:4831 +step:150/200 train_loss:4.1391 train_time:271291ms step_avg:1808.61ms tok_s:4804 +step:200/200 train_loss:3.9233 train_time:362934ms step_avg:1814.67ms tok_s:4182 +val_progress:1/60568 +val_progress:25/60568 +val_progress:50/60568 +val_progress:75/60568 +val_progress:100/60568 +val_progress:125/60568 +val_progress:150/60568 +val_progress:175/60568 +val_progress:200/60568 +val_progress:225/60568 +val_progress:250/60568 +val_progress:275/60568 +val_progress:300/60568 +val_progress:325/60568 +val_progress:350/60568 +val_progress:375/60568 +val_progress:400/60568 +val_progress:425/60568 +val_progress:450/60568 +val_progress:475/60568 +val_progress:500/60568 +val_progress:525/60568 +val_progress:550/60568 +val_progress:575/60568 +val_progress:600/60568 +val_progress:625/60568 +val_progress:650/60568 +val_progress:675/60568 +val_progress:700/60568 +val_progress:725/60568 +val_progress:750/60568 +val_progress:775/60568 +val_progress:800/60568 +val_progress:825/60568 +val_progress:850/60568 +val_progress:875/60568 +val_progress:900/60568 +val_progress:925/60568 +val_progress:950/60568 +val_progress:975/60568 +val_progress:1000/60568 +val_progress:1025/60568 +val_progress:1050/60568 +val_progress:1075/60568 +val_progress:1100/60568 +val_progress:1125/60568 +val_progress:1150/60568 +val_progress:1175/60568 +val_progress:1200/60568 +val_progress:1225/60568 +val_progress:1250/60568 +val_progress:1275/60568 +val_progress:1300/60568 +val_progress:1325/60568 +val_progress:1350/60568 +val_progress:1375/60568 +val_progress:1400/60568 +val_progress:1425/60568 +val_progress:1450/60568 +val_progress:1475/60568 +val_progress:1500/60568 +val_progress:1525/60568 +val_progress:1550/60568 +val_progress:1575/60568 +val_progress:1600/60568 +val_progress:1625/60568 +val_progress:1650/60568 +val_progress:1675/60568 +val_progress:1700/60568 +val_progress:1725/60568 +val_progress:1750/60568 +val_progress:1775/60568 +val_progress:1800/60568 +val_progress:1825/60568 +val_progress:1850/60568 +val_progress:1875/60568 +val_progress:1900/60568 +val_progress:1925/60568 +val_progress:1950/60568 +val_progress:1975/60568 +val_progress:2000/60568 +val_progress:2025/60568 +val_progress:2050/60568 +val_progress:2075/60568 +val_progress:2100/60568 +val_progress:2125/60568 +val_progress:2150/60568 +val_progress:2175/60568 +val_progress:2200/60568 +val_progress:2225/60568 +val_progress:2250/60568 +val_progress:2275/60568 +val_progress:2300/60568 +val_progress:2325/60568 +val_progress:2350/60568 +val_progress:2375/60568 +val_progress:2400/60568 +val_progress:2425/60568 +val_progress:2450/60568 +val_progress:2475/60568 +val_progress:2500/60568 +val_progress:2525/60568 +val_progress:2550/60568 +val_progress:2575/60568 +val_progress:2600/60568 +val_progress:2625/60568 +val_progress:2650/60568 +val_progress:2675/60568 +val_progress:2700/60568 +val_progress:2725/60568 +val_progress:2750/60568 +val_progress:2775/60568 +val_progress:2800/60568 +val_progress:2825/60568 +val_progress:2850/60568 +val_progress:2875/60568 +val_progress:2900/60568 +val_progress:2925/60568 +val_progress:2950/60568 +val_progress:2975/60568 +val_progress:3000/60568 +val_progress:3025/60568 +val_progress:3050/60568 +val_progress:3075/60568 +val_progress:3100/60568 +val_progress:3125/60568 +val_progress:3150/60568 +val_progress:3175/60568 +val_progress:3200/60568 +val_progress:3225/60568 +val_progress:3250/60568 +val_progress:3275/60568 +val_progress:3300/60568 +val_progress:3325/60568 +val_progress:3350/60568 +val_progress:3375/60568 +val_progress:3400/60568 +val_progress:3425/60568 +val_progress:3450/60568 +val_progress:3475/60568 +val_progress:3500/60568 +val_progress:3525/60568 +val_progress:3550/60568 +val_progress:3575/60568 +val_progress:3600/60568 +val_progress:3625/60568 +val_progress:3650/60568 +val_progress:3675/60568 +val_progress:3700/60568 +val_progress:3725/60568 +val_progress:3750/60568 +val_progress:3775/60568 +val_progress:3800/60568 +val_progress:3825/60568 +val_progress:3850/60568 +val_progress:3875/60568 +val_progress:3900/60568 +val_progress:3925/60568 +val_progress:3950/60568 +val_progress:3975/60568 +val_progress:4000/60568 +val_progress:4025/60568 +val_progress:4050/60568 +val_progress:4075/60568 +val_progress:4100/60568 +val_progress:4125/60568 +val_progress:4150/60568 +val_progress:4175/60568 +val_progress:4200/60568 +val_progress:4225/60568 +val_progress:4250/60568 +val_progress:4275/60568 +val_progress:4300/60568 +val_progress:4325/60568 +val_progress:4350/60568 +val_progress:4375/60568 +val_progress:4400/60568 +val_progress:4425/60568 +val_progress:4450/60568 +val_progress:4475/60568 +val_progress:4500/60568 +val_progress:4525/60568 +val_progress:4550/60568 +val_progress:4575/60568 +val_progress:4600/60568 +val_progress:4625/60568 +val_progress:4650/60568 +val_progress:4675/60568 +val_progress:4700/60568 +val_progress:4725/60568 +val_progress:4750/60568 +val_progress:4775/60568 +val_progress:4800/60568 +val_progress:4825/60568 +val_progress:4850/60568 +val_progress:4875/60568 +val_progress:4900/60568 +val_progress:4925/60568 +val_progress:4950/60568 +val_progress:4975/60568 +val_progress:5000/60568 +val_progress:5025/60568 +val_progress:5050/60568 +val_progress:5075/60568 From 9ff83034d80103a067eaac71d80f6263349b76e9 Mon Sep 17 00:00:00 2001 From: Saken Tukenov Date: Fri, 20 Mar 2026 00:56:48 +0500 Subject: [PATCH 2/4] Replace draft WIP with finished 4090 non-record submission --- .../README.md | 52 - .../eval_probe.log | 1151 -------------- .../train_gpt_mlx.py | 1104 ------------- .../train_partial.log | 1360 ----------------- .../README.md | 57 + .../submission.json | 18 + .../train.log | 1208 +++++++++++++++ .../train_gpt.py | 1131 ++++++++++++++ 8 files changed, 2414 insertions(+), 3667 deletions(-) delete mode 100644 records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/README.md delete mode 100644 records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/eval_probe.log delete mode 100644 records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_gpt_mlx.py delete mode 100644 records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_partial.log create mode 100644 records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/README.md create mode 100644 records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/submission.json create mode 100644 records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/train.log create mode 100644 records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/train_gpt.py diff --git a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/README.md b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/README.md deleted file mode 100644 index 3dce84ef01..0000000000 --- a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/README.md +++ /dev/null @@ -1,52 +0,0 @@ -This draft folder documents an in-progress local Apple Silicon reproduction using the current root `train_gpt_mlx.py`. - -This is not a finished non-record submission yet. It exists to show concrete reproduction work, local logs, and the exact script snapshot before moving to a completed cloud-backed run. - -What is included: -- `train_gpt_mlx.py`: exact MLX script snapshot used for the local runs. -- `train_partial.log`: 200-step local smoke run on an Apple M1 with 1 FineWeb train shard and the full fixed validation split. Training completed through step 200 and then entered full validation. -- `eval_probe.log`: follow-up probe using a larger validation batch size to test the local 8GB memory / eval-time tradeoff. - -Local machine: -- Apple M1 MacBook Air -- 8GB unified memory -- Python 3.14 -- MLX 0.31.1 - -Smoke configuration: -- Tokenizer / dataset: `sp1024`, `fineweb10B_sp1024` -- Train shards: `1` -- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` -- Tied embeddings: `TIE_EMBEDDINGS=1` -- Training: `ITERATIONS=200 TRAIN_BATCH_TOKENS=8192` -- Validation: full fixed `fineweb_val_*` split - -Observed outcome so far: -- Local training is stable and reproduces the baseline setup on Apple Silicon. -- Full final validation on an 8GB M1 is much slower than training, so this draft does not yet include a finished `submission.json`. -- The next step is to rerun the same baseline path on cloud GPUs and convert this folder into a completed non-record submission with final `val_bpb`, artifact bytes, and `submission.json`. - -Command used for the main smoke run: -```bash -source .venv/bin/activate -RUN_ID=stukenov_mlx_smoke \ -ITERATIONS=200 \ -TRAIN_BATCH_TOKENS=8192 \ -VAL_LOSS_EVERY=0 \ -VAL_BATCH_SIZE=8192 \ -TRAIN_LOG_EVERY=50 \ -python train_gpt_mlx.py -``` - -Command used for the eval probe: -```bash -source .venv/bin/activate -RUN_ID=stukenov_mlx_probe \ -ITERATIONS=1 \ -TRAIN_BATCH_TOKENS=8192 \ -VAL_LOSS_EVERY=0 \ -VAL_BATCH_SIZE=65536 \ -TRAIN_LOG_EVERY=1 \ -MAX_WALLCLOCK_SECONDS=0 \ -python train_gpt_mlx.py -``` diff --git a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/eval_probe.log b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/eval_probe.log deleted file mode 100644 index 1a9fa26fde..0000000000 --- a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/eval_probe.log +++ /dev/null @@ -1,1151 +0,0 @@ -#!/usr/bin/env python3 -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" -from __future__ import annotations - -import glob -import json -import math -import os -import pickle -import sys -import time -import uuid -import zlib -from collections.abc import Callable -from pathlib import Path - -import numpy as np -import sentencepiece as spm - -import mlx.core as mx -import mlx.nn as nn -import mlx.optimizers as optim -from mlx.utils import tree_flatten, tree_unflatten - -# ============================================================================== -# SHARD FORMAT + COMPUTE DTYPE -# ============================================================================== - -COMPUTE_DTYPE = mx.bfloat16 - -# ============================================================================== -# HYPERPARAMETERS -# ============================================================================== -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap -class Hyperparameters: - # Data / tokenizer. - data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed: int = int(os.environ.get("SEED", 1337)) - - # Training loop. These defaults now mirror train_gpt.py on a single process. - iterations: int = int(os.environ.get("ITERATIONS", 20_000)) - val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) - # Validation always uses the full fineweb_val split. - val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) - train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) - # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak - # memory pressure without changing the effective optimizer batch. - mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) - # Force MLX to materialize the graph after every sub-batch, preventing lazy - # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. - # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). - mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) - warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) - warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) - max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - - # Model (defaults match the current baseline setup). - vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) - model_dim: int = int(os.environ.get("MODEL_DIM", 512)) - num_heads: int = int(os.environ.get("NUM_HEADS", 8)) - num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) - mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) - logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) - qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Optimizer. We keep the same per-group defaults as train_gpt.py. - beta1: float = float(os.environ.get("BETA1", 0.9)) - beta2: float = float(os.environ.get("BETA2", 0.95)) - adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) - tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) - matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - out_dir: str = os.environ.get("OUT_DIR", "logs") - - @property - def train_files(self) -> str: - return f"{self.data_path}/fineweb_train_*.bin" - - @property - def val_files(self) -> str: - return f"{self.data_path}/fineweb_val_*.bin" - - @property - def microbatch_tokens(self) -> int: - return self.train_batch_tokens // self.grad_accum_steps - - def lr_mul(self, step: int, elapsed_ms: float) -> float: - if self.warmdown_iters <= 0: - return 1.0 - if self.max_wallclock_seconds <= 0: - warmdown_start = max(self.iterations - self.warmdown_iters, 0) - return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = self.warmdown_iters * step_ms - remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) - - -def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: - usable_total = (total_tokens // seq_len) * seq_len - if usable_total <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) - chunks: list[int] = [] - remaining = usable_total - while remaining > 0: - chunk = min(remaining, usable_chunk) - chunks.append(chunk) - remaining -= chunk - return chunks - - -def accumulate_flat_grads( - accum: dict[str, mx.array] | None, - grads_tree: dict, - scale: float, -) -> dict[str, mx.array]: - flat = dict(tree_flatten(grads_tree)) - if accum is None: - return {k: g * scale for k, g in flat.items()} - for k, g in flat.items(): - accum[k] = accum[k] + g * scale - return accum - - -# ============================================================================== -# MATH HELPERS -# ============================================================================== - -def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: - return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) - - -def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - # Background on Muon: https://kellerjordan.github.io/posts/muon/ - a, b, c = 3.4445, -4.7750, 2.0315 - x = g.astype(mx.float32) - x = x / (mx.sqrt(mx.sum(x * x)) + eps) - transposed = x.shape[0] > x.shape[1] - if transposed: - x = x.T - for _ in range(steps): - a_mat = x @ x.T - b_mat = b * a_mat + c * (a_mat @ a_mat) - x = a * x + b_mat @ x - if transposed: - x = x.T - return x.astype(g.dtype) - - -def load_data_shard(path: Path) -> np.ndarray: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - if self.file_idx == 0: - self.epoch += 1 - if self.log_fn is not None: - self.log_fn( - f"WARNING: starting epoch:{self.epoch} " - f"dataset:{self.dataset_name} train_shards:{len(self.files)}" - ) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> np.ndarray: - chunks: list[np.ndarray] = [] - left = n - while left > 0: - if self.pos >= self.tokens.size: - self.next_file() - k = min(left, int(self.tokens.size - self.pos)) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - left -= k - return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) - - -class TokenLoader: - def __init__( - self, - pattern: str, - log_fn: Callable[[str], None] | None = None, - dataset_name: str = "", - ): - self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) - - def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: - usable = (batch_tokens // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - chunk = self.stream.take(usable + 1) - x = chunk[:-1].reshape(-1, seq_len) - y = chunk[1:].reshape(-1, seq_len) - return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) - - -# ============================================================================== -# MODEL BLOCKS -# ============================================================================== - -class CastedLinear(nn.Module): - def __init__(self, in_dim: int, out_dim: int): - super().__init__() - self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) - - def __call__(self, x: mx.array) -> mx.array: - return x @ self.weight.astype(x.dtype).T - - -class RMSNormNoWeight(nn.Module): - # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. - def __call__(self, x: mx.array) -> mx.array: - return rms_norm(x) - - -class CausalSelfAttention(nn.Module): - # - separate q/k/v projections - # - RMSNorm on q and k before attention - # - RoPE on q and k - # - causal masked SDPA - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim) - self.c_k = CastedLinear(dim, kv_dim) - self.c_v = CastedLinear(dim, kv_dim) - self.proj = CastedLinear(dim, dim) - self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init - self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) - self.scale = self.head_dim ** -0.5 - - def __call__(self, x: mx.array) -> mx.array: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - - q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) - k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) - q = q * self.q_gain.astype(q.dtype)[None, :, None, None] - y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") - y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = dim * mlp_mult - self.fc = CastedLinear(dim, hidden) - self.proj = CastedLinear(hidden, dim) - - def __call__(self, x: mx.array) -> mx.array: - x = nn.relu(self.fc(x)) - return self.proj(x * x) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNormNoWeight() - self.mlp_norm = RMSNormNoWeight() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = mx.ones((dim,), dtype=mx.float32) - self.mlp_scale = mx.ones((dim,), dtype=mx.float32) - self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) - - def __call__(self, x: mx.array, x0: mx.array) -> mx.array: - mix = self.resid_mix.astype(x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - # - token embedding + RMSNorm - # - encoder half accumulates skip tensors - # - decoder half consumes reversed skips with learned skip_weights - # - tied embeddings for the LM head (the baseline default setup) - def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, - qk_gain_init: float): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.logit_chunk_tokens = logit_chunk_tokens - self.logit_softcap = logit_softcap - - self.tok_emb = nn.Embedding(vocab_size, dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) - self.blocks = [ - Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for i in range(num_layers) - ] - self.final_norm = RMSNormNoWeight() - - for b in self.blocks: - b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) - b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) - self.tok_emb.weight = ( - mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std - ).astype(COMPUTE_DTYPE) - - def softcap(self, logits: mx.array) -> mx.array: - c = self.logit_softcap - return c * mx.tanh(logits / c) - - def __call__(self, input_ids: mx.array) -> mx.array: - x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) - x0 = x - skips: list[mx.array] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - # Odd layer counts have one more decoder block than encoder block. The baseline only - # applies a skip connection when one exists, then runs the remaining decoder block(s) - # without an added skip. - if skips: - x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - return self.final_norm(x) - - def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: - # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful - # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). - x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) - y = target_ids.reshape(-1) - if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: - logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") - - loss_sum = mx.array(0.0, dtype=mx.float32) - n = int(x.shape[0]) - for s in range(0, n, self.logit_chunk_tokens): - e = min(s + self.logit_chunk_tokens, n) - logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") - return loss_sum / float(n) - -# ============================================================================== -# OPTIMIZERS (MUON + ADAM SPLIT) -# ============================================================================== -class Muon: - # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the - # parameter update. - def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): - self.keys = keys - self.args = args - self.buffers = {k: mx.zeros_like(params[k]) for k in keys} - - def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: - if self.args.muon_momentum_warmup_steps: - t = min(step / self.args.muon_momentum_warmup_steps, 1.0) - momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum - else: - momentum = self.args.muon_momentum - lr = self.args.matrix_lr * lr_mul - out: dict[str, mx.array] = {} - for k in self.keys: - p = params[k] - g = grads[k] - buf = momentum * self.buffers[k] + g - self.buffers[k] = buf - g_eff = g + momentum * buf - g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) - scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) - out[k] = p - lr * (g_ortho * scale).astype(p.dtype) - return out - - -class SplitOptimizers: - # - embeddings: Adam with the tied-embedding LR - # - block matrices (2D): Muon - # - block scalars + skip weights: Adam - # This preserves the high-level optimization behavior even though MLX internals differ. - def __init__(self, model: GPT, args: Hyperparameters): - self.args = args - params = dict(tree_flatten(model.parameters())) - self.embed_key = "tok_emb.weight" - self.matrix_keys = [ - k - for k, p in params.items() - if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - self.scalar_keys = [ - k - for k, p in params.items() - if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) - ] - - self.muon = Muon(self.matrix_keys, params, args) - self.adam_embed = optim.Adam( - learning_rate=args.tied_embed_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - self.adam_scalar = optim.Adam( - learning_rate=args.scalar_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - - def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: - params = dict(tree_flatten(model.parameters())) - grads = dict(tree_flatten(grads_tree)) - updated = dict(params) - - updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) - - self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul - updated.update( - self.adam_embed.apply_gradients( - {self.embed_key: grads[self.embed_key]}, - {self.embed_key: params[self.embed_key]}, - ) - ) - - self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul - scalar_grads = {k: grads[k] for k in self.scalar_keys} - scalar_params = {k: params[k] for k in self.scalar_keys} - updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) - - model.update(tree_unflatten(list(updated.items()))) - -# ============================================================================== -# QUANTIZATION (INT8 + ZLIB) -# ============================================================================== -# - per-row int8 for 2D float tensors -# - per-tensor int8 for other float tensors -# - fp16 passthrough for small float tensors -# - exact passthrough for non-floats - -MX_DTYPE_FROM_NAME = { - "float32": mx.float32, - "float16": mx.float16, - "bfloat16": mx.bfloat16, -} - -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 -INT8_PER_ROW_SCALE_DTYPE = np.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - - -def _np_float32(arr: mx.array) -> np.ndarray: - return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) - - -def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return np.ascontiguousarray(_np_float32(arr)) - if arr.dtype in {mx.float32, mx.bfloat16}: - passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] - return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) - return np.ascontiguousarray(np.array(arr, copy=True)) - - -def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: - f32 = _np_float32(arr) - if f32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) - clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) - scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) - q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 - scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) - q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), scale - - -def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: - quantized: dict[str, np.ndarray] = {} - scales: dict[str, np.ndarray] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, np.ndarray] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, arr in flat_state.items(): - stats["param_count"] += int(arr.size) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += int(arr.nbytes) - if not mx.issubdtype(arr.dtype, mx.floating): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = np.ascontiguousarray(np.array(arr)) - stats["int8_payload_bytes"] += int(passthrough[name].nbytes) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_array(name, arr, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += int(kept.nbytes) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_array(arr) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(arr.dtype).split(".")[-1] - stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - - -def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: - out: dict[str, mx.array] = {} - qmeta = quant_obj.get("qmeta", {}) - passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) - for name, q in quant_obj["quantized"].items(): - q_np = np.asarray(q, dtype=np.int8) - dtype_name = quant_obj["dtypes"][name] - scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) - if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: - # Broadcast the saved row scale back across trailing dimensions. - out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) - else: - out_arr = q_np.astype(np.float32) * float(scale) - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) - for name, arr in quant_obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_arr = np.array(arr, copy=True) - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) - else: - out[name] = mx.array(out_arr) - return out - - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_lut = np.zeros((table_size,), dtype=np.int16) - has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_lut[token_id] = False - if sp.is_byte(token_id): - base_bytes_lut[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_lut[token_id] = True - piece = piece[1:] - base_bytes_lut[token_id] = len(piece.encode("utf-8")) - return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut - - -def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: - # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we - # decode bytes with the exact tokenizer that produced the shards. The manifest - # lets the training script fail fast on accidental dataset/tokenizer mismatches. - dataset_dir = Path(data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - if len(dataset_dir.parents) < 2: - return dataset_dir.name, actual_train_files, None - manifest_path = dataset_dir.parents[1] / "manifest.json" - if not manifest_path.is_file(): - return dataset_dir.name, actual_train_files, None - - manifest = json.loads(manifest_path.read_text(encoding="utf-8")) - dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) - if dataset_entry is None: - return dataset_dir.name, actual_train_files, None - - tokenizer_name = dataset_entry.get("tokenizer_name") - tokenizer_entry = ( - next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) - if tokenizer_name - else None - ) - expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name - if expected_name and Path(tokenizer_path).name != expected_name: - raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") - expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") - if expected_train_files is not None: - expected_train_files = int(expected_train_files) - if actual_train_files > expected_train_files: - raise ValueError( - f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " - f"manifest says {expected_train_files}" - ) - return dataset_dir.name, actual_train_files, expected_train_files - - -def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) - usable = ((tokens.size - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def loss_and_grad_chunked( - args: Hyperparameters, - train_loader: TokenLoader, - compiled_loss_and_grad, -) -> tuple[mx.array, dict]: - chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) - total_tokens = float(sum(chunk_sizes)) - loss_value = mx.array(0.0, dtype=mx.float32) - grad_accum: dict[str, mx.array] | None = None - for chunk_tokens in chunk_sizes: - x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) - loss, grads = compiled_loss_and_grad(x, y) - scale = float(y.size) / total_tokens - loss_value = loss_value + loss.astype(mx.float32) * scale - grad_accum = accumulate_flat_grads(grad_accum, grads, scale) - if args.mlx_eager_eval: - mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory - return loss_value, tree_unflatten(list(grad_accum.items())) - - -def eval_val( - args: Hyperparameters, - compiled_loss, - val_tokens: np.ndarray, - base_bytes_lut: np.ndarray, - has_leading_space_lut: np.ndarray, - is_boundary_token_lut: np.ndarray, - log_fn: Callable[[str], None] | None = None, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - val_batch_seqs = val_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.size - 1) // args.train_seq_len - total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) - total_loss_sum = 0.0 - total_tokens = 0.0 - total_bytes = 0.0 - for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): - batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - chunk = val_tokens[raw_start:raw_end] - x_np = chunk[:-1].reshape(-1, args.train_seq_len) - y_np = chunk[1:].reshape(-1, args.train_seq_len) - x = mx.array(x_np, dtype=mx.int32) - y = mx.array(y_np, dtype=mx.int32) - chunk_token_count = float(y.size) - batch_loss = compiled_loss(x, y).astype(mx.float32) - mx.eval(batch_loss) - total_loss_sum += float(batch_loss.item()) * chunk_token_count - prev_ids = x_np.reshape(-1) - tgt_ids = y_np.reshape(-1) - bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) - bytes_np += ( - has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] - ).astype(np.int16, copy=False) - total_tokens += chunk_token_count - total_bytes += float(bytes_np.astype(np.float64).sum()) - if log_fn is not None and total_batches > 1 and ( - batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 - ): - log_fn(f"val_progress:{batch_idx}/{total_batches}") - val_loss = total_loss_sum / total_tokens - bits_per_token = val_loss / math.log(2.0) - val_bpb = bits_per_token * (total_tokens / total_bytes) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: - if max_norm <= 0: - return grads_tree - flat = dict(tree_flatten(grads_tree)) - total_sq = 0.0 - for grad in flat.values(): - total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) - if total_sq <= 0.0: - return grads_tree - total_norm = math.sqrt(total_sq) - if total_norm <= max_norm: - return grads_tree - scale = max_norm / (total_norm + 1e-12) - return tree_unflatten([(k, g * scale) for k, g in flat.items()]) - - -def main() -> None: - # ============================================================================== - # TOKENIZER + VALIDATION METRIC SETUP - # ============================================================================== - args = Hyperparameters() - out_dir = Path(args.out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - logfile = out_dir / f"{args.run_id}.txt" - print(logfile) - - def log(msg: str, console: bool = True) -> None: - if console: - print(msg) - with logfile.open("a", encoding="utf-8") as f: - print(msg, file=f) - - code = Path(__file__).read_text(encoding="utf-8") - log(code, console=False) - log("=" * 100, console=False) - log(f"Running Python {sys.version}", console=False) - log(f"Running MLX {mx.__version__}", console=False) - log("=" * 100, console=False) - - if not args.tie_embeddings: - raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( - args.data_path, - args.tokenizer_path, - ) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size - ) - - # ============================================================================== - # TRAINING SETUP - # ============================================================================== - mx.random.seed(args.seed) - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - # ============================================================================== - # MODEL + OPTIMIZER SETUP - # ============================================================================== - model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - logit_chunk_tokens=args.logit_chunk_tokens, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - tied_embed_init_std=args.tied_embed_init_std, - qk_gain_init=args.qk_gain_init, - ) - opt = SplitOptimizers(model, args) - - # ============================================================================== - # COMPILED TRAIN / EVAL FUNCTIONS (MLX) - # ============================================================================== - # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example - # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". - # Compiling the model-bound functions and capturing the full model state fixes that while still - # returning gradients only for trainable parameters via nn.value_and_grad(...). - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) - compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), - inputs=model.state, - outputs=model.state, - ) - - # Print config once so logs are self-describing. - n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) - log(f"run_id:{args.run_id}") - log(f"mlx_version:{mx.__version__}") - log(f"train_loader:shards pattern={args.train_files}") - log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") - if expected_train_files is None: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") - elif actual_train_files < expected_train_files: - log( - f"WARNING: train_loader:subset dataset:{dataset_name} " - f"train_shards:{actual_train_files}/{expected_train_files} " - f"new epochs will arrive sooner than the full dataset" - ) - else: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") - log(f"tokenizer_path:{args.tokenizer_path}") - log( - f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " - f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " - f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" - ) - log( - f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " - f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " - f"val_batch_size:{args.val_batch_size} " - f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") - log( - f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " - f"embed_lr:{args.tied_embed_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " - f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" - ) - log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") - log( - f"dtypes tok_emb:{model.tok_emb.weight.dtype} " - f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " - f"skip_weights:{model.skip_weights.dtype}" - ) - - # ============================================================================== - # TRAINING LOOP - # ============================================================================== - if args.warmup_steps > 0: - # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us - # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. - # Instead we run the real train shapes, force the loss/grads to materialize, and then reset - # the loader so measured training still starts from the true init and token window. - for warmup_step in range(args.warmup_steps): - accum: dict[str, mx.array] | None = None - warmup_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - mx.eval(warmup_loss, accum) - mx.synchronize() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - - # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) - warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] - x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) - y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) - warm_val_loss = compiled_loss(x_val, y_val) - mx.eval(warm_val_loss) - mx.synchronize() - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - train_time_ms = 0.0 - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - stop_after_step: int | None = None - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - train_time_ms += 1000.0 * (time.perf_counter() - t0) - # Validation always scans the same fixed full validation split. - val_loss, val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - if step % 25 == 0 or last_step: - log( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" - ) - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) - step_t0 = time.perf_counter() - - accum: dict[str, mx.array] | None = None - train_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - train_loss = train_loss + loss.astype(mx.float32) * grad_scale - if args.mlx_eager_eval: - mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory - - grads = tree_unflatten(list(accum.items())) - grads = clip_grad_tree(grads, args.grad_clip_norm) - train_loss_value = float(train_loss.item()) - opt.step(model, grads, step=step, lr_mul=lr_mul) - mx.synchronize() - - step_ms = 1000.0 * (time.perf_counter() - step_t0) - approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) - tok_s = args.train_batch_tokens / (step_ms / 1000.0) - step += 1 - if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): - log( - f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " - f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" - ) - if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: - stop_after_step = step - - # ============================================================================== - # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL - # ============================================================================== - # We always write a raw artifact and a quantized artifact, then validate the - # quantized roundtrip directly by loading the dequantized tensors back into the - # model and running one final validation pass. - out_path = out_dir / f"{args.run_id}_mlx_model.npz" - flat_state = {k: v for k, v in tree_flatten(model.state)} - mx.savez(str(out_path), **flat_state) - log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") - - quant_obj, quant_stats = quantize_state_dict_int8(flat_state) - quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) - quant_blob = zlib.compress(quant_raw, level=9) - quant_serialized_bytes = len(quant_raw) - quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" - with quant_path.open("wb") as f: - f.write(quant_blob) - quant_file_bytes = quant_path.stat().st_size - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log( - f"serialized_model_int8_zlib:{quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" - ) - - with quant_path.open("rb") as f: - quant_blob_disk = f.read() - quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) - model.update(tree_unflatten(list(quant_flat.items()))) - q_t0 = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) - log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") - log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.14.2 (main, Dec 5 2025, 16:49:16) [Clang 17.0.0 (clang-1700.4.4.1)] -Running MLX 0.31.1 -==================================================================================================== -run_id:stukenov_mlx_probe -mlx_version:0.31.1 -train_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_train_*.bin -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -WARNING: train_loader:subset dataset:fineweb10B_sp1024 train_shards:1/195 new epochs will arrive sooner than the full dataset -tokenizer_path:./data/tokenizers/fineweb_1024_bpe.model -model_params:17059912 vocab_size:1024 layers:9 dim:512 heads:8 kv_heads:4 seq_len:1024 tie_embeddings:True -iterations:1 train_batch_tokens:8192 grad_accum_steps:8 microbatch_tokens:1024 microbatch_batch_size:1 val_batch_size:65536 warmup_steps:20 max_wallclock_seconds:0.000 -mlx_max_microbatch_tokens:8192 -optimizer:muon+adam muon_matrix_params:54 scalar_params:37 embed_lr:0.05 matrix_lr:0.04 scalar_lr:0.04 muon_momentum:0.95 muon_steps:5 -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -compute_dtype:mlx.core.bfloat16 compile:True -dtypes tok_emb:mlx.core.bfloat16 linear_weight:mlx.core.float32 skip_weights:mlx.core.float32 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/1 train_loss:6.9428 train_time:1651ms step_avg:1651.25ms tok_s:4962 -val_progress:1/7571 -val_progress:25/7571 -val_progress:50/7571 -val_progress:75/7571 -val_progress:100/7571 -val_progress:125/7571 -val_progress:150/7571 -val_progress:175/7571 diff --git a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_gpt_mlx.py b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_gpt_mlx.py deleted file mode 100644 index 7b9e935aa6..0000000000 --- a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_gpt_mlx.py +++ /dev/null @@ -1,1104 +0,0 @@ -#!/usr/bin/env python3 -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" -from __future__ import annotations - -import glob -import json -import math -import os -import pickle -import sys -import time -import uuid -import zlib -from collections.abc import Callable -from pathlib import Path - -import numpy as np -import sentencepiece as spm - -import mlx.core as mx -import mlx.nn as nn -import mlx.optimizers as optim -from mlx.utils import tree_flatten, tree_unflatten - -# ============================================================================== -# SHARD FORMAT + COMPUTE DTYPE -# ============================================================================== - -COMPUTE_DTYPE = mx.bfloat16 - -# ============================================================================== -# HYPERPARAMETERS -# ============================================================================== -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap -class Hyperparameters: - # Data / tokenizer. - data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed: int = int(os.environ.get("SEED", 1337)) - - # Training loop. These defaults now mirror train_gpt.py on a single process. - iterations: int = int(os.environ.get("ITERATIONS", 20_000)) - val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) - # Validation always uses the full fineweb_val split. - val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) - train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) - # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak - # memory pressure without changing the effective optimizer batch. - mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) - # Force MLX to materialize the graph after every sub-batch, preventing lazy - # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. - # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). - mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) - warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) - warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) - max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - - # Model (defaults match the current baseline setup). - vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) - model_dim: int = int(os.environ.get("MODEL_DIM", 512)) - num_heads: int = int(os.environ.get("NUM_HEADS", 8)) - num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) - mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) - logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) - qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Optimizer. We keep the same per-group defaults as train_gpt.py. - beta1: float = float(os.environ.get("BETA1", 0.9)) - beta2: float = float(os.environ.get("BETA2", 0.95)) - adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) - tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) - matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - out_dir: str = os.environ.get("OUT_DIR", "logs") - - @property - def train_files(self) -> str: - return f"{self.data_path}/fineweb_train_*.bin" - - @property - def val_files(self) -> str: - return f"{self.data_path}/fineweb_val_*.bin" - - @property - def microbatch_tokens(self) -> int: - return self.train_batch_tokens // self.grad_accum_steps - - def lr_mul(self, step: int, elapsed_ms: float) -> float: - if self.warmdown_iters <= 0: - return 1.0 - if self.max_wallclock_seconds <= 0: - warmdown_start = max(self.iterations - self.warmdown_iters, 0) - return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = self.warmdown_iters * step_ms - remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) - - -def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: - usable_total = (total_tokens // seq_len) * seq_len - if usable_total <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) - chunks: list[int] = [] - remaining = usable_total - while remaining > 0: - chunk = min(remaining, usable_chunk) - chunks.append(chunk) - remaining -= chunk - return chunks - - -def accumulate_flat_grads( - accum: dict[str, mx.array] | None, - grads_tree: dict, - scale: float, -) -> dict[str, mx.array]: - flat = dict(tree_flatten(grads_tree)) - if accum is None: - return {k: g * scale for k, g in flat.items()} - for k, g in flat.items(): - accum[k] = accum[k] + g * scale - return accum - - -# ============================================================================== -# MATH HELPERS -# ============================================================================== - -def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: - return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) - - -def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - # Background on Muon: https://kellerjordan.github.io/posts/muon/ - a, b, c = 3.4445, -4.7750, 2.0315 - x = g.astype(mx.float32) - x = x / (mx.sqrt(mx.sum(x * x)) + eps) - transposed = x.shape[0] > x.shape[1] - if transposed: - x = x.T - for _ in range(steps): - a_mat = x @ x.T - b_mat = b * a_mat + c * (a_mat @ a_mat) - x = a * x + b_mat @ x - if transposed: - x = x.T - return x.astype(g.dtype) - - -def load_data_shard(path: Path) -> np.ndarray: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - if self.file_idx == 0: - self.epoch += 1 - if self.log_fn is not None: - self.log_fn( - f"WARNING: starting epoch:{self.epoch} " - f"dataset:{self.dataset_name} train_shards:{len(self.files)}" - ) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> np.ndarray: - chunks: list[np.ndarray] = [] - left = n - while left > 0: - if self.pos >= self.tokens.size: - self.next_file() - k = min(left, int(self.tokens.size - self.pos)) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - left -= k - return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) - - -class TokenLoader: - def __init__( - self, - pattern: str, - log_fn: Callable[[str], None] | None = None, - dataset_name: str = "", - ): - self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) - - def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: - usable = (batch_tokens // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - chunk = self.stream.take(usable + 1) - x = chunk[:-1].reshape(-1, seq_len) - y = chunk[1:].reshape(-1, seq_len) - return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) - - -# ============================================================================== -# MODEL BLOCKS -# ============================================================================== - -class CastedLinear(nn.Module): - def __init__(self, in_dim: int, out_dim: int): - super().__init__() - self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) - - def __call__(self, x: mx.array) -> mx.array: - return x @ self.weight.astype(x.dtype).T - - -class RMSNormNoWeight(nn.Module): - # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. - def __call__(self, x: mx.array) -> mx.array: - return rms_norm(x) - - -class CausalSelfAttention(nn.Module): - # - separate q/k/v projections - # - RMSNorm on q and k before attention - # - RoPE on q and k - # - causal masked SDPA - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim) - self.c_k = CastedLinear(dim, kv_dim) - self.c_v = CastedLinear(dim, kv_dim) - self.proj = CastedLinear(dim, dim) - self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init - self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) - self.scale = self.head_dim ** -0.5 - - def __call__(self, x: mx.array) -> mx.array: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - - q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) - k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) - q = q * self.q_gain.astype(q.dtype)[None, :, None, None] - y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") - y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = dim * mlp_mult - self.fc = CastedLinear(dim, hidden) - self.proj = CastedLinear(hidden, dim) - - def __call__(self, x: mx.array) -> mx.array: - x = nn.relu(self.fc(x)) - return self.proj(x * x) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNormNoWeight() - self.mlp_norm = RMSNormNoWeight() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = mx.ones((dim,), dtype=mx.float32) - self.mlp_scale = mx.ones((dim,), dtype=mx.float32) - self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) - - def __call__(self, x: mx.array, x0: mx.array) -> mx.array: - mix = self.resid_mix.astype(x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - # - token embedding + RMSNorm - # - encoder half accumulates skip tensors - # - decoder half consumes reversed skips with learned skip_weights - # - tied embeddings for the LM head (the baseline default setup) - def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, - qk_gain_init: float): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.logit_chunk_tokens = logit_chunk_tokens - self.logit_softcap = logit_softcap - - self.tok_emb = nn.Embedding(vocab_size, dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) - self.blocks = [ - Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for i in range(num_layers) - ] - self.final_norm = RMSNormNoWeight() - - for b in self.blocks: - b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) - b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) - self.tok_emb.weight = ( - mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std - ).astype(COMPUTE_DTYPE) - - def softcap(self, logits: mx.array) -> mx.array: - c = self.logit_softcap - return c * mx.tanh(logits / c) - - def __call__(self, input_ids: mx.array) -> mx.array: - x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) - x0 = x - skips: list[mx.array] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - # Odd layer counts have one more decoder block than encoder block. The baseline only - # applies a skip connection when one exists, then runs the remaining decoder block(s) - # without an added skip. - if skips: - x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - return self.final_norm(x) - - def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: - # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful - # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). - x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) - y = target_ids.reshape(-1) - if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: - logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") - - loss_sum = mx.array(0.0, dtype=mx.float32) - n = int(x.shape[0]) - for s in range(0, n, self.logit_chunk_tokens): - e = min(s + self.logit_chunk_tokens, n) - logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") - return loss_sum / float(n) - -# ============================================================================== -# OPTIMIZERS (MUON + ADAM SPLIT) -# ============================================================================== -class Muon: - # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the - # parameter update. - def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): - self.keys = keys - self.args = args - self.buffers = {k: mx.zeros_like(params[k]) for k in keys} - - def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: - if self.args.muon_momentum_warmup_steps: - t = min(step / self.args.muon_momentum_warmup_steps, 1.0) - momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum - else: - momentum = self.args.muon_momentum - lr = self.args.matrix_lr * lr_mul - out: dict[str, mx.array] = {} - for k in self.keys: - p = params[k] - g = grads[k] - buf = momentum * self.buffers[k] + g - self.buffers[k] = buf - g_eff = g + momentum * buf - g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) - scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) - out[k] = p - lr * (g_ortho * scale).astype(p.dtype) - return out - - -class SplitOptimizers: - # - embeddings: Adam with the tied-embedding LR - # - block matrices (2D): Muon - # - block scalars + skip weights: Adam - # This preserves the high-level optimization behavior even though MLX internals differ. - def __init__(self, model: GPT, args: Hyperparameters): - self.args = args - params = dict(tree_flatten(model.parameters())) - self.embed_key = "tok_emb.weight" - self.matrix_keys = [ - k - for k, p in params.items() - if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - self.scalar_keys = [ - k - for k, p in params.items() - if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) - ] - - self.muon = Muon(self.matrix_keys, params, args) - self.adam_embed = optim.Adam( - learning_rate=args.tied_embed_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - self.adam_scalar = optim.Adam( - learning_rate=args.scalar_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - - def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: - params = dict(tree_flatten(model.parameters())) - grads = dict(tree_flatten(grads_tree)) - updated = dict(params) - - updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) - - self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul - updated.update( - self.adam_embed.apply_gradients( - {self.embed_key: grads[self.embed_key]}, - {self.embed_key: params[self.embed_key]}, - ) - ) - - self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul - scalar_grads = {k: grads[k] for k in self.scalar_keys} - scalar_params = {k: params[k] for k in self.scalar_keys} - updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) - - model.update(tree_unflatten(list(updated.items()))) - -# ============================================================================== -# QUANTIZATION (INT8 + ZLIB) -# ============================================================================== -# - per-row int8 for 2D float tensors -# - per-tensor int8 for other float tensors -# - fp16 passthrough for small float tensors -# - exact passthrough for non-floats - -MX_DTYPE_FROM_NAME = { - "float32": mx.float32, - "float16": mx.float16, - "bfloat16": mx.bfloat16, -} - -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 -INT8_PER_ROW_SCALE_DTYPE = np.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - - -def _np_float32(arr: mx.array) -> np.ndarray: - return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) - - -def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return np.ascontiguousarray(_np_float32(arr)) - if arr.dtype in {mx.float32, mx.bfloat16}: - passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] - return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) - return np.ascontiguousarray(np.array(arr, copy=True)) - - -def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: - f32 = _np_float32(arr) - if f32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) - clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) - scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) - q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 - scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) - q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), scale - - -def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: - quantized: dict[str, np.ndarray] = {} - scales: dict[str, np.ndarray] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, np.ndarray] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, arr in flat_state.items(): - stats["param_count"] += int(arr.size) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += int(arr.nbytes) - if not mx.issubdtype(arr.dtype, mx.floating): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = np.ascontiguousarray(np.array(arr)) - stats["int8_payload_bytes"] += int(passthrough[name].nbytes) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_array(name, arr, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += int(kept.nbytes) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_array(arr) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(arr.dtype).split(".")[-1] - stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - - -def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: - out: dict[str, mx.array] = {} - qmeta = quant_obj.get("qmeta", {}) - passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) - for name, q in quant_obj["quantized"].items(): - q_np = np.asarray(q, dtype=np.int8) - dtype_name = quant_obj["dtypes"][name] - scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) - if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: - # Broadcast the saved row scale back across trailing dimensions. - out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) - else: - out_arr = q_np.astype(np.float32) * float(scale) - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) - for name, arr in quant_obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_arr = np.array(arr, copy=True) - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) - else: - out[name] = mx.array(out_arr) - return out - - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_lut = np.zeros((table_size,), dtype=np.int16) - has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_lut[token_id] = False - if sp.is_byte(token_id): - base_bytes_lut[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_lut[token_id] = True - piece = piece[1:] - base_bytes_lut[token_id] = len(piece.encode("utf-8")) - return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut - - -def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: - # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we - # decode bytes with the exact tokenizer that produced the shards. The manifest - # lets the training script fail fast on accidental dataset/tokenizer mismatches. - dataset_dir = Path(data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - if len(dataset_dir.parents) < 2: - return dataset_dir.name, actual_train_files, None - manifest_path = dataset_dir.parents[1] / "manifest.json" - if not manifest_path.is_file(): - return dataset_dir.name, actual_train_files, None - - manifest = json.loads(manifest_path.read_text(encoding="utf-8")) - dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) - if dataset_entry is None: - return dataset_dir.name, actual_train_files, None - - tokenizer_name = dataset_entry.get("tokenizer_name") - tokenizer_entry = ( - next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) - if tokenizer_name - else None - ) - expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name - if expected_name and Path(tokenizer_path).name != expected_name: - raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") - expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") - if expected_train_files is not None: - expected_train_files = int(expected_train_files) - if actual_train_files > expected_train_files: - raise ValueError( - f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " - f"manifest says {expected_train_files}" - ) - return dataset_dir.name, actual_train_files, expected_train_files - - -def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) - usable = ((tokens.size - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def loss_and_grad_chunked( - args: Hyperparameters, - train_loader: TokenLoader, - compiled_loss_and_grad, -) -> tuple[mx.array, dict]: - chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) - total_tokens = float(sum(chunk_sizes)) - loss_value = mx.array(0.0, dtype=mx.float32) - grad_accum: dict[str, mx.array] | None = None - for chunk_tokens in chunk_sizes: - x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) - loss, grads = compiled_loss_and_grad(x, y) - scale = float(y.size) / total_tokens - loss_value = loss_value + loss.astype(mx.float32) * scale - grad_accum = accumulate_flat_grads(grad_accum, grads, scale) - if args.mlx_eager_eval: - mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory - return loss_value, tree_unflatten(list(grad_accum.items())) - - -def eval_val( - args: Hyperparameters, - compiled_loss, - val_tokens: np.ndarray, - base_bytes_lut: np.ndarray, - has_leading_space_lut: np.ndarray, - is_boundary_token_lut: np.ndarray, - log_fn: Callable[[str], None] | None = None, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - val_batch_seqs = val_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.size - 1) // args.train_seq_len - total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) - total_loss_sum = 0.0 - total_tokens = 0.0 - total_bytes = 0.0 - for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): - batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - chunk = val_tokens[raw_start:raw_end] - x_np = chunk[:-1].reshape(-1, args.train_seq_len) - y_np = chunk[1:].reshape(-1, args.train_seq_len) - x = mx.array(x_np, dtype=mx.int32) - y = mx.array(y_np, dtype=mx.int32) - chunk_token_count = float(y.size) - batch_loss = compiled_loss(x, y).astype(mx.float32) - mx.eval(batch_loss) - total_loss_sum += float(batch_loss.item()) * chunk_token_count - prev_ids = x_np.reshape(-1) - tgt_ids = y_np.reshape(-1) - bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) - bytes_np += ( - has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] - ).astype(np.int16, copy=False) - total_tokens += chunk_token_count - total_bytes += float(bytes_np.astype(np.float64).sum()) - if log_fn is not None and total_batches > 1 and ( - batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 - ): - log_fn(f"val_progress:{batch_idx}/{total_batches}") - val_loss = total_loss_sum / total_tokens - bits_per_token = val_loss / math.log(2.0) - val_bpb = bits_per_token * (total_tokens / total_bytes) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: - if max_norm <= 0: - return grads_tree - flat = dict(tree_flatten(grads_tree)) - total_sq = 0.0 - for grad in flat.values(): - total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) - if total_sq <= 0.0: - return grads_tree - total_norm = math.sqrt(total_sq) - if total_norm <= max_norm: - return grads_tree - scale = max_norm / (total_norm + 1e-12) - return tree_unflatten([(k, g * scale) for k, g in flat.items()]) - - -def main() -> None: - # ============================================================================== - # TOKENIZER + VALIDATION METRIC SETUP - # ============================================================================== - args = Hyperparameters() - out_dir = Path(args.out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - logfile = out_dir / f"{args.run_id}.txt" - print(logfile) - - def log(msg: str, console: bool = True) -> None: - if console: - print(msg) - with logfile.open("a", encoding="utf-8") as f: - print(msg, file=f) - - code = Path(__file__).read_text(encoding="utf-8") - log(code, console=False) - log("=" * 100, console=False) - log(f"Running Python {sys.version}", console=False) - log(f"Running MLX {mx.__version__}", console=False) - log("=" * 100, console=False) - - if not args.tie_embeddings: - raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( - args.data_path, - args.tokenizer_path, - ) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size - ) - - # ============================================================================== - # TRAINING SETUP - # ============================================================================== - mx.random.seed(args.seed) - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - # ============================================================================== - # MODEL + OPTIMIZER SETUP - # ============================================================================== - model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - logit_chunk_tokens=args.logit_chunk_tokens, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - tied_embed_init_std=args.tied_embed_init_std, - qk_gain_init=args.qk_gain_init, - ) - opt = SplitOptimizers(model, args) - - # ============================================================================== - # COMPILED TRAIN / EVAL FUNCTIONS (MLX) - # ============================================================================== - # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example - # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". - # Compiling the model-bound functions and capturing the full model state fixes that while still - # returning gradients only for trainable parameters via nn.value_and_grad(...). - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) - compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), - inputs=model.state, - outputs=model.state, - ) - - # Print config once so logs are self-describing. - n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) - log(f"run_id:{args.run_id}") - log(f"mlx_version:{mx.__version__}") - log(f"train_loader:shards pattern={args.train_files}") - log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") - if expected_train_files is None: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") - elif actual_train_files < expected_train_files: - log( - f"WARNING: train_loader:subset dataset:{dataset_name} " - f"train_shards:{actual_train_files}/{expected_train_files} " - f"new epochs will arrive sooner than the full dataset" - ) - else: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") - log(f"tokenizer_path:{args.tokenizer_path}") - log( - f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " - f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " - f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" - ) - log( - f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " - f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " - f"val_batch_size:{args.val_batch_size} " - f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") - log( - f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " - f"embed_lr:{args.tied_embed_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " - f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" - ) - log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") - log( - f"dtypes tok_emb:{model.tok_emb.weight.dtype} " - f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " - f"skip_weights:{model.skip_weights.dtype}" - ) - - # ============================================================================== - # TRAINING LOOP - # ============================================================================== - if args.warmup_steps > 0: - # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us - # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. - # Instead we run the real train shapes, force the loss/grads to materialize, and then reset - # the loader so measured training still starts from the true init and token window. - for warmup_step in range(args.warmup_steps): - accum: dict[str, mx.array] | None = None - warmup_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - mx.eval(warmup_loss, accum) - mx.synchronize() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - - # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) - warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] - x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) - y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) - warm_val_loss = compiled_loss(x_val, y_val) - mx.eval(warm_val_loss) - mx.synchronize() - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - train_time_ms = 0.0 - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - stop_after_step: int | None = None - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - train_time_ms += 1000.0 * (time.perf_counter() - t0) - # Validation always scans the same fixed full validation split. - val_loss, val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - if step % 25 == 0 or last_step: - log( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" - ) - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) - step_t0 = time.perf_counter() - - accum: dict[str, mx.array] | None = None - train_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - train_loss = train_loss + loss.astype(mx.float32) * grad_scale - if args.mlx_eager_eval: - mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory - - grads = tree_unflatten(list(accum.items())) - grads = clip_grad_tree(grads, args.grad_clip_norm) - train_loss_value = float(train_loss.item()) - opt.step(model, grads, step=step, lr_mul=lr_mul) - mx.synchronize() - - step_ms = 1000.0 * (time.perf_counter() - step_t0) - approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) - tok_s = args.train_batch_tokens / (step_ms / 1000.0) - step += 1 - if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): - log( - f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " - f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" - ) - if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: - stop_after_step = step - - # ============================================================================== - # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL - # ============================================================================== - # We always write a raw artifact and a quantized artifact, then validate the - # quantized roundtrip directly by loading the dequantized tensors back into the - # model and running one final validation pass. - out_path = out_dir / f"{args.run_id}_mlx_model.npz" - flat_state = {k: v for k, v in tree_flatten(model.state)} - mx.savez(str(out_path), **flat_state) - log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") - - quant_obj, quant_stats = quantize_state_dict_int8(flat_state) - quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) - quant_blob = zlib.compress(quant_raw, level=9) - quant_serialized_bytes = len(quant_raw) - quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" - with quant_path.open("wb") as f: - f.write(quant_blob) - quant_file_bytes = quant_path.stat().st_size - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log( - f"serialized_model_int8_zlib:{quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" - ) - - with quant_path.open("rb") as f: - quant_blob_disk = f.read() - quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) - model.update(tree_unflatten(list(quant_flat.items()))) - q_t0 = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) - log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") - log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - -if __name__ == "__main__": - main() diff --git a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_partial.log b/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_partial.log deleted file mode 100644 index 91cec65c1e..0000000000 --- a/records/track_non_record_16mb/2026-03-19_MLXSmokeLocalM1_SP1024_WIP/train_partial.log +++ /dev/null @@ -1,1360 +0,0 @@ -#!/usr/bin/env python3 -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" -from __future__ import annotations - -import glob -import json -import math -import os -import pickle -import sys -import time -import uuid -import zlib -from collections.abc import Callable -from pathlib import Path - -import numpy as np -import sentencepiece as spm - -import mlx.core as mx -import mlx.nn as nn -import mlx.optimizers as optim -from mlx.utils import tree_flatten, tree_unflatten - -# ============================================================================== -# SHARD FORMAT + COMPUTE DTYPE -# ============================================================================== - -COMPUTE_DTYPE = mx.bfloat16 - -# ============================================================================== -# HYPERPARAMETERS -# ============================================================================== -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap -class Hyperparameters: - # Data / tokenizer. - data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed: int = int(os.environ.get("SEED", 1337)) - - # Training loop. These defaults now mirror train_gpt.py on a single process. - iterations: int = int(os.environ.get("ITERATIONS", 20_000)) - val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) - # Validation always uses the full fineweb_val split. - val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) - train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) - # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak - # memory pressure without changing the effective optimizer batch. - mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) - # Force MLX to materialize the graph after every sub-batch, preventing lazy - # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. - # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). - mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) - warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) - warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) - max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - - # Model (defaults match the current baseline setup). - vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) - model_dim: int = int(os.environ.get("MODEL_DIM", 512)) - num_heads: int = int(os.environ.get("NUM_HEADS", 8)) - num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) - mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) - logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) - qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Optimizer. We keep the same per-group defaults as train_gpt.py. - beta1: float = float(os.environ.get("BETA1", 0.9)) - beta2: float = float(os.environ.get("BETA2", 0.95)) - adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) - tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) - matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - out_dir: str = os.environ.get("OUT_DIR", "logs") - - @property - def train_files(self) -> str: - return f"{self.data_path}/fineweb_train_*.bin" - - @property - def val_files(self) -> str: - return f"{self.data_path}/fineweb_val_*.bin" - - @property - def microbatch_tokens(self) -> int: - return self.train_batch_tokens // self.grad_accum_steps - - def lr_mul(self, step: int, elapsed_ms: float) -> float: - if self.warmdown_iters <= 0: - return 1.0 - if self.max_wallclock_seconds <= 0: - warmdown_start = max(self.iterations - self.warmdown_iters, 0) - return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = self.warmdown_iters * step_ms - remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) - - -def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: - usable_total = (total_tokens // seq_len) * seq_len - if usable_total <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) - chunks: list[int] = [] - remaining = usable_total - while remaining > 0: - chunk = min(remaining, usable_chunk) - chunks.append(chunk) - remaining -= chunk - return chunks - - -def accumulate_flat_grads( - accum: dict[str, mx.array] | None, - grads_tree: dict, - scale: float, -) -> dict[str, mx.array]: - flat = dict(tree_flatten(grads_tree)) - if accum is None: - return {k: g * scale for k, g in flat.items()} - for k, g in flat.items(): - accum[k] = accum[k] + g * scale - return accum - - -# ============================================================================== -# MATH HELPERS -# ============================================================================== - -def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: - return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) - - -def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - # Background on Muon: https://kellerjordan.github.io/posts/muon/ - a, b, c = 3.4445, -4.7750, 2.0315 - x = g.astype(mx.float32) - x = x / (mx.sqrt(mx.sum(x * x)) + eps) - transposed = x.shape[0] > x.shape[1] - if transposed: - x = x.T - for _ in range(steps): - a_mat = x @ x.T - b_mat = b * a_mat + c * (a_mat @ a_mat) - x = a * x + b_mat @ x - if transposed: - x = x.T - return x.astype(g.dtype) - - -def load_data_shard(path: Path) -> np.ndarray: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - if self.file_idx == 0: - self.epoch += 1 - if self.log_fn is not None: - self.log_fn( - f"WARNING: starting epoch:{self.epoch} " - f"dataset:{self.dataset_name} train_shards:{len(self.files)}" - ) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> np.ndarray: - chunks: list[np.ndarray] = [] - left = n - while left > 0: - if self.pos >= self.tokens.size: - self.next_file() - k = min(left, int(self.tokens.size - self.pos)) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - left -= k - return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) - - -class TokenLoader: - def __init__( - self, - pattern: str, - log_fn: Callable[[str], None] | None = None, - dataset_name: str = "", - ): - self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) - - def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: - usable = (batch_tokens // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - chunk = self.stream.take(usable + 1) - x = chunk[:-1].reshape(-1, seq_len) - y = chunk[1:].reshape(-1, seq_len) - return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) - - -# ============================================================================== -# MODEL BLOCKS -# ============================================================================== - -class CastedLinear(nn.Module): - def __init__(self, in_dim: int, out_dim: int): - super().__init__() - self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) - - def __call__(self, x: mx.array) -> mx.array: - return x @ self.weight.astype(x.dtype).T - - -class RMSNormNoWeight(nn.Module): - # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. - def __call__(self, x: mx.array) -> mx.array: - return rms_norm(x) - - -class CausalSelfAttention(nn.Module): - # - separate q/k/v projections - # - RMSNorm on q and k before attention - # - RoPE on q and k - # - causal masked SDPA - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim) - self.c_k = CastedLinear(dim, kv_dim) - self.c_v = CastedLinear(dim, kv_dim) - self.proj = CastedLinear(dim, dim) - self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init - self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) - self.scale = self.head_dim ** -0.5 - - def __call__(self, x: mx.array) -> mx.array: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - - q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) - k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) - q = q * self.q_gain.astype(q.dtype)[None, :, None, None] - y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") - y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = dim * mlp_mult - self.fc = CastedLinear(dim, hidden) - self.proj = CastedLinear(hidden, dim) - - def __call__(self, x: mx.array) -> mx.array: - x = nn.relu(self.fc(x)) - return self.proj(x * x) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNormNoWeight() - self.mlp_norm = RMSNormNoWeight() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = mx.ones((dim,), dtype=mx.float32) - self.mlp_scale = mx.ones((dim,), dtype=mx.float32) - self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) - - def __call__(self, x: mx.array, x0: mx.array) -> mx.array: - mix = self.resid_mix.astype(x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - # - token embedding + RMSNorm - # - encoder half accumulates skip tensors - # - decoder half consumes reversed skips with learned skip_weights - # - tied embeddings for the LM head (the baseline default setup) - def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, - qk_gain_init: float): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.logit_chunk_tokens = logit_chunk_tokens - self.logit_softcap = logit_softcap - - self.tok_emb = nn.Embedding(vocab_size, dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) - self.blocks = [ - Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for i in range(num_layers) - ] - self.final_norm = RMSNormNoWeight() - - for b in self.blocks: - b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) - b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) - self.tok_emb.weight = ( - mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std - ).astype(COMPUTE_DTYPE) - - def softcap(self, logits: mx.array) -> mx.array: - c = self.logit_softcap - return c * mx.tanh(logits / c) - - def __call__(self, input_ids: mx.array) -> mx.array: - x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) - x0 = x - skips: list[mx.array] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - # Odd layer counts have one more decoder block than encoder block. The baseline only - # applies a skip connection when one exists, then runs the remaining decoder block(s) - # without an added skip. - if skips: - x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - return self.final_norm(x) - - def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: - # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful - # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). - x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) - y = target_ids.reshape(-1) - if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: - logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") - - loss_sum = mx.array(0.0, dtype=mx.float32) - n = int(x.shape[0]) - for s in range(0, n, self.logit_chunk_tokens): - e = min(s + self.logit_chunk_tokens, n) - logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") - return loss_sum / float(n) - -# ============================================================================== -# OPTIMIZERS (MUON + ADAM SPLIT) -# ============================================================================== -class Muon: - # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the - # parameter update. - def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): - self.keys = keys - self.args = args - self.buffers = {k: mx.zeros_like(params[k]) for k in keys} - - def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: - if self.args.muon_momentum_warmup_steps: - t = min(step / self.args.muon_momentum_warmup_steps, 1.0) - momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum - else: - momentum = self.args.muon_momentum - lr = self.args.matrix_lr * lr_mul - out: dict[str, mx.array] = {} - for k in self.keys: - p = params[k] - g = grads[k] - buf = momentum * self.buffers[k] + g - self.buffers[k] = buf - g_eff = g + momentum * buf - g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) - scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) - out[k] = p - lr * (g_ortho * scale).astype(p.dtype) - return out - - -class SplitOptimizers: - # - embeddings: Adam with the tied-embedding LR - # - block matrices (2D): Muon - # - block scalars + skip weights: Adam - # This preserves the high-level optimization behavior even though MLX internals differ. - def __init__(self, model: GPT, args: Hyperparameters): - self.args = args - params = dict(tree_flatten(model.parameters())) - self.embed_key = "tok_emb.weight" - self.matrix_keys = [ - k - for k, p in params.items() - if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - self.scalar_keys = [ - k - for k, p in params.items() - if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) - ] - - self.muon = Muon(self.matrix_keys, params, args) - self.adam_embed = optim.Adam( - learning_rate=args.tied_embed_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - self.adam_scalar = optim.Adam( - learning_rate=args.scalar_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - - def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: - params = dict(tree_flatten(model.parameters())) - grads = dict(tree_flatten(grads_tree)) - updated = dict(params) - - updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) - - self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul - updated.update( - self.adam_embed.apply_gradients( - {self.embed_key: grads[self.embed_key]}, - {self.embed_key: params[self.embed_key]}, - ) - ) - - self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul - scalar_grads = {k: grads[k] for k in self.scalar_keys} - scalar_params = {k: params[k] for k in self.scalar_keys} - updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) - - model.update(tree_unflatten(list(updated.items()))) - -# ============================================================================== -# QUANTIZATION (INT8 + ZLIB) -# ============================================================================== -# - per-row int8 for 2D float tensors -# - per-tensor int8 for other float tensors -# - fp16 passthrough for small float tensors -# - exact passthrough for non-floats - -MX_DTYPE_FROM_NAME = { - "float32": mx.float32, - "float16": mx.float16, - "bfloat16": mx.bfloat16, -} - -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 -INT8_PER_ROW_SCALE_DTYPE = np.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - - -def _np_float32(arr: mx.array) -> np.ndarray: - return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) - - -def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return np.ascontiguousarray(_np_float32(arr)) - if arr.dtype in {mx.float32, mx.bfloat16}: - passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] - return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) - return np.ascontiguousarray(np.array(arr, copy=True)) - - -def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: - f32 = _np_float32(arr) - if f32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) - clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) - scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) - q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 - scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) - q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), scale - - -def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: - quantized: dict[str, np.ndarray] = {} - scales: dict[str, np.ndarray] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, np.ndarray] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, arr in flat_state.items(): - stats["param_count"] += int(arr.size) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += int(arr.nbytes) - if not mx.issubdtype(arr.dtype, mx.floating): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = np.ascontiguousarray(np.array(arr)) - stats["int8_payload_bytes"] += int(passthrough[name].nbytes) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_array(name, arr, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += int(kept.nbytes) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_array(arr) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(arr.dtype).split(".")[-1] - stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - - -def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: - out: dict[str, mx.array] = {} - qmeta = quant_obj.get("qmeta", {}) - passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) - for name, q in quant_obj["quantized"].items(): - q_np = np.asarray(q, dtype=np.int8) - dtype_name = quant_obj["dtypes"][name] - scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) - if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: - # Broadcast the saved row scale back across trailing dimensions. - out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) - else: - out_arr = q_np.astype(np.float32) * float(scale) - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) - for name, arr in quant_obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_arr = np.array(arr, copy=True) - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) - else: - out[name] = mx.array(out_arr) - return out - - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_lut = np.zeros((table_size,), dtype=np.int16) - has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_lut[token_id] = False - if sp.is_byte(token_id): - base_bytes_lut[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_lut[token_id] = True - piece = piece[1:] - base_bytes_lut[token_id] = len(piece.encode("utf-8")) - return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut - - -def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: - # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we - # decode bytes with the exact tokenizer that produced the shards. The manifest - # lets the training script fail fast on accidental dataset/tokenizer mismatches. - dataset_dir = Path(data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - if len(dataset_dir.parents) < 2: - return dataset_dir.name, actual_train_files, None - manifest_path = dataset_dir.parents[1] / "manifest.json" - if not manifest_path.is_file(): - return dataset_dir.name, actual_train_files, None - - manifest = json.loads(manifest_path.read_text(encoding="utf-8")) - dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) - if dataset_entry is None: - return dataset_dir.name, actual_train_files, None - - tokenizer_name = dataset_entry.get("tokenizer_name") - tokenizer_entry = ( - next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) - if tokenizer_name - else None - ) - expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name - if expected_name and Path(tokenizer_path).name != expected_name: - raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") - expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") - if expected_train_files is not None: - expected_train_files = int(expected_train_files) - if actual_train_files > expected_train_files: - raise ValueError( - f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " - f"manifest says {expected_train_files}" - ) - return dataset_dir.name, actual_train_files, expected_train_files - - -def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) - usable = ((tokens.size - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def loss_and_grad_chunked( - args: Hyperparameters, - train_loader: TokenLoader, - compiled_loss_and_grad, -) -> tuple[mx.array, dict]: - chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) - total_tokens = float(sum(chunk_sizes)) - loss_value = mx.array(0.0, dtype=mx.float32) - grad_accum: dict[str, mx.array] | None = None - for chunk_tokens in chunk_sizes: - x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) - loss, grads = compiled_loss_and_grad(x, y) - scale = float(y.size) / total_tokens - loss_value = loss_value + loss.astype(mx.float32) * scale - grad_accum = accumulate_flat_grads(grad_accum, grads, scale) - if args.mlx_eager_eval: - mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory - return loss_value, tree_unflatten(list(grad_accum.items())) - - -def eval_val( - args: Hyperparameters, - compiled_loss, - val_tokens: np.ndarray, - base_bytes_lut: np.ndarray, - has_leading_space_lut: np.ndarray, - is_boundary_token_lut: np.ndarray, - log_fn: Callable[[str], None] | None = None, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - val_batch_seqs = val_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.size - 1) // args.train_seq_len - total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) - total_loss_sum = 0.0 - total_tokens = 0.0 - total_bytes = 0.0 - for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): - batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - chunk = val_tokens[raw_start:raw_end] - x_np = chunk[:-1].reshape(-1, args.train_seq_len) - y_np = chunk[1:].reshape(-1, args.train_seq_len) - x = mx.array(x_np, dtype=mx.int32) - y = mx.array(y_np, dtype=mx.int32) - chunk_token_count = float(y.size) - batch_loss = compiled_loss(x, y).astype(mx.float32) - mx.eval(batch_loss) - total_loss_sum += float(batch_loss.item()) * chunk_token_count - prev_ids = x_np.reshape(-1) - tgt_ids = y_np.reshape(-1) - bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) - bytes_np += ( - has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] - ).astype(np.int16, copy=False) - total_tokens += chunk_token_count - total_bytes += float(bytes_np.astype(np.float64).sum()) - if log_fn is not None and total_batches > 1 and ( - batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 - ): - log_fn(f"val_progress:{batch_idx}/{total_batches}") - val_loss = total_loss_sum / total_tokens - bits_per_token = val_loss / math.log(2.0) - val_bpb = bits_per_token * (total_tokens / total_bytes) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: - if max_norm <= 0: - return grads_tree - flat = dict(tree_flatten(grads_tree)) - total_sq = 0.0 - for grad in flat.values(): - total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) - if total_sq <= 0.0: - return grads_tree - total_norm = math.sqrt(total_sq) - if total_norm <= max_norm: - return grads_tree - scale = max_norm / (total_norm + 1e-12) - return tree_unflatten([(k, g * scale) for k, g in flat.items()]) - - -def main() -> None: - # ============================================================================== - # TOKENIZER + VALIDATION METRIC SETUP - # ============================================================================== - args = Hyperparameters() - out_dir = Path(args.out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - logfile = out_dir / f"{args.run_id}.txt" - print(logfile) - - def log(msg: str, console: bool = True) -> None: - if console: - print(msg) - with logfile.open("a", encoding="utf-8") as f: - print(msg, file=f) - - code = Path(__file__).read_text(encoding="utf-8") - log(code, console=False) - log("=" * 100, console=False) - log(f"Running Python {sys.version}", console=False) - log(f"Running MLX {mx.__version__}", console=False) - log("=" * 100, console=False) - - if not args.tie_embeddings: - raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( - args.data_path, - args.tokenizer_path, - ) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size - ) - - # ============================================================================== - # TRAINING SETUP - # ============================================================================== - mx.random.seed(args.seed) - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - # ============================================================================== - # MODEL + OPTIMIZER SETUP - # ============================================================================== - model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - logit_chunk_tokens=args.logit_chunk_tokens, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - tied_embed_init_std=args.tied_embed_init_std, - qk_gain_init=args.qk_gain_init, - ) - opt = SplitOptimizers(model, args) - - # ============================================================================== - # COMPILED TRAIN / EVAL FUNCTIONS (MLX) - # ============================================================================== - # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example - # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". - # Compiling the model-bound functions and capturing the full model state fixes that while still - # returning gradients only for trainable parameters via nn.value_and_grad(...). - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) - compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), - inputs=model.state, - outputs=model.state, - ) - - # Print config once so logs are self-describing. - n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) - log(f"run_id:{args.run_id}") - log(f"mlx_version:{mx.__version__}") - log(f"train_loader:shards pattern={args.train_files}") - log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") - if expected_train_files is None: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") - elif actual_train_files < expected_train_files: - log( - f"WARNING: train_loader:subset dataset:{dataset_name} " - f"train_shards:{actual_train_files}/{expected_train_files} " - f"new epochs will arrive sooner than the full dataset" - ) - else: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") - log(f"tokenizer_path:{args.tokenizer_path}") - log( - f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " - f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " - f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" - ) - log( - f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " - f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " - f"val_batch_size:{args.val_batch_size} " - f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") - log( - f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " - f"embed_lr:{args.tied_embed_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " - f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" - ) - log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") - log( - f"dtypes tok_emb:{model.tok_emb.weight.dtype} " - f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " - f"skip_weights:{model.skip_weights.dtype}" - ) - - # ============================================================================== - # TRAINING LOOP - # ============================================================================== - if args.warmup_steps > 0: - # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us - # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. - # Instead we run the real train shapes, force the loss/grads to materialize, and then reset - # the loader so measured training still starts from the true init and token window. - for warmup_step in range(args.warmup_steps): - accum: dict[str, mx.array] | None = None - warmup_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - mx.eval(warmup_loss, accum) - mx.synchronize() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - - # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) - warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] - x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) - y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) - warm_val_loss = compiled_loss(x_val, y_val) - mx.eval(warm_val_loss) - mx.synchronize() - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - train_time_ms = 0.0 - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - stop_after_step: int | None = None - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - train_time_ms += 1000.0 * (time.perf_counter() - t0) - # Validation always scans the same fixed full validation split. - val_loss, val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - if step % 25 == 0 or last_step: - log( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" - ) - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) - step_t0 = time.perf_counter() - - accum: dict[str, mx.array] | None = None - train_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - train_loss = train_loss + loss.astype(mx.float32) * grad_scale - if args.mlx_eager_eval: - mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory - - grads = tree_unflatten(list(accum.items())) - grads = clip_grad_tree(grads, args.grad_clip_norm) - train_loss_value = float(train_loss.item()) - opt.step(model, grads, step=step, lr_mul=lr_mul) - mx.synchronize() - - step_ms = 1000.0 * (time.perf_counter() - step_t0) - approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) - tok_s = args.train_batch_tokens / (step_ms / 1000.0) - step += 1 - if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): - log( - f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " - f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" - ) - if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: - stop_after_step = step - - # ============================================================================== - # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL - # ============================================================================== - # We always write a raw artifact and a quantized artifact, then validate the - # quantized roundtrip directly by loading the dequantized tensors back into the - # model and running one final validation pass. - out_path = out_dir / f"{args.run_id}_mlx_model.npz" - flat_state = {k: v for k, v in tree_flatten(model.state)} - mx.savez(str(out_path), **flat_state) - log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") - - quant_obj, quant_stats = quantize_state_dict_int8(flat_state) - quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) - quant_blob = zlib.compress(quant_raw, level=9) - quant_serialized_bytes = len(quant_raw) - quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" - with quant_path.open("wb") as f: - f.write(quant_blob) - quant_file_bytes = quant_path.stat().st_size - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log( - f"serialized_model_int8_zlib:{quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" - ) - - with quant_path.open("rb") as f: - quant_blob_disk = f.read() - quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) - model.update(tree_unflatten(list(quant_flat.items()))) - q_t0 = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) - log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") - log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.14.2 (main, Dec 5 2025, 16:49:16) [Clang 17.0.0 (clang-1700.4.4.1)] -Running MLX 0.31.1 -==================================================================================================== -run_id:stukenov_mlx_smoke -mlx_version:0.31.1 -train_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_train_*.bin -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -WARNING: train_loader:subset dataset:fineweb10B_sp1024 train_shards:1/195 new epochs will arrive sooner than the full dataset -tokenizer_path:./data/tokenizers/fineweb_1024_bpe.model -model_params:17059912 vocab_size:1024 layers:9 dim:512 heads:8 kv_heads:4 seq_len:1024 tie_embeddings:True -iterations:200 train_batch_tokens:8192 grad_accum_steps:8 microbatch_tokens:1024 microbatch_batch_size:1 val_batch_size:8192 warmup_steps:20 max_wallclock_seconds:600.000 -mlx_max_microbatch_tokens:8192 -optimizer:muon+adam muon_matrix_params:54 scalar_params:37 embed_lr:0.05 matrix_lr:0.04 scalar_lr:0.04 muon_momentum:0.95 muon_steps:5 -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -compute_dtype:mlx.core.bfloat16 compile:True -dtypes tok_emb:mlx.core.bfloat16 linear_weight:mlx.core.float32 skip_weights:mlx.core.float32 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/200 train_loss:6.9428 train_time:1474ms step_avg:1474.19ms tok_s:5557 -step:2/200 train_loss:18.7855 train_time:3330ms step_avg:1665.02ms tok_s:4416 -step:3/200 train_loss:16.6689 train_time:5029ms step_avg:1676.42ms tok_s:4822 -step:4/200 train_loss:14.0871 train_time:6728ms step_avg:1682.08ms tok_s:4823 -step:5/200 train_loss:11.7481 train_time:8424ms step_avg:1684.83ms tok_s:4832 -step:6/200 train_loss:9.4689 train_time:10150ms step_avg:1691.60ms tok_s:4748 -step:7/200 train_loss:7.9685 train_time:11838ms step_avg:1691.14ms tok_s:4854 -step:8/200 train_loss:7.1412 train_time:13515ms step_avg:1689.41ms tok_s:4885 -step:9/200 train_loss:6.6587 train_time:15235ms step_avg:1692.74ms tok_s:4765 -step:10/200 train_loss:6.2999 train_time:16919ms step_avg:1691.86ms tok_s:4866 -step:50/200 train_loss:4.8429 train_time:84371ms step_avg:1687.41ms tok_s:4838 -step:100/200 train_loss:4.4248 train_time:168819ms step_avg:1688.19ms tok_s:4831 -step:150/200 train_loss:4.1391 train_time:271291ms step_avg:1808.61ms tok_s:4804 -step:200/200 train_loss:3.9233 train_time:362934ms step_avg:1814.67ms tok_s:4182 -val_progress:1/60568 -val_progress:25/60568 -val_progress:50/60568 -val_progress:75/60568 -val_progress:100/60568 -val_progress:125/60568 -val_progress:150/60568 -val_progress:175/60568 -val_progress:200/60568 -val_progress:225/60568 -val_progress:250/60568 -val_progress:275/60568 -val_progress:300/60568 -val_progress:325/60568 -val_progress:350/60568 -val_progress:375/60568 -val_progress:400/60568 -val_progress:425/60568 -val_progress:450/60568 -val_progress:475/60568 -val_progress:500/60568 -val_progress:525/60568 -val_progress:550/60568 -val_progress:575/60568 -val_progress:600/60568 -val_progress:625/60568 -val_progress:650/60568 -val_progress:675/60568 -val_progress:700/60568 -val_progress:725/60568 -val_progress:750/60568 -val_progress:775/60568 -val_progress:800/60568 -val_progress:825/60568 -val_progress:850/60568 -val_progress:875/60568 -val_progress:900/60568 -val_progress:925/60568 -val_progress:950/60568 -val_progress:975/60568 -val_progress:1000/60568 -val_progress:1025/60568 -val_progress:1050/60568 -val_progress:1075/60568 -val_progress:1100/60568 -val_progress:1125/60568 -val_progress:1150/60568 -val_progress:1175/60568 -val_progress:1200/60568 -val_progress:1225/60568 -val_progress:1250/60568 -val_progress:1275/60568 -val_progress:1300/60568 -val_progress:1325/60568 -val_progress:1350/60568 -val_progress:1375/60568 -val_progress:1400/60568 -val_progress:1425/60568 -val_progress:1450/60568 -val_progress:1475/60568 -val_progress:1500/60568 -val_progress:1525/60568 -val_progress:1550/60568 -val_progress:1575/60568 -val_progress:1600/60568 -val_progress:1625/60568 -val_progress:1650/60568 -val_progress:1675/60568 -val_progress:1700/60568 -val_progress:1725/60568 -val_progress:1750/60568 -val_progress:1775/60568 -val_progress:1800/60568 -val_progress:1825/60568 -val_progress:1850/60568 -val_progress:1875/60568 -val_progress:1900/60568 -val_progress:1925/60568 -val_progress:1950/60568 -val_progress:1975/60568 -val_progress:2000/60568 -val_progress:2025/60568 -val_progress:2050/60568 -val_progress:2075/60568 -val_progress:2100/60568 -val_progress:2125/60568 -val_progress:2150/60568 -val_progress:2175/60568 -val_progress:2200/60568 -val_progress:2225/60568 -val_progress:2250/60568 -val_progress:2275/60568 -val_progress:2300/60568 -val_progress:2325/60568 -val_progress:2350/60568 -val_progress:2375/60568 -val_progress:2400/60568 -val_progress:2425/60568 -val_progress:2450/60568 -val_progress:2475/60568 -val_progress:2500/60568 -val_progress:2525/60568 -val_progress:2550/60568 -val_progress:2575/60568 -val_progress:2600/60568 -val_progress:2625/60568 -val_progress:2650/60568 -val_progress:2675/60568 -val_progress:2700/60568 -val_progress:2725/60568 -val_progress:2750/60568 -val_progress:2775/60568 -val_progress:2800/60568 -val_progress:2825/60568 -val_progress:2850/60568 -val_progress:2875/60568 -val_progress:2900/60568 -val_progress:2925/60568 -val_progress:2950/60568 -val_progress:2975/60568 -val_progress:3000/60568 -val_progress:3025/60568 -val_progress:3050/60568 -val_progress:3075/60568 -val_progress:3100/60568 -val_progress:3125/60568 -val_progress:3150/60568 -val_progress:3175/60568 -val_progress:3200/60568 -val_progress:3225/60568 -val_progress:3250/60568 -val_progress:3275/60568 -val_progress:3300/60568 -val_progress:3325/60568 -val_progress:3350/60568 -val_progress:3375/60568 -val_progress:3400/60568 -val_progress:3425/60568 -val_progress:3450/60568 -val_progress:3475/60568 -val_progress:3500/60568 -val_progress:3525/60568 -val_progress:3550/60568 -val_progress:3575/60568 -val_progress:3600/60568 -val_progress:3625/60568 -val_progress:3650/60568 -val_progress:3675/60568 -val_progress:3700/60568 -val_progress:3725/60568 -val_progress:3750/60568 -val_progress:3775/60568 -val_progress:3800/60568 -val_progress:3825/60568 -val_progress:3850/60568 -val_progress:3875/60568 -val_progress:3900/60568 -val_progress:3925/60568 -val_progress:3950/60568 -val_progress:3975/60568 -val_progress:4000/60568 -val_progress:4025/60568 -val_progress:4050/60568 -val_progress:4075/60568 -val_progress:4100/60568 -val_progress:4125/60568 -val_progress:4150/60568 -val_progress:4175/60568 -val_progress:4200/60568 -val_progress:4225/60568 -val_progress:4250/60568 -val_progress:4275/60568 -val_progress:4300/60568 -val_progress:4325/60568 -val_progress:4350/60568 -val_progress:4375/60568 -val_progress:4400/60568 -val_progress:4425/60568 -val_progress:4450/60568 -val_progress:4475/60568 -val_progress:4500/60568 -val_progress:4525/60568 -val_progress:4550/60568 -val_progress:4575/60568 -val_progress:4600/60568 -val_progress:4625/60568 -val_progress:4650/60568 -val_progress:4675/60568 -val_progress:4700/60568 -val_progress:4725/60568 -val_progress:4750/60568 -val_progress:4775/60568 -val_progress:4800/60568 -val_progress:4825/60568 -val_progress:4850/60568 -val_progress:4875/60568 -val_progress:4900/60568 -val_progress:4925/60568 -val_progress:4950/60568 -val_progress:4975/60568 -val_progress:5000/60568 -val_progress:5025/60568 -val_progress:5050/60568 -val_progress:5075/60568 diff --git a/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/README.md b/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/README.md new file mode 100644 index 0000000000..c16c1497ce --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/README.md @@ -0,0 +1,57 @@ +This record captures a finished non-record smoke submission built from the current root `train_gpt.py`, with a small CUDA compatibility patch in the local copy used for the run. + +This run is not intended for the 10-minute leaderboard. It is a short, fully completed non-record baseline on a single RTX 4090 using the fixed full FineWeb validation split and a single training shard. The main purpose is to document a clean, reproducible CUDA submission path with final metrics, artifact bytes, and logs. + +## Why this script differs slightly from root + +The Vast.ai image used for this run shipped with a PyTorch build that does not accept the `enable_gqa=` argument on `scaled_dot_product_attention`. To keep the run reproducible on that image, the copied `train_gpt.py` expands KV heads manually when `num_kv_heads != num_heads` and then calls `scaled_dot_product_attention` without `enable_gqa`. + +The model, tokenizer, data, and training setup otherwise follow the baseline configuration. + +## Configuration + +- Track: `non-record`, unlimited compute, still under the `16,000,000` byte artifact cap +- Hardware: `1x RTX 4090` on Vast.ai +- Tokenizer / dataset: `sp1024`, full fixed `fineweb_val_*`, `1` training shard +- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` +- Tied embeddings: `TIE_EMBEDDINGS=1` +- Batching: `TRAIN_BATCH_TOKENS=8192 TRAIN_SEQ_LEN=1024` +- Validation cadence: final-only validation on the full fixed validation split +- Training length: `ITERATIONS=50` + +## Command + +```bash +RUN_ID=stukenov_4090_smoke50 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +TRAIN_BATCH_TOKENS=8192 \ +TRAIN_SEQ_LEN=1024 \ +VAL_BATCH_SIZE=65536 \ +VAL_LOSS_EVERY=0 \ +TRAIN_LOG_EVERY=25 \ +ITERATIONS=50 \ +MAX_WALLCLOCK_SECONDS=0 \ +OMP_NUM_THREADS=1 \ +TORCH_NCCL_ASYNC_ERROR_HANDLING=1 \ +python train_gpt.py +``` + +## Key Metrics + +- Training stopped at `50/50` steps. +- Pre-quant eval at stop: `val_loss:5.3102`, `val_bpb:3.1450` +- Post-quant int8+zlib roundtrip: `val_loss:5.7139`, `val_bpb:3.3841` +- Exact printed metric: `final_int8_zlib_roundtrip_exact val_loss:5.71391837 val_bpb:3.38410431` +- Train time: `12070ms` (`step_avg:241.40ms`) +- Eval time: `28444ms` +- Peak memory: `565 MiB allocated`, `750 MiB reserved` +- Serialized model int8+zlib: `5121054 bytes` +- Code size: `47999 bytes` +- Total submission size int8+zlib: `5169053 bytes` + +## Included Files + +- `train_gpt.py` (exact code snapshot used for the run) +- `train.log` (exact training log) +- `submission.json` (metadata for this non-record run) diff --git a/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/submission.json b/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/submission.json new file mode 100644 index 0000000000..0096441f09 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/submission.json @@ -0,0 +1,18 @@ +{ + "author": "Saken Tukenov", + "github_id": "stukenov", + "name": "1x RTX 4090 Compat Smoke (50 steps)", + "blurb": "Finished non-record smoke run on 1x RTX 4090 using the baseline 9x512 SP-1024 architecture, one FineWeb training shard, and the full fixed validation split. Uses a small compatibility fallback in train_gpt.py to expand KV heads manually on a PyTorch 2.4 image that lacks enable_gqa support. Post-quant int8+zlib artifact remains under the 16,000,000-byte cap.", + "date": "2026-03-20", + "track": "non-record-unlimited-compute-16mb", + "val_loss": 5.71391837, + "val_bpb": 3.38410431, + "pre_quant_val_loss": 5.3102, + "pre_quant_val_bpb": 3.1450, + "step_stop": 50, + "wallclock_seconds": 12.070, + "eval_time_seconds": 28.444, + "bytes_total": 5169053, + "bytes_model_int8_zlib": 5121054, + "bytes_code": 47999 +} diff --git a/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/train.log b/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/train.log new file mode 100644 index 0000000000..c31e3c9cbe --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/train.log @@ -0,0 +1,1208 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # PyTorch 2.4.x does not expose enable_gqa on scaled_dot_product_attention. + # Expand KV heads manually so this script still runs on common CUDA images. + if self.num_kv_heads != self.num_heads: + repeat = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(repeat, dim=1) + v = v.repeat_interleave(repeat, dim=1) + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] +Running PyTorch 2.4.1+cu124 +Thu Mar 19 19:52:05 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 4090 On | 00000000:82:00.0 Off | Off | +| 0% 36C P8 26W / 450W | 4MiB / 24564MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:8192 train_seq_len:1024 iterations:50 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/50 train_loss:6.9447 train_time:265ms step_avg:265.35ms +step:2/50 train_loss:6.7849 train_time:527ms step_avg:263.47ms +step:3/50 train_loss:6.5552 train_time:790ms step_avg:263.50ms +step:4/50 train_loss:6.3388 train_time:1006ms step_avg:251.44ms +step:5/50 train_loss:6.1575 train_time:1220ms step_avg:243.90ms +step:6/50 train_loss:5.9918 train_time:1428ms step_avg:238.06ms +step:7/50 train_loss:5.8398 train_time:1634ms step_avg:233.46ms +step:8/50 train_loss:5.8209 train_time:1837ms step_avg:229.62ms +step:9/50 train_loss:5.8000 train_time:2077ms step_avg:230.73ms +step:10/50 train_loss:5.7026 train_time:2303ms step_avg:230.35ms +step:25/50 train_loss:5.4326 train_time:5719ms step_avg:228.77ms +step:50/50 train_loss:5.2853 train_time:12070ms step_avg:241.40ms +step:50/50 val_loss:5.3102 val_bpb:3.1450 train_time:12070ms step_avg:241.41ms +peak memory allocated: 565 MiB reserved: 750 MiB +Serialized model: 67224578 bytes +Code size: 47999 bytes +Total submission size: 67272577 bytes +Serialized model int8+zlib: 5121054 bytes (payload:17178912 raw_torch:17223564 payload_ratio:3.91x) +Total submission size int8+zlib: 5169053 bytes +final_int8_zlib_roundtrip val_loss:5.7139 val_bpb:3.3841 eval_time:28444ms +final_int8_zlib_roundtrip_exact val_loss:5.71391837 val_bpb:3.38410431 diff --git a/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/train_gpt.py b/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/train_gpt.py new file mode 100644 index 0000000000..6ba4c1920d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_CUDA4090_CompatSmoke_SP1024/train_gpt.py @@ -0,0 +1,1131 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # PyTorch 2.4.x does not expose enable_gqa on scaled_dot_product_attention. + # Expand KV heads manually so this script still runs on common CUDA images. + if self.num_kv_heads != self.num_heads: + repeat = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(repeat, dim=1) + v = v.repeat_interleave(repeat, dim=1) + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From a1a5ba8253c9676551332bce7cecd0ab02e298f7 Mon Sep 17 00:00:00 2001 From: Saken Tukenov Date: Sun, 3 May 2026 19:43:36 +0500 Subject: [PATCH 3/4] ci: add dependabot configuration --- .github/dependabot.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..7afc10865b --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + groups: + all-updates: + patterns: + - "*" From 9a58c02aac83d9d819327faa572a200c4fc56dd2 Mon Sep 17 00:00:00 2001 From: stukenov Date: Sun, 3 May 2026 20:21:35 +0500 Subject: [PATCH 4/4] ci: add minimal CI workflow --- .github/workflows/ci.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000000..2bed41008c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,14 @@ +name: CI +on: + push: + branches: [main] + pull_request: +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install -r requirements.txt