Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions .agents/logbooks/moe-adamh-grad-norm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# MoE AdamH Gradient Normalization: Research Logbook

## Scope

- Goal: test whether normalizing each module's gradients to RMS 1 before AdamH
moment updates improves Grug MoE training relative to the current AdamH
recipe.
- Primary metrics: final `eval/paloma/macro_loss`, effective speedup versus
the compute-optimal MoE baseline.
- Secondary metrics: `throughput/tokens_per_second`,
`throughput/total_tokens`, routing balance/stability metrics.
- Constraints: compare at the README compute-optimal d512, d768, d1024, and
d1280 budgets; run gate 1 before gate 2.
- Issue: https://github.com/marin-community/marin/issues/5180

## Baseline

- Date: 2026-04-25
- Code refs:
- `experiments/grug/moe/README.md`
- `experiments/grug/moe/adamh.py`
- `experiments/grug/moe/optimizer.py`
- `experiments/grug/moe/launch.py`
- Baseline numbers: compute-optimal d512, d768, d1024, and d1280 table in
`experiments/grug/moe/README.md`.

## Experiment Log

### 2026-04-25 11:45 - Kickoff

- Hypothesis: module-wise gradient RMS normalization reduces scale mismatch
between attention, shared expert, and routed expert AdamH groups without
changing AdamH's projected parameter update rule.
- Command: local implementation and unit tests.
- Config:
- optimizer: `GrugMoeAdamHGradientNormConfig`
- normalization: each module's gradient leaves are scaled to combined RMS 1
before AdamH moment updates.
- gate 1 launch: `GRUG_MOE_ADAMH_GRAD_NORM_GATE=gate1 python -m experiments.grug.moe.launch_adamh_grad_norm`
- gate 2 launch: `GRUG_MOE_ADAMH_GRAD_NORM_GATE=gate2 python -m experiments.grug.moe.launch_adamh_grad_norm`
- Result: implementation passes focused optimizer tests and pre-commit locally.
- Interpretation: ready to run gate 1 against d512 and d768 baselines.
- Next action: create the GitHub experiment issue, push the branch, and submit
the gate 1 Iris job.

### 2026-04-25 11:46 - MOE-AGNH-001 gate 1 submitted

- Hypothesis: the gradient-normalized AdamH variant can match or improve the
d512 and d768 compute-optimal MoE baselines without hurting throughput.
- Command:
`.venv/bin/iris --config lib/iris/examples/marin.yaml job run --no-wait --memory=2G --disk=4G --cpu=1 --extra=cpu --reserve v5p-8 -e WANDB_API_KEY "$WANDB_API_KEY" -e GRUG_MOE_ADAMH_GRAD_NORM_GATE gate1 -- python -m experiments.grug.moe.launch_adamh_grad_norm`
- Config:
- commit: `d97fdc3131750bfaab552d459fd14e51bf8a2860`
- issue: https://github.com/marin-community/marin/issues/5180
- PR: https://github.com/marin-community/marin/pull/5181
- Iris parent job: `/kaiyue/iris-run-job-20260425-184632`
- data browser:
https://marin.community/data-browser/experiment?path=gs%3A//marin-us-east5/experiments/launch_adamh_grad_norm-7614bf.json
- W&B d512:
https://wandb.ai/understanding-sam/marin_moe/runs/moe-adamh-grad-norm-d512-2p19e17
- W&B d768:
https://wandb.ai/understanding-sam/marin_moe/runs/moe-adamh-grad-norm-d768-1p70e18
- Result: submitted to Iris and dispatched both Fray TPU child jobs. The first
TPU allocation hit a device-busy bad-node signature; Iris retried
automatically and both child jobs reached `JOB_STATE_RUNNING`.
- Interpretation: gate 1 is healthy enough to monitor for first eval/final
Paloma macro loss rather than resubmit immediately.
- Next action: babysit the d512 and d768 jobs to terminal state, compare final
metrics with the README baselines, then decide whether to launch gate 2.

### 2026-04-25 12:40 - MOE-AGNH-001 d512 final

- Hypothesis: d512 should show effective speedup greater than 1 versus the
README compute-optimal baseline.
- Command: W&B API summary pull for
`understanding-sam/marin_moe/moe-adamh-grad-norm-d512-2p19e17`.
- Config:
- budget: `2.19e17`
- baseline Paloma macro loss: `3.8104`
- baseline throughput: `405630` tokens/s
- variant Paloma macro loss: `3.815110206604004`
- variant throughput: `406982.727149109` tokens/s
- variant total tokens: `837156864`
- Result: d512 Iris child job succeeded. Effective speedup is `0.980893`;
loss delta is `+0.004710` and throughput delta is `+0.333%`.
- Interpretation: d512 does not clear gate 1. Since the d768 gate 1 child job
is already running, continue it to terminal state for the complete gate 1
comparison, but do not launch gate 2 unless the decision criteria are
explicitly revised.
- Next action: continue babysitting d768, then close out the issue with the
gate 1 table.

### 2026-04-25 15:25 - MOE-AGNH-001 gate 1 final

- Hypothesis: the d768 point may still show effective speedup, but gate 1
requires both d512 and d768 to exceed 1.0.
- Command:
`.venv/bin/iris --config lib/iris/examples/marin.yaml job list --json --prefix /kaiyue/iris-run-job-20260425-184632`
plus W&B API summaries for
`understanding-sam/marin_moe/moe-adamh-grad-norm-d512-2p19e17` and
`understanding-sam/marin_moe/moe-adamh-grad-norm-d768-1p70e18`.
- Config:
- d512 baseline: loss `3.8104`, throughput `405630` tokens/s, budget
`2.19e17`
- d512 variant: loss `3.815110206604004`, throughput `406982.727149109`
tokens/s, total tokens `837156864`, global step `6386`
- d768 baseline: loss `3.4339`, throughput `273532` tokens/s, budget
`1.70e18`
- d768 variant: loss `3.429192543029785`, throughput `274218.4171530239`
tokens/s, total tokens `2711355392`, global step `10342`
- Result:

| Scale | Baseline loss | Variant loss | Loss delta | Baseline tok/s | Variant tok/s | Tok/s delta | Effective speedup |
|-------|---------------|--------------|------------|----------------|---------------|-------------|-------------------|
| d512 | 3.8104 | 3.815110 | +0.004710 | 405,630 | 406,983 | +0.333% | 0.980893 |
| d768 | 3.4339 | 3.429193 | -0.004707 | 273,532 | 274,218 | +0.251% | 1.030269 |

- Interpretation: the d768 point is positive, but d512 fails the required
effective-speedup threshold. Gate 1 fails overall, so this variant should
not advance to gate 2 under `experiments/grug/moe/agent.md`.
- Next action: post the final GitHub issue summary and close the experiment
issue as a completed negative result.
66 changes: 65 additions & 1 deletion experiments/grug/moe/adamh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Local copy of AdamH for iteration without modifying Levanter.
# Adapted from levanter.optim.adamh.

from collections import defaultdict
from typing import Any, NamedTuple

import chex
Expand All @@ -12,13 +13,53 @@
import optax
from optax import tree_utils as otu

from levanter.utils.jax_utils import leaf_key_paths


class ScaleByAdamHState(NamedTuple):
count: chex.Array
mu: optax.Updates
nu: optax.Updates


def _module_key(path: str | None) -> str | None:
if path is None:
return None
parts = path.split(".")
if len(parts) <= 1:
return path
return ".".join(parts[:-1])


def normalize_module_gradients_to_unit_rms(updates: optax.Updates, eps: float = 1e-16) -> optax.Updates:
"""Normalize each module's gradient leaves to combined RMS 1."""
leaves, treedef = jax.tree_util.tree_flatten(updates, is_leaf=lambda x: x is None)
paths = treedef.flatten_up_to(leaf_key_paths(updates))
groups: dict[str, list[int]] = defaultdict(list)

for index, (leaf, path) in enumerate(zip(leaves, paths, strict=True)):
if leaf is None or not hasattr(leaf, "shape"):
continue
module_key = _module_key(path)
if module_key is None:
continue
groups[module_key].append(index)

normalized_leaves = list(leaves)
for group_indices in groups.values():
square_sum = sum(
(jnp.sum(jnp.square(leaves[index].astype(jnp.float32))) for index in group_indices),
jnp.array(0.0, dtype=jnp.float32),
)
num_elements = sum(int(leaves[index].size) for index in group_indices)
inv_rms = jax.lax.rsqrt(square_sum / num_elements + eps)
for index in group_indices:
leaf = leaves[index]
normalized_leaves[index] = leaf * inv_rms.astype(leaf.dtype)

return jax.tree_util.tree_unflatten(treedef, normalized_leaves)


def scale_by_adamh(
b1: float = 0.9,
b2: float = 0.999,
Expand Down Expand Up @@ -75,4 +116,27 @@ def scale_invariant_update(p, u):
return optax.GradientTransformation(init_fn, update_fn)


__all__ = ["ScaleByAdamHState", "scale_by_adamh"]
def scale_by_adamh_with_module_gradient_normalization(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
learning_rate: float = 0.02,
mu_dtype: Any | None = None,
gradient_norm_eps: float = 1e-16,
) -> optax.GradientTransformation:
"""AdamH with module-wise gradient RMS normalization before moment updates."""
adamh = scale_by_adamh(b1=b1, b2=b2, eps=eps, learning_rate=learning_rate, mu_dtype=mu_dtype)

def update_fn(updates, state, params):
normalized_updates = normalize_module_gradients_to_unit_rms(updates, eps=gradient_norm_eps)
return adamh.update(normalized_updates, state, params)

return optax.GradientTransformation(adamh.init, update_fn)


__all__ = [
"ScaleByAdamHState",
"normalize_module_gradients_to_unit_rms",
"scale_by_adamh",
"scale_by_adamh_with_module_gradient_normalization",
]
106 changes: 106 additions & 0 deletions experiments/grug/moe/launch_adamh_grad_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""Launch MoE AdamH gradient-normalization ablations."""

import dataclasses
import os

from fray.cluster import ResourceConfig
from levanter.tracker.wandb import WandbConfig
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned

from experiments.grug.moe.heuristic import build_from_heuristic
from experiments.grug.moe.launch import (
GrugMoeLaunchConfig,
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
run_grug_moe_trial,
)
from experiments.grug.moe.optimizer import GrugMoeAdamHConfig, GrugMoeAdamHGradientNormConfig
from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig

_TARGET_STEPS: int = 2**14
_GATE_SPECS: dict[str, tuple[tuple[str, float, int], ...]] = {
"gate1": (
("d512-2p19e17", 2.19e17, 512),
("d768-1p70e18", 1.70e18, 768),
),
"gate2": (
("d1024-9p00e18", 9.00e18, 1024),
("d1280-2p83e19", 2.83e19, 1280),
),
}


def _gradient_norm_optimizer(optimizer: GrugMoeAdamHConfig) -> GrugMoeAdamHGradientNormConfig:
return GrugMoeAdamHGradientNormConfig(**dataclasses.asdict(optimizer))


def _resolve_run_id(label: str) -> str:
run_id = os.environ.get("GRUG_RUN_ID", f"moe-adamh-grad-norm-{label}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep per-step run IDs unique for gated launches

When GRUG_RUN_ID is set, _resolve_run_id returns the same ID for every label, so gate1/all runs emit multiple steps with identical run_ids. In run_grug_moe_trial, that ID becomes the trainer/W&B run ID (and W&B defaults to resume="allow"), so subsequent steps can resume or overwrite earlier runs instead of producing separate experiment records. This breaks side-by-side ablation tracking for the very comparisons this launcher is meant to run.

Useful? React with 👍 / 👎.

ferry_date = os.environ.get("FERRY_DATE")
if ferry_date:
run_id = f"{run_id}-{ferry_date}"
return run_id


def _make_step(label: str, budget: float, hidden_dim: int) -> ExecutorStep:
model, optimizer, batch_size, steps = build_from_heuristic(
budget=budget,
hidden_dim=hidden_dim,
target_steps=_TARGET_STEPS,
)
return ExecutorStep(
name=f"grug/moe-adamh-grad-norm/{label}",
fn=run_grug_moe_trial,
config=GrugMoeLaunchConfig(
model=versioned(model),
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
output_path=this_output_path(),
run_id=_resolve_run_id(label),
resources=versioned(ResourceConfig.with_tpu("v5p-8")),
steps=versioned(steps),
batch_size=versioned(batch_size),
seed=versioned(0),
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
tracker=WandbConfig(
project="marin_moe",
tags=["moe", "adamh-grad-norm"],
group="moe-adamh-grad-norm",
name=None,
),
optimizer=versioned(_gradient_norm_optimizer(optimizer)),
grug_trainer=versioned(
GrugTrainerConfig(
z_loss_weight=1e-4,
ema_beta=None,
log_every=1,
)
),
eval=versioned(
GrugEvalConfig(
eval_batch_size=512,
steps_per_eval=1000,
max_eval_batches=8,
eval_current=True,
eval_ema=False,
)
),
),
)


def _selected_specs() -> tuple[tuple[str, float, int], ...]:
gate = os.environ.get("GRUG_MOE_ADAMH_GRAD_NORM_GATE", "gate1")
if gate == "all":
return _GATE_SPECS["gate1"] + _GATE_SPECS["gate2"]
if gate not in _GATE_SPECS:
raise ValueError(f"Unknown GRUG_MOE_ADAMH_GRAD_NORM_GATE={gate!r}. Expected gate1, gate2, or all.")
return _GATE_SPECS[gate]


if __name__ == "__main__":
executor_main(
steps=[_make_step(label, budget, hidden_dim) for label, budget, hidden_dim in _selected_specs()],
description="Grug MoE AdamH with per-module gradient RMS normalization.",
)
28 changes: 24 additions & 4 deletions experiments/grug/moe/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import optax

from levanter.optim import OptimizerConfig
from experiments.grug.moe.adamh import scale_by_adamh
from experiments.grug.moe.adamh import scale_by_adamh, scale_by_adamh_with_module_gradient_normalization
from levanter.utils.jax_utils import leaf_key_paths


Expand All @@ -28,6 +28,9 @@ class GrugMoeAdamHConfig(OptimizerConfig):
adam_lr: float = 6e-4
expert_lr: float | None = None

def _scale_by_adamh(self, learning_rate):
return scale_by_adamh(self.beta1, self.beta2, self.epsilon, learning_rate)

def build(self, num_train_steps):
learning_rate_schedule = self.lr_scheduler(num_train_steps)
adam_lr_schedule = self.lr_scheduler(num_train_steps, override_lr=self.adam_lr)
Expand All @@ -39,14 +42,14 @@ def adamh_transform():
components = []
if self.max_grad_norm:
components.append(optax.clip_by_global_norm(self.max_grad_norm))
components.append(scale_by_adamh(self.beta1, self.beta2, self.epsilon, learning_rate))
components.append(self._scale_by_adamh(learning_rate))
return optax.chain(*components)

def adamh_expert_transform():
components = []
if self.max_grad_norm:
components.append(optax.clip_by_global_norm(self.max_grad_norm))
components.append(scale_by_adamh(self.beta1, self.beta2, self.epsilon, expert_lr))
components.append(self._scale_by_adamh(expert_lr))
return optax.chain(*components)

def adam_transform():
Expand Down Expand Up @@ -91,4 +94,21 @@ def mask_fn(param, path):
return jax.tree.map(mask_fn, params, paths)


__all__ = ["GrugMoeAdamHConfig"]
@OptimizerConfig.register_subclass("grug_moe_adamh_grad_norm")
@dataclass(frozen=True)
class GrugMoeAdamHGradientNormConfig(GrugMoeAdamHConfig):
"""AdamH for Grug MoE with per-module gradient RMS normalization."""

gradient_norm_eps: float = 1e-16

def _scale_by_adamh(self, learning_rate):
return scale_by_adamh_with_module_gradient_normalization(
self.beta1,
self.beta2,
self.epsilon,
learning_rate,
gradient_norm_eps=self.gradient_norm_eps,
)


__all__ = ["GrugMoeAdamHConfig", "GrugMoeAdamHGradientNormConfig"]
Loading
Loading