Skip to content

Commit 5afc883

Browse files
committed
[grug] Add MoE ragged debug launch artifacts
1 parent 0204cbe commit 5afc883

4 files changed

Lines changed: 342 additions & 7 deletions

File tree

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Debugging log for ragged-all-to-all
2+
3+
Investigating why the 1e23 Grug MoE run diverges under `ragged_all_to_all` while the ring configuration is healthy.
4+
5+
## Initial status
6+
7+
The `moe_1e23_d5120_bs2048_ep8_ragged_48l_rayuvtpu_20260417_011404` run started with healthy step-0 metrics, then diverged later in training. By step 1000 it was far behind the old `ep4_ring` baseline and by step 1250 it showed multi-million gradient norms, followed by `NaN` eval at step 1259.
8+
9+
## Hypothesis 1
10+
11+
The ragged dispatch path is not semantically equivalent to the ring path under expert parallelism and later-training router distributions. Capacity clipping, dropped-assignment accounting, or recombination may diverge in a way that only appears once routing sharpens.
12+
13+
## Changes to make
14+
15+
- Relaunch the 1e23 config as ring `ep4` on the current Ray cluster for a fresh control run.
16+
- Audit `experiments/grug/moe/model.py` and `lib/levanter/src/levanter/grug/grug_moe.py`.
17+
- Run targeted TPU tests on `v5p-8` and `v5p-32` to compare ring vs ragged gradients for the functional MoE MLP block.
18+
19+
## Future Work
20+
21+
- [ ] Check whether ring and ragged produce materially different MLP block gradients on TPU pods.
22+
- [ ] Confirm whether dropped-assignment behavior differs between implementations at equal capacity.
23+
- [ ] Compare router-sharpening behavior after a few hundred optimization steps, not just at initialization.
24+
25+
## Results
26+
27+
- The host-side routing simulation matches ring exactly, so the abstract clip/permute/unpermute math was not the bug.
28+
- The actual bug was in [grug_moe.py](/Users/dlwh/.codex/worktrees/8989/marin/lib/levanter/src/levanter/grug/grug_moe.py): `_shard_a2a_params` was feeding `jax.lax.ragged_all_to_all` receiver-local output offsets instead of the sender-side remote offsets that the primitive expects.
29+
- JAX internally transposes `output_offsets` with an `all_to_all`.
30+
- Our old code pre-transposed them, so real distributed runs wrote returned slices into the wrong positions.
31+
- That explains why a pure Python/JAX simulation of the routing logic looked correct while TPU runs showed large ring-vs-ragged output and gradient deltas.
32+
- Existing EP coverage in [test_grugformer_moe.py](/Users/dlwh/.codex/worktrees/8989/marin/lib/levanter/tests/grug/test_grugformer_moe.py) only checks output shape and finiteness, not ring-vs-ragged parity.
33+
- Fresh ring control relaunch:
34+
- Ray submission: `ray-run-dlwh-moe-uvtpu-ep4-ring-manual-20260417_152005`
35+
- W&B run id (expected once training initializes): `moe_1e23_d5120_bs2048_ep4_ring_rayuvtpu_20260417_152005`
36+
- The first relaunch attempt via `ray_run.py` failed during Ray runtime-env pip setup because `kitoken==0.10.2` was not available through the cluster-visible pip indexes.
37+
- Manual `ray job submit` without Ray pip runtime-env is now running and has reached executor dispatch for `grug/moe_1e23_d5120_bs2048_ep4_ring`.
38+
- TPU parity probes:
39+
- Initial `v5p-8` and `v5p-32` attempts failed because the probe lived under untracked `scratch/`, which the Iris workspace bundle did not include.
40+
- The probe now lives at [scripts/debug/grug_moe_grad_compare.py](/Users/dlwh/.codex/worktrees/8989/marin/scripts/debug/grug_moe_grad_compare.py) and compiles locally.
41+
- First tracked-path jobs submitted:
42+
- `/dlwh/grug-moe-grad-compare-v5p8-20260417-0828`
43+
- `/dlwh/grug-moe-grad-compare-v5p32-20260417-0828`
44+
- Those jobs later failed with entrypoint container OOM (`exit 137`) under the default `1GB` host-memory request, so they did not yet exercise the MoE kernel.
45+
- Region-widened jobs were submitted next:
46+
- `/dlwh/grug-moe-grad-compare-v5p8-20260417-0833`
47+
- `/dlwh/grug-moe-grad-compare-v5p32-20260417-0833`
48+
- Final corrected jobs use both `us-central1` and `us-east5` plus `--memory 8GB`:
49+
- `/dlwh/grug-moe-grad-compare-v5p8-20260417-0835`
50+
- `/dlwh/grug-moe-grad-compare-v5p32-20260417-0835`
51+
- Successful `v5p-8` probe after the `_shard_a2a_params` fix:
52+
- Job: `/dlwh/grug-moe-grad-compare-v5p8-20260417-094052`
53+
- Normal routed case now matches:
54+
- `ring_loss == ragged_loss == 995541.4375`
55+
- `ring_dropped == ragged_dropped == 9`
56+
- `output_diff.rel_l2 = 4.17e-08`
57+
- `grad_x_diff.rel_l2 = 4.20e-08`
58+
- `grad_w_up_gate_diff = 0`
59+
- `grad_w_down_diff = 0`
60+
- Forced-overflow case still matches exactly with zero diffs.
61+
- Local regression coverage added:
62+
- `_shard_a2a_params` now has a unit test asserting sender-side output offsets.
63+
- A parity test now checks `ring` vs `ragged_all_to_all` MoE outputs when EP is available on a non-CPU backend.

experiments/grug/moe/launch.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def run_grug_moe_trial(config: GrugMoeLaunchConfig) -> None:
119119
run_grug(run_config)
120120

121121

122-
RESOLVED_RUN_ID = _resolve_run_id("moe_1e23_d5120_bs2048_ep8_ragged")
122+
RESOLVED_RUN_ID = _resolve_run_id("moe_1e23_d5120_bs2048_ep4_ring")
123123

124124

125125
# 1e23 compute budget, d5120. Model +
@@ -129,18 +129,19 @@ def run_grug_moe_trial(config: GrugMoeLaunchConfig) -> None:
129129
_BASELINE_BUDGET: float = 1e23
130130
_BASELINE_HIDDEN_DIM: int = 5120
131131
_BASELINE_TARGET_STEPS: int = 120_000
132+
_BASELINE_NUM_LAYERS_OVERRIDE: int | None = 48
132133
_baseline_model, _baseline_optimizer, _baseline_batch, _baseline_steps = build_from_heuristic(
133134
budget=_BASELINE_BUDGET,
134135
hidden_dim=_BASELINE_HIDDEN_DIM,
135136
target_steps=_BASELINE_TARGET_STEPS,
136137
)
137-
# Stack MoE blocks via jax.lax.scan to keep XLA compile + peak HBM tractable at
138-
# the heuristic-derived depth, and force ragged dispatch so the smoke exercises
139-
# the high-EP path from #4697.
138+
# Match the known-good 1e23 ring EP=4 configuration while keeping the current
139+
# v4-2048/us-central2 launch wiring.
140140
_baseline_model = dataclasses.replace(
141141
_baseline_model,
142-
moe_implementation="ragged_all_to_all",
142+
moe_implementation="ring",
143143
use_array_stacked_blocks=True,
144+
num_layers=_BASELINE_NUM_LAYERS_OVERRIDE or _baseline_model.num_layers,
144145
)
145146

146147
# Override the heuristic-derived batch_size (round_up_pow2 only produces powers
@@ -157,7 +158,7 @@ def run_grug_moe_trial(config: GrugMoeLaunchConfig) -> None:
157158

158159

159160
baseline_moe = ExecutorStep(
160-
name="grug/moe_1e23_d5120_bs2048_ep8_ragged",
161+
name="grug/moe_1e23_d5120_bs2048_ep4_ring",
161162
fn=run_grug_moe_trial,
162163
config=GrugMoeLaunchConfig(
163164
model=versioned(_baseline_model),
@@ -169,7 +170,7 @@ def run_grug_moe_trial(config: GrugMoeLaunchConfig) -> None:
169170
resources=versioned(ResourceConfig.with_tpu("v4-2048", regions=["us-central2"])),
170171
steps=versioned(_baseline_steps),
171172
batch_size=versioned(_baseline_batch),
172-
expert_parallel=versioned(8),
173+
expert_parallel=versioned(4),
173174
seed=versioned(0),
174175
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
175176
tracker=WandbConfig(
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#!/usr/bin/env python3
2+
# Copyright The Marin Authors
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from __future__ import annotations
6+
7+
import json
8+
9+
import numpy as np
10+
11+
import jax
12+
import jax.numpy as jnp
13+
from jax.experimental import multihost_utils
14+
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P
15+
16+
from iris.runtime.jax_init import initialize_jax
17+
from levanter.grug.grug_moe import moe_mlp
18+
from levanter.utils.activation import ActivationFunctionEnum
19+
20+
21+
def _make_ep_mesh() -> Mesh:
22+
devices = jax.devices()
23+
if len(devices) < 2 or len(devices) % 2 != 0:
24+
raise RuntimeError(f"Need an even number of devices >= 2, got {len(devices)}")
25+
mesh_devices = np.array(devices).reshape(len(devices) // 2, 2, 1)
26+
return Mesh(
27+
mesh_devices,
28+
axis_names=("data", "expert", "model"),
29+
axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Explicit),
30+
)
31+
32+
33+
def _make_inputs(
34+
*,
35+
key: jax.Array,
36+
tokens: int,
37+
hidden_dim: int,
38+
intermediate_dim: int,
39+
num_experts: int,
40+
topk: int,
41+
overflow: bool,
42+
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
43+
k_x, k_sel, k_logits, k_w13, k_w2 = jax.random.split(key, 5)
44+
x = jax.random.normal(k_x, (tokens, hidden_dim), dtype=jnp.float32)
45+
if overflow:
46+
selected_experts = jnp.zeros((tokens, topk), dtype=jnp.int32)
47+
combine_weights = jnp.full((tokens, topk), 1.0 / topk, dtype=jnp.float32)
48+
else:
49+
selected_experts = jax.random.randint(k_sel, (tokens, topk), 0, num_experts, dtype=jnp.int32)
50+
combine_logits = jax.random.normal(k_logits, (tokens, topk), dtype=jnp.float32)
51+
combine_weights = jax.nn.softmax(combine_logits, axis=-1)
52+
w_up_gate = jax.random.normal(k_w13, (num_experts, hidden_dim, 2 * intermediate_dim), dtype=jnp.float32)
53+
w_down = jax.random.normal(k_w2, (num_experts, intermediate_dim, hidden_dim), dtype=jnp.float32)
54+
return x, selected_experts, combine_weights, w_up_gate, w_down
55+
56+
57+
def _tree_diff_stats(a, b) -> dict[str, float]:
58+
leaves_a = jax.tree.leaves(a)
59+
leaves_b = jax.tree.leaves(b)
60+
max_abs = 0.0
61+
max_rel = 0.0
62+
l2_sq = 0.0
63+
ref_l2_sq = 0.0
64+
for xa, xb in zip(leaves_a, leaves_b, strict=True):
65+
da = np.asarray(xa)
66+
db = np.asarray(xb)
67+
diff = np.abs(da - db)
68+
max_abs = max(max_abs, float(diff.max(initial=0.0)))
69+
denom = np.maximum(np.abs(db), 1e-12)
70+
max_rel = max(max_rel, float((diff / denom).max(initial=0.0)))
71+
l2_sq += float(np.sum((da - db) ** 2))
72+
ref_l2_sq += float(np.sum(db**2))
73+
return {
74+
"max_abs": max_abs,
75+
"max_rel": max_rel,
76+
"l2": l2_sq**0.5,
77+
"ref_l2": ref_l2_sq**0.5,
78+
"rel_l2": (l2_sq**0.5) / max(ref_l2_sq**0.5, 1e-12),
79+
}
80+
81+
82+
def _host_array(x: jax.Array) -> np.ndarray:
83+
if jax.process_count() > 1 and getattr(x, "ndim", 0) > 0:
84+
x = multihost_utils.process_allgather(x, tiled=True)
85+
return np.asarray(x)
86+
87+
88+
def _host_scalar(x: jax.Array) -> float:
89+
return float(np.asarray(x))
90+
91+
92+
def _run_case(mesh: Mesh, *, overflow: bool) -> dict[str, object]:
93+
hidden_dim = 128
94+
intermediate_dim = 256
95+
num_experts = 8
96+
topk = 4
97+
tokens = max(len(jax.devices()) * 16, 64)
98+
99+
with jax.set_mesh(mesh):
100+
x, selected_experts, combine_weights, w_up_gate, w_down = _make_inputs(
101+
key=jax.random.key(17 if overflow else 7),
102+
tokens=tokens,
103+
hidden_dim=hidden_dim,
104+
intermediate_dim=intermediate_dim,
105+
num_experts=num_experts,
106+
topk=topk,
107+
overflow=overflow,
108+
)
109+
110+
batch_sharding = NamedSharding(mesh, P(("data", "expert"), None))
111+
expert_sharding = NamedSharding(mesh, P("expert", None, None))
112+
x = jax.sharding.reshard(x, batch_sharding)
113+
selected_experts = jax.sharding.reshard(selected_experts, batch_sharding)
114+
combine_weights = jax.sharding.reshard(combine_weights, batch_sharding)
115+
w_up_gate = jax.sharding.reshard(w_up_gate, expert_sharding)
116+
w_down = jax.sharding.reshard(w_down, expert_sharding)
117+
118+
def run_impl(implementation: str):
119+
def loss_and_drop(
120+
x_arg,
121+
selected_experts_arg,
122+
combine_weights_arg,
123+
w_up_gate_arg,
124+
w_down_arg,
125+
):
126+
out, dropped = moe_mlp(
127+
x_arg,
128+
selected_experts_arg,
129+
combine_weights_arg,
130+
w_up_gate_arg,
131+
w_down_arg,
132+
activation=ActivationFunctionEnum.silu,
133+
implementation=implementation,
134+
mesh=None,
135+
report_capacity_overflow=True,
136+
capacity_factor=1.0,
137+
)
138+
loss = jnp.mean(out.astype(jnp.float32) ** 2)
139+
return loss, (out, dropped)
140+
141+
fn = jax.jit(jax.value_and_grad(loss_and_drop, has_aux=True, argnums=(0, 3, 4)))
142+
(loss, (out, dropped)), grads = fn(x, selected_experts, combine_weights, w_up_gate, w_down)
143+
return loss, out, dropped, grads
144+
145+
ring_loss, ring_out, ring_dropped, ring_grads = run_impl("ring")
146+
ragged_loss, ragged_out, ragged_dropped, ragged_grads = run_impl("ragged_all_to_all")
147+
148+
ring_loss = _host_scalar(ring_loss)
149+
ragged_loss = _host_scalar(ragged_loss)
150+
ring_out_np = _host_array(ring_out)
151+
ragged_out_np = _host_array(ragged_out)
152+
ring_grad_x = _host_array(ring_grads[0])
153+
ragged_grad_x = _host_array(ragged_grads[0])
154+
ring_grad_w_up_gate = _host_array(ring_grads[1])
155+
ragged_grad_w_up_gate = _host_array(ragged_grads[1])
156+
ring_grad_w_down = _host_array(ring_grads[2])
157+
ragged_grad_w_down = _host_array(ragged_grads[2])
158+
159+
return {
160+
"overflow": overflow,
161+
"tokens": tokens,
162+
"num_devices": len(jax.devices()),
163+
"num_processes": jax.process_count(),
164+
"ring_loss": ring_loss,
165+
"ragged_loss": ragged_loss,
166+
"loss_delta": ring_loss - ragged_loss,
167+
"ring_dropped": int(np.asarray(ring_dropped)),
168+
"ragged_dropped": int(np.asarray(ragged_dropped)),
169+
"output_diff": _tree_diff_stats(ring_out_np, ragged_out_np),
170+
"grad_x_diff": _tree_diff_stats(ring_grad_x, ragged_grad_x),
171+
"grad_w_up_gate_diff": _tree_diff_stats(ring_grad_w_up_gate, ragged_grad_w_up_gate),
172+
"grad_w_down_diff": _tree_diff_stats(ring_grad_w_down, ragged_grad_w_down),
173+
}
174+
175+
176+
def main() -> None:
177+
initialize_jax()
178+
mesh = _make_ep_mesh()
179+
normal = _run_case(mesh, overflow=False)
180+
overflow = _run_case(mesh, overflow=True)
181+
if jax.process_index() == 0:
182+
print(json.dumps({"normal": normal, "overflow": overflow}, indent=2, sort_keys=True), flush=True)
183+
184+
185+
if __name__ == "__main__":
186+
main()
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env python3
2+
# Copyright The Marin Authors
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from __future__ import annotations
6+
7+
import dataclasses
8+
9+
from experiments.grug.moe.launch import (
10+
ExecutorStep,
11+
GrugEvalConfig,
12+
GrugMoeLaunchConfig,
13+
GrugTrainerConfig,
14+
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
15+
WandbConfig,
16+
_baseline_batch,
17+
_baseline_model,
18+
_baseline_optimizer,
19+
_baseline_steps,
20+
_resolve_run_id,
21+
executor_main,
22+
run_grug_moe_trial,
23+
this_output_path,
24+
versioned,
25+
)
26+
from fray.cluster import ResourceConfig
27+
28+
RUN_ID = _resolve_run_id("moe_1e23_d5120_bs2048_ep8_ragged_48l_rayuvtpu_20260417_0945")
29+
STEP_NAME = "grug/moe_1e23_d5120_bs2048_ep8_ragged_48l_fix_a2a_20260417_0945"
30+
31+
32+
ragged_ep8_fix = ExecutorStep(
33+
name=STEP_NAME,
34+
fn=run_grug_moe_trial,
35+
config=GrugMoeLaunchConfig(
36+
model=versioned(
37+
dataclasses.replace(
38+
_baseline_model,
39+
moe_implementation="ragged_all_to_all",
40+
use_array_stacked_blocks=True,
41+
num_layers=48,
42+
)
43+
),
44+
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
45+
output_path=this_output_path(),
46+
run_id=RUN_ID,
47+
resources=versioned(ResourceConfig.with_tpu("v4-2048", regions=["us-central2"])),
48+
steps=versioned(_baseline_steps),
49+
batch_size=versioned(_baseline_batch),
50+
expert_parallel=versioned(8),
51+
seed=versioned(0),
52+
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
53+
tracker=WandbConfig(
54+
project="dial_moe",
55+
tags=["adamh", "qb", "sharded-qb", "gatednorm", "xsa", "zloss", "eq3e3", "ragged-fix"],
56+
group="moe-iter04",
57+
name=None,
58+
),
59+
optimizer=versioned(_baseline_optimizer),
60+
priority_band="production",
61+
grug_trainer=versioned(
62+
GrugTrainerConfig(
63+
z_loss_weight=1e-4,
64+
ema_beta=None,
65+
log_every=1,
66+
)
67+
),
68+
eval=versioned(
69+
GrugEvalConfig(
70+
eval_batch_size=1024,
71+
steps_per_eval=1000,
72+
max_eval_batches=8,
73+
eval_current=True,
74+
eval_ema=False,
75+
)
76+
),
77+
),
78+
)
79+
80+
81+
if __name__ == "__main__":
82+
executor_main(
83+
steps=[ragged_ep8_fix],
84+
description="Grug MoE 1e23 ragged EP8 relaunch after ragged_all_to_all offset fix.",
85+
)

0 commit comments

Comments
 (0)