-
Notifications
You must be signed in to change notification settings - Fork 111
[grug] Add MoE AdamH gradient normalization #5181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
WhenWen
wants to merge
4
commits into
main
Choose a base branch
from
research/moe-adamh-grad-norm
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| # MoE AdamH Gradient Normalization: Research Logbook | ||
|
|
||
| ## Scope | ||
|
|
||
| - Goal: test whether normalizing each module's gradients to RMS 1 before AdamH | ||
| moment updates improves Grug MoE training relative to the current AdamH | ||
| recipe. | ||
| - Primary metrics: final `eval/paloma/macro_loss`, effective speedup versus | ||
| the compute-optimal MoE baseline. | ||
| - Secondary metrics: `throughput/tokens_per_second`, | ||
| `throughput/total_tokens`, routing balance/stability metrics. | ||
| - Constraints: compare at the README compute-optimal d512, d768, d1024, and | ||
| d1280 budgets; run gate 1 before gate 2. | ||
| - Issue: https://github.com/marin-community/marin/issues/5180 | ||
|
|
||
| ## Baseline | ||
|
|
||
| - Date: 2026-04-25 | ||
| - Code refs: | ||
| - `experiments/grug/moe/README.md` | ||
| - `experiments/grug/moe/adamh.py` | ||
| - `experiments/grug/moe/optimizer.py` | ||
| - `experiments/grug/moe/launch.py` | ||
| - Baseline numbers: compute-optimal d512, d768, d1024, and d1280 table in | ||
| `experiments/grug/moe/README.md`. | ||
|
|
||
| ## Experiment Log | ||
|
|
||
| ### 2026-04-25 11:45 - Kickoff | ||
|
|
||
| - Hypothesis: module-wise gradient RMS normalization reduces scale mismatch | ||
| between attention, shared expert, and routed expert AdamH groups without | ||
| changing AdamH's projected parameter update rule. | ||
| - Command: local implementation and unit tests. | ||
| - Config: | ||
| - optimizer: `GrugMoeAdamHGradientNormConfig` | ||
| - normalization: each module's gradient leaves are scaled to combined RMS 1 | ||
| before AdamH moment updates. | ||
| - gate 1 launch: `GRUG_MOE_ADAMH_GRAD_NORM_GATE=gate1 python -m experiments.grug.moe.launch_adamh_grad_norm` | ||
| - gate 2 launch: `GRUG_MOE_ADAMH_GRAD_NORM_GATE=gate2 python -m experiments.grug.moe.launch_adamh_grad_norm` | ||
| - Result: implementation passes focused optimizer tests and pre-commit locally. | ||
| - Interpretation: ready to run gate 1 against d512 and d768 baselines. | ||
| - Next action: create the GitHub experiment issue, push the branch, and submit | ||
| the gate 1 Iris job. | ||
|
|
||
| ### 2026-04-25 11:46 - MOE-AGNH-001 gate 1 submitted | ||
|
|
||
| - Hypothesis: the gradient-normalized AdamH variant can match or improve the | ||
| d512 and d768 compute-optimal MoE baselines without hurting throughput. | ||
| - Command: | ||
| `.venv/bin/iris --config lib/iris/examples/marin.yaml job run --no-wait --memory=2G --disk=4G --cpu=1 --extra=cpu --reserve v5p-8 -e WANDB_API_KEY "$WANDB_API_KEY" -e GRUG_MOE_ADAMH_GRAD_NORM_GATE gate1 -- python -m experiments.grug.moe.launch_adamh_grad_norm` | ||
| - Config: | ||
| - commit: `d97fdc3131750bfaab552d459fd14e51bf8a2860` | ||
| - issue: https://github.com/marin-community/marin/issues/5180 | ||
| - PR: https://github.com/marin-community/marin/pull/5181 | ||
| - Iris parent job: `/kaiyue/iris-run-job-20260425-184632` | ||
| - data browser: | ||
| https://marin.community/data-browser/experiment?path=gs%3A//marin-us-east5/experiments/launch_adamh_grad_norm-7614bf.json | ||
| - W&B d512: | ||
| https://wandb.ai/understanding-sam/marin_moe/runs/moe-adamh-grad-norm-d512-2p19e17 | ||
| - W&B d768: | ||
| https://wandb.ai/understanding-sam/marin_moe/runs/moe-adamh-grad-norm-d768-1p70e18 | ||
| - Result: submitted to Iris and dispatched both Fray TPU child jobs. The first | ||
| TPU allocation hit a device-busy bad-node signature; Iris retried | ||
| automatically and both child jobs reached `JOB_STATE_RUNNING`. | ||
| - Interpretation: gate 1 is healthy enough to monitor for first eval/final | ||
| Paloma macro loss rather than resubmit immediately. | ||
| - Next action: babysit the d512 and d768 jobs to terminal state, compare final | ||
| metrics with the README baselines, then decide whether to launch gate 2. | ||
|
|
||
| ### 2026-04-25 12:40 - MOE-AGNH-001 d512 final | ||
|
|
||
| - Hypothesis: d512 should show effective speedup greater than 1 versus the | ||
| README compute-optimal baseline. | ||
| - Command: W&B API summary pull for | ||
| `understanding-sam/marin_moe/moe-adamh-grad-norm-d512-2p19e17`. | ||
| - Config: | ||
| - budget: `2.19e17` | ||
| - baseline Paloma macro loss: `3.8104` | ||
| - baseline throughput: `405630` tokens/s | ||
| - variant Paloma macro loss: `3.815110206604004` | ||
| - variant throughput: `406982.727149109` tokens/s | ||
| - variant total tokens: `837156864` | ||
| - Result: d512 Iris child job succeeded. Effective speedup is `0.980893`; | ||
| loss delta is `+0.004710` and throughput delta is `+0.333%`. | ||
| - Interpretation: d512 does not clear gate 1. Since the d768 gate 1 child job | ||
| is already running, continue it to terminal state for the complete gate 1 | ||
| comparison, but do not launch gate 2 unless the decision criteria are | ||
| explicitly revised. | ||
| - Next action: continue babysitting d768, then close out the issue with the | ||
| gate 1 table. | ||
|
|
||
| ### 2026-04-25 15:25 - MOE-AGNH-001 gate 1 final | ||
|
|
||
| - Hypothesis: the d768 point may still show effective speedup, but gate 1 | ||
| requires both d512 and d768 to exceed 1.0. | ||
| - Command: | ||
| `.venv/bin/iris --config lib/iris/examples/marin.yaml job list --json --prefix /kaiyue/iris-run-job-20260425-184632` | ||
| plus W&B API summaries for | ||
| `understanding-sam/marin_moe/moe-adamh-grad-norm-d512-2p19e17` and | ||
| `understanding-sam/marin_moe/moe-adamh-grad-norm-d768-1p70e18`. | ||
| - Config: | ||
| - d512 baseline: loss `3.8104`, throughput `405630` tokens/s, budget | ||
| `2.19e17` | ||
| - d512 variant: loss `3.815110206604004`, throughput `406982.727149109` | ||
| tokens/s, total tokens `837156864`, global step `6386` | ||
| - d768 baseline: loss `3.4339`, throughput `273532` tokens/s, budget | ||
| `1.70e18` | ||
| - d768 variant: loss `3.429192543029785`, throughput `274218.4171530239` | ||
| tokens/s, total tokens `2711355392`, global step `10342` | ||
| - Result: | ||
|
|
||
| | Scale | Baseline loss | Variant loss | Loss delta | Baseline tok/s | Variant tok/s | Tok/s delta | Effective speedup | | ||
| |-------|---------------|--------------|------------|----------------|---------------|-------------|-------------------| | ||
| | d512 | 3.8104 | 3.815110 | +0.004710 | 405,630 | 406,983 | +0.333% | 0.980893 | | ||
| | d768 | 3.4339 | 3.429193 | -0.004707 | 273,532 | 274,218 | +0.251% | 1.030269 | | ||
|
|
||
| - Interpretation: the d768 point is positive, but d512 fails the required | ||
| effective-speedup threshold. Gate 1 fails overall, so this variant should | ||
| not advance to gate 2 under `experiments/grug/moe/agent.md`. | ||
| - Next action: post the final GitHub issue summary and close the experiment | ||
| issue as a completed negative result. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # Copyright The Marin Authors | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Launch MoE AdamH gradient-normalization ablations.""" | ||
|
|
||
| import dataclasses | ||
| import os | ||
|
|
||
| from fray.cluster import ResourceConfig | ||
| from levanter.tracker.wandb import WandbConfig | ||
| from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned | ||
|
|
||
| from experiments.grug.moe.heuristic import build_from_heuristic | ||
| from experiments.grug.moe.launch import ( | ||
| GrugMoeLaunchConfig, | ||
| NEMOTRON_MIX_WITH_DEFAULT_VALIDATION, | ||
| run_grug_moe_trial, | ||
| ) | ||
| from experiments.grug.moe.optimizer import GrugMoeAdamHConfig, GrugMoeAdamHGradientNormConfig | ||
| from experiments.grug.moe.train import GrugEvalConfig, GrugTrainerConfig | ||
|
|
||
| _TARGET_STEPS: int = 2**14 | ||
| _GATE_SPECS: dict[str, tuple[tuple[str, float, int], ...]] = { | ||
| "gate1": ( | ||
| ("d512-2p19e17", 2.19e17, 512), | ||
| ("d768-1p70e18", 1.70e18, 768), | ||
| ), | ||
| "gate2": ( | ||
| ("d1024-9p00e18", 9.00e18, 1024), | ||
| ("d1280-2p83e19", 2.83e19, 1280), | ||
| ), | ||
| } | ||
|
|
||
|
|
||
| def _gradient_norm_optimizer(optimizer: GrugMoeAdamHConfig) -> GrugMoeAdamHGradientNormConfig: | ||
| return GrugMoeAdamHGradientNormConfig(**dataclasses.asdict(optimizer)) | ||
|
|
||
|
|
||
| def _resolve_run_id(label: str) -> str: | ||
| run_id = os.environ.get("GRUG_RUN_ID", f"moe-adamh-grad-norm-{label}") | ||
| ferry_date = os.environ.get("FERRY_DATE") | ||
| if ferry_date: | ||
| run_id = f"{run_id}-{ferry_date}" | ||
| return run_id | ||
|
|
||
|
|
||
| def _make_step(label: str, budget: float, hidden_dim: int) -> ExecutorStep: | ||
| model, optimizer, batch_size, steps = build_from_heuristic( | ||
| budget=budget, | ||
| hidden_dim=hidden_dim, | ||
| target_steps=_TARGET_STEPS, | ||
| ) | ||
| return ExecutorStep( | ||
| name=f"grug/moe-adamh-grad-norm/{label}", | ||
| fn=run_grug_moe_trial, | ||
| config=GrugMoeLaunchConfig( | ||
| model=versioned(model), | ||
| data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION, | ||
| output_path=this_output_path(), | ||
| run_id=_resolve_run_id(label), | ||
| resources=versioned(ResourceConfig.with_tpu("v5p-8")), | ||
| steps=versioned(steps), | ||
| batch_size=versioned(batch_size), | ||
| seed=versioned(0), | ||
| mp=versioned("params=float32,compute=bfloat16,output=bfloat16"), | ||
| tracker=WandbConfig( | ||
| project="marin_moe", | ||
| tags=["moe", "adamh-grad-norm"], | ||
| group="moe-adamh-grad-norm", | ||
| name=None, | ||
| ), | ||
| optimizer=versioned(_gradient_norm_optimizer(optimizer)), | ||
| grug_trainer=versioned( | ||
| GrugTrainerConfig( | ||
| z_loss_weight=1e-4, | ||
| ema_beta=None, | ||
| log_every=1, | ||
| ) | ||
| ), | ||
| eval=versioned( | ||
| GrugEvalConfig( | ||
| eval_batch_size=512, | ||
| steps_per_eval=1000, | ||
| max_eval_batches=8, | ||
| eval_current=True, | ||
| eval_ema=False, | ||
| ) | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def _selected_specs() -> tuple[tuple[str, float, int], ...]: | ||
| gate = os.environ.get("GRUG_MOE_ADAMH_GRAD_NORM_GATE", "gate1") | ||
| if gate == "all": | ||
| return _GATE_SPECS["gate1"] + _GATE_SPECS["gate2"] | ||
| if gate not in _GATE_SPECS: | ||
| raise ValueError(f"Unknown GRUG_MOE_ADAMH_GRAD_NORM_GATE={gate!r}. Expected gate1, gate2, or all.") | ||
| return _GATE_SPECS[gate] | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| executor_main( | ||
| steps=[_make_step(label, budget, hidden_dim) for label, budget, hidden_dim in _selected_specs()], | ||
| description="Grug MoE AdamH with per-module gradient RMS normalization.", | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
GRUG_RUN_IDis set,_resolve_run_idreturns the same ID for every label, sogate1/allruns emit multiple steps with identicalrun_ids. Inrun_grug_moe_trial, that ID becomes the trainer/W&B run ID (and W&B defaults toresume="allow"), so subsequent steps can resume or overwrite earlier runs instead of producing separate experiment records. This breaks side-by-side ablation tracking for the very comparisons this launcher is meant to run.Useful? React with 👍 / 👎.