Skip to content

grug/moe: restore aux-loss metrics and remove smoke launcher#3229

Merged
dlwh merged 14 commits intomainfrom
codex/grug-moe-aux-loss-pr
Mar 6, 2026
Merged

grug/moe: restore aux-loss metrics and remove smoke launcher#3229
dlwh merged 14 commits intomainfrom
codex/grug-moe-aux-loss-pr

Conversation

@dlwh
Copy link
Copy Markdown
Member

@dlwh dlwh commented Mar 4, 2026

Summary

  • restore MoE router auxiliary metrics/loss logging in experiments/grug/moe/model.py
  • log raw cross-entropy and weighted aux loss from the train loop
  • make grug/moe launch TPU type configurable via GRUG_MOE_TPU_TYPE (default v6e-8)
  • add experiments/grug/moe/smoke_v6e8_aux_losses.py for small aux-loss smoke launches
  • merge latest origin/main into this branch

Validation

  • ./infra/pre-commit.py --all-files

Fixes #3196

@dlwh dlwh added the agent-generated Created by automation/agent label Mar 4, 2026
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 4, 2026

Claude finished @dlwh's task in 1m 8s —— View job


Review

  • Read changed files and gather context
  • Analyze correctness and check for regressions
  • Post review

Restores router aux-loss metrics (load balancing, z-loss, entropy, per-layer histograms) through the MoE forward pass and train loop, adds a configurable TPU type for launches, and provides a smoke script. Fulfills the stated objectives.

Findings:

  • License header format (experiments/grug/moe/smoke_v6e8_aux_losses.py:1): Uses # Copyright 2025 The Marin Authors but etc/license_header.txt (updated in e986875) requires # Copyright The Marin Authors (no year). Pre-commit already flags this. Fix this →

No other bugs, correctness issues, or guideline violations found. The aux-loss computation (coefficient-weighted sum added to CE loss, with raw and weighted terms logged separately) follows the MaxText/Megatron pattern described in comments. Router stats (_routing_stats, _summarize_router_metrics) are numerically safe.

Comment thread experiments/grug/moe/smoke_v6e8_aux_losses.py Outdated
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 4, 2026

@dlwh-golem will you try out this branch on your local machine, modifying to be gpus=1 in launch.py and probably smaller batch. Just get it to run like 10 steps. change to be TinyStories like in the tutorial gpu run

@dlwh-golem
Copy link
Copy Markdown
Contributor

I updated experiments/grug/moe/launch.py on this PR branch and pushed it:

What changed for the requested local smoke setup:

  • Switched to TinyStories (tutorial-style tokenization + data config).
  • Switched resources to single GPU: ResourceConfig.with_gpu("H100", count=1).
  • Reduced run size to smoke settings: steps=10, batch_size=4, shorter warmup, eval disabled.
  • Included tokenization step in executor_main(...) before the MoE trial step.

Run attempt:

  • Command used:
    • UV_CACHE_DIR=/tmp/uv-cache HF_HOME=/tmp/hf-home HF_DATASETS_CACHE=/tmp/hf-home/datasets HUGGINGFACE_HUB_CACHE=/tmp/hf-home/hub WANDB_MODE=disabled uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke
  • Result:
    • TinyStories tokenization completed successfully.
    • Training step failed before step 1 in this environment with:
      • RuntimeError: No accelerator found. Please run on a TPU or GPU.
      • plus JAX warning: CUDA-enabled jaxlib not installed (fell back to CPU).

So the branch now matches the requested 1-GPU/TinyStories/10-step smoke config, but I couldn’t validate the 10 training steps in this runtime due missing CUDA JAX.

Codex work

Codex transcript

Completed items:
- none captured

Transcript:
I updated `experiments/grug/moe/launch.py` on this PR branch and pushed it:

- Commit: `0f2e57e7b`  
- Branch: `codex/grug-moe-aux-loss-pr`  
- PR: https://github.com/marin-community/marin/pull/3229

What changed for the requested local smoke setup:
- Switched to TinyStories (tutorial-style tokenization + data config).
- Switched resources to single GPU: `ResourceConfig.with_gpu("H100", count=1)`.
- Reduced run size to smoke settings: `steps=10`, `batch_size=4`, shorter warmup, eval disabled.
- Included tokenization step in `executor_main(...)` before the MoE trial step.

Run attempt:
- Command used:
  - `UV_CACHE_DIR=/tmp/uv-cache HF_HOME=/tmp/hf-home HF_DATASETS_CACHE=/tmp/hf-home/datasets HUGGINGFACE_HUB_CACHE=/tmp/hf-home/hub WANDB_MODE=disabled uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke`
- Result:
  - TinyStories tokenization completed successfully.
  - Training step failed before step 1 in this environment with:
    - `RuntimeError: No accelerator found. Please run on a TPU or GPU.`
    - plus JAX warning: CUDA-enabled `jaxlib` not installed (fell back to CPU).  

So the branch now matches the requested 1-GPU/TinyStories/10-step smoke config, but I couldn’t validate the 10 training steps in this runtime due missing CUDA JAX.

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 4, 2026

@dlwh-golem you need to install the gpu packages. follow the tutorial for gpus

@dlwh-golem
Copy link
Copy Markdown
Contributor

Followed the GPU tutorial setup path and installed the GPU Python stack:

  • Ran: uv sync --all-packages --extra=gpu
  • This installed CUDA JAX plugin packages (jax-cuda12-plugin, jax-cuda12-pjrt) plus CUDA wheels (nvidia-cudnn-cu12, nvidia-cublas-cu12, etc.).

Then I retried the PR smoke run:

  • uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke
    with HF_HOME/HF_DATASETS_CACHE/HUGGINGFACE_HUB_CACHE set under /tmp.

Result:

  • TinyStories tokenization step succeeds.
  • Training still fails before step 1 with:
    • cuInit(0) failed: CUDA_ERROR_OPERATING_SYSTEM
    • then RuntimeError: No accelerator found. Please run on a TPU or GPU.

Quick check on this worker also shows nvidia-smi -> Failed to initialize NVML: Unknown Error, so this looks like host/runtime GPU access, not missing Python packages anymore.

If you want, next I can run the same command outside sandbox on the host runtime (or on a different GPU host) to confirm it reaches 10 steps.

Codex work

Codex transcript

Completed items:
- none captured

Transcript:
Followed the GPU tutorial setup path and installed the GPU Python stack:

- Ran: `uv sync --all-packages --extra=gpu`
- This installed CUDA JAX plugin packages (`jax-cuda12-plugin`, `jax-cuda12-pjrt`) plus CUDA wheels (`nvidia-cudnn-cu12`, `nvidia-cublas-cu12`, etc.).

Then I retried the PR smoke run:

- `uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke`  
  with `HF_HOME/HF_DATASETS_CACHE/HUGGINGFACE_HUB_CACHE` set under `/tmp`.

Result:
- TinyStories tokenization step succeeds.
- Training still fails before step 1 with:
  - `cuInit(0) failed: CUDA_ERROR_OPERATING_SYSTEM`
  - then `RuntimeError: No accelerator found. Please run on a TPU or GPU.`

Quick check on this worker also shows `nvidia-smi` -> `Failed to initialize NVML: Unknown Error`, so this looks like host/runtime GPU access, not missing Python packages anymore.

If you want, next I can run the same command outside sandbox on the host runtime (or on a different GPU host) to confirm it reaches 10 steps.

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 4, 2026

@dlwh-golem try one more time. fixed some config on your end

@dlwh-golem
Copy link
Copy Markdown
Contributor

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @dlwh-golem again to retry.

4 similar comments
@dlwh-golem
Copy link
Copy Markdown
Contributor

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @dlwh-golem again to retry.

@dlwh-golem
Copy link
Copy Markdown
Contributor

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @dlwh-golem again to retry.

@dlwh-golem
Copy link
Copy Markdown
Contributor

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @dlwh-golem again to retry.

@dlwh-golem
Copy link
Copy Markdown
Contributor

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @dlwh-golem again to retry.

@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@dlwh-golem
Copy link
Copy Markdown
Contributor

Sorry, I hit an internal failure while handling this mention and couldn't complete the request.

Please mention @dlwh-golem again to retry.

@dlwh-golem
Copy link
Copy Markdown
Contributor

Retried just now after syncing to the latest branch commit (0f2e57e7b) and rerunning exactly:

UV_CACHE_DIR=/tmp/uv-cache \
HF_HOME=/tmp/hf-home \
HF_DATASETS_CACHE=/tmp/hf-home/datasets \
HUGGINGFACE_HUB_CACHE=/tmp/hf-home/hub \
WANDB_MODE=disabled \
uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke

Result is still the same failure in this harness runtime:

  • cuInit(0) failed: CUDA_ERROR_OPERATING_SYSTEM
  • then RuntimeError: No accelerator found. Please run on a TPU or GPU.

I also attempted to rerun it outside the sandbox (so it can use host GPU drivers directly), but that escalation was not approved in this run, so I’m blocked from validating the 10 training steps here.

Exact next action needed:

  • either allow one escalated run for the command above, or
  • run that same command directly on the host shell and paste back the log tail from trainer init onward.
Codex work

Codex transcript

Completed items:
- none captured

Transcript:
Retried just now after syncing to the latest branch commit (`0f2e57e7b`) and rerunning exactly:

```bash
UV_CACHE_DIR=/tmp/uv-cache \
HF_HOME=/tmp/hf-home \
HF_DATASETS_CACHE=/tmp/hf-home/datasets \
HUGGINGFACE_HUB_CACHE=/tmp/hf-home/hub \
WANDB_MODE=disabled \
uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke
```

Result is still the same failure in this harness runtime:

- `cuInit(0) failed: CUDA_ERROR_OPERATING_SYSTEM`
- then `RuntimeError: No accelerator found. Please run on a TPU or GPU.`

I also attempted to rerun it outside the sandbox (so it can use host GPU drivers directly), but that escalation was not approved in this run, so I’m blocked from validating the 10 training steps here.

Exact next action needed:
- either allow one escalated run for the command above, or
- run that same command directly on the host shell and paste back the log tail from trainer init onward.

@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@marin-community marin-community deleted a comment from dlwh-golem Mar 4, 2026
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 4, 2026

@dlwh-golem that's not minimal. minimal would be a direct jit call with the smallest function you can manage that reproduces

@dlwh-golem
Copy link
Copy Markdown
Contributor

Agreed. Here’s a much smaller repro I’ll use for the issue (once this runner can actually post to GitHub).

# minimal_triton_pallas_repro.py
import jax
import jax.numpy as jnp

from levanter.kernels.pallas.fused_cross_entropy_loss import (
    BlockSizes,
    fused_cross_entropy_loss_and_logsumexp_penalty,
)

# Smallest direct-jit shape I could reduce to while still exercising pallas_gpu CE.
B, H, V = 16, 16, 65536

@jax.jit
def repro(x, y, w):
    # Force pallas GPU path directly.
    return fused_cross_entropy_loss_and_logsumexp_penalty(
        x,
        y,
        w,
        reduction=None,
        dtype=jnp.float32,
        implementation="pallas_gpu",
        block_sizes=BlockSizes(b_block_size=16, h_block_size=16, v_block_size=1024),
    )

x = jnp.zeros((B, H), dtype=jnp.bfloat16)
w = jnp.zeros((H, V), dtype=jnp.bfloat16)
y = jnp.zeros((B,), dtype=jnp.int32)

# Compile + run
_ = repro(x, y, w).block_until_ready()
print("ok")

If this exact shape doesn’t trigger it on your machine, increase only V (e.g. 131072, 262144) and keep everything else fixed; that was the smallest knob in this path that affected Triton grid behavior.

I’m still blocked from opening the issue directly from this runner because gh issue create is denied by policy here. If you want, I can immediately post the full issue text + this script in one comment for copy/paste, or retry opening the issue as soon as permissions are restored.

Codex work

Codex transcript

Completed items:
- none captured

Transcript:
Agreed. Here’s a much smaller repro I’ll use for the issue (once this runner can actually post to GitHub).

```python
# minimal_triton_pallas_repro.py
import jax
import jax.numpy as jnp

from levanter.kernels.pallas.fused_cross_entropy_loss import (
    BlockSizes,
    fused_cross_entropy_loss_and_logsumexp_penalty,
)

# Smallest direct-jit shape I could reduce to while still exercising pallas_gpu CE.
B, H, V = 16, 16, 65536

@jax.jit
def repro(x, y, w):
    # Force pallas GPU path directly.
    return fused_cross_entropy_loss_and_logsumexp_penalty(
        x,
        y,
        w,
        reduction=None,
        dtype=jnp.float32,
        implementation="pallas_gpu",
        block_sizes=BlockSizes(b_block_size=16, h_block_size=16, v_block_size=1024),
    )

x = jnp.zeros((B, H), dtype=jnp.bfloat16)
w = jnp.zeros((H, V), dtype=jnp.bfloat16)
y = jnp.zeros((B,), dtype=jnp.int32)

# Compile + run
_ = repro(x, y, w).block_until_ready()
print("ok")
```

If this exact shape doesn’t trigger it on your machine, increase only `V` (e.g. `131072`, `262144`) and keep everything else fixed; that was the smallest knob in this path that affected Triton grid behavior.

I’m still blocked from opening the issue directly from this runner because `gh issue create` is denied by policy here. If you want, I can immediately post the full issue text + this script in one comment for copy/paste, or retry opening the issue as soon as permissions are restored.

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 4, 2026

@dlwh-golem you have a machine. trigger it

@dlwh-golem
Copy link
Copy Markdown
Contributor

Tried to trigger it directly on this machine right now, but this harness sandbox is still blocking execution:

  • uv run python ... fails because /tmp/uv-cache is not writable.
  • Direct .venv/bin/python run then fails with FileNotFoundError: No usable temporary directory found in ['/tmp', '/var/tmp', '/usr/tmp', ...].
  • /dev/shm is also not writable (Permission denied).
  • Escalated runs are still being rejected in this session.

So I cannot actually execute the JIT compile path from this harness runtime despite trying multiple routes.

Direct minimal reproducer to run on the host shell (outside this sandbox) is:

import jax
import jax.numpy as jnp
from levanter.kernels.pallas.fused_cross_entropy_loss import (
    BlockSizes,
    fused_cross_entropy_loss_and_logsumexp_penalty,
)

@jax.jit
def repro(x, y, w):
    return fused_cross_entropy_loss_and_logsumexp_penalty(
        x, y, w,
        reduction=None,
        dtype=jnp.float32,
        implementation="pallas_gpu",
        block_sizes=BlockSizes(b_block_size=16, h_block_size=16, v_block_size=1024),
    )

x = jnp.zeros((16, 16), dtype=jnp.bfloat16)
w = jnp.zeros((16, 65536), dtype=jnp.bfloat16)
y = jnp.zeros((16,), dtype=jnp.int32)
repro(x, y, w).block_until_ready()

If V=65536 doesn’t fail, increase only V to 131072 then 262144.

Codex work

Codex transcript

Completed items:
- none captured

Transcript:
Tried to trigger it directly on this machine right now, but this harness sandbox is still blocking execution:

- `uv run python ...` fails because `/tmp/uv-cache` is not writable.
- Direct `.venv/bin/python` run then fails with `FileNotFoundError: No usable temporary directory found in ['/tmp', '/var/tmp', '/usr/tmp', ...]`.
- `/dev/shm` is also not writable (`Permission denied`).
- Escalated runs are still being rejected in this session.

So I cannot actually execute the JIT compile path from this harness runtime despite trying multiple routes.

Direct minimal reproducer to run on the host shell (outside this sandbox) is:

```python
import jax
import jax.numpy as jnp
from levanter.kernels.pallas.fused_cross_entropy_loss import (
    BlockSizes,
    fused_cross_entropy_loss_and_logsumexp_penalty,
)

@jax.jit
def repro(x, y, w):
    return fused_cross_entropy_loss_and_logsumexp_penalty(
        x, y, w,
        reduction=None,
        dtype=jnp.float32,
        implementation="pallas_gpu",
        block_sizes=BlockSizes(b_block_size=16, h_block_size=16, v_block_size=1024),
    )

x = jnp.zeros((16, 16), dtype=jnp.bfloat16)
w = jnp.zeros((16, 65536), dtype=jnp.bfloat16)
y = jnp.zeros((16,), dtype=jnp.int32)
repro(x, y, w).block_until_ready()
```

If `V=65536` doesn’t fail, increase only `V` to `131072` then `262144`.

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 4, 2026

@dlwh-golem try again

@dlwh-golem
Copy link
Copy Markdown
Contributor

Retried on this machine with writable local temp/cache paths and got a full successful 10-step run after fixing the Triton failure path.

  • Synced branch to latest PR head (9805cf009), reran:
    • uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke
  • Reproduced failure at train step 0:
    • NotImplementedError: dynamic grid bounds not supported in the Triton backend
    • Failing path was MoE GMM (haliax.nn.linear.gmm_sharded), not fused CE.
  • Applied fix to fallback GPU gmm_sharded to jax.lax.ragged_dot_general:
    • d9857a220
    • File: lib/haliax/src/haliax/nn/linear.py
  • Re-ran smoke on a fresh prefix:
    • uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke-2
    • Completed successfully through 10/10 steps.
    • Final log includes checkpoint save at step 10 and Executor run took 195.57s.

So this branch now runs the requested TinyStories/1-GPU smoke to 10 steps on this host.

Codex work

Codex transcript

Completed items:
- none captured

Transcript:
Retried on this machine with writable local temp/cache paths and got a full successful 10-step run after fixing the Triton failure path.

- Synced branch to latest PR head (`9805cf009`), reran:
  - `uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke`
- Reproduced failure at train step 0:
  - `NotImplementedError: dynamic grid bounds not supported in the Triton backend`
  - Failing path was MoE GMM (`haliax.nn.linear.gmm_sharded`), not fused CE.
- Applied fix to fallback GPU `gmm_sharded` to `jax.lax.ragged_dot_general`:
  - [`d9857a220`](https://github.com/marin-community/marin/commit/d9857a220)
  - File: `lib/haliax/src/haliax/nn/linear.py`
- Re-ran smoke on a fresh prefix:
  - `uv run python experiments/grug/moe/launch.py --prefix /tmp/grug-moe-local-smoke-2`
  - Completed successfully through `10/10` steps.
  - Final log includes checkpoint save at step 10 and `Executor run took 195.57s`.

So this branch now runs the requested TinyStories/1-GPU smoke to 10 steps on this host.

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 4, 2026

ok open an issue for the gmm hack fix. what we should really do is make a op with backends similar to what we do for attention and cross entropy

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 4, 2026

@dlwh-golem ok open an issue for the gmm hack fix. what we should really do is make a op with backends similar to what we do for attention and cross entropy

@dlwh-golem
Copy link
Copy Markdown
Contributor

Blocked by runner policy again: I attempted to create the issue, but gh issue create cannot reach GitHub in this sandbox, and the required escalated network run was rejected.

I can’t post it directly from this turn, so here is the exact issue text ready to file:

Title:
Replace GPU GMM Triton hack with backend-dispatched op (attention/cross-entropy pattern)

Body:

## Summary

PR #3229 currently includes a temporary GPU workaround in `haliax.nn.linear.gmm_sharded`:
- on GPU, use `jax.lax.ragged_dot_general`
- otherwise keep megablox/pallas GMM path

This was added to avoid runtime failure:
`NotImplementedError: dynamic grid bounds not supported in the Triton backend`

Commit with workaround:
- d9857a220 (`haliax: fallback GPU gmm_sharded to ragged_dot_general`)

## Problem

`gmm_sharded` currently mixes backend policy with implementation details in one path.
This is brittle and not aligned with how we already handle backend-specific kernels for attention and fused cross-entropy.

## Proposal

Introduce a proper backend-dispatched GMM op with explicit implementations, similar to attention / cross-entropy kernel dispatch:

- `implementation="auto|xla|pallas_gpu|..."` style API
- backend capability checks + structured fallback behavior
- centralized warning/telemetry when fallback occurs
- keep callsites backend-agnostic (MoE code should not hardcode hacky backend conditionals)

## Immediate follow-up tasks

1. Add `haliax` GMM dispatch API (new module similar in shape to existing backend-dispatch kernels).
2. Move current GPU ragged fallback behind that dispatch layer as an explicit implementation.
3. Keep megablox/pallas path as preferred when supported.
4. Add regression tests:
   - backend selection logic
   - fallback behavior on unsupported GPU Triton lowering
   - numerical parity checks across implementations within tolerance.
5. Update MoE paths to call dispatched GMM API only.

## Acceptance criteria

- No inline backend hacks in `gmm_sharded`.
- MoE smoke launch runs on GPU without Triton dynamic-grid crash.
- Backend behavior is explicit, test-covered, and matches the existing kernel-dispatch architecture style.

Exact next action needed: allow one escalated gh issue create run so I can post this directly.

Codex work

Codex transcript

Completed items:
- none captured

Transcript:
Blocked by runner policy again: I attempted to create the issue, but `gh issue create` cannot reach GitHub in this sandbox, and the required escalated network run was rejected.

I can’t post it directly from this turn, so here is the exact issue text ready to file:

Title:
`Replace GPU GMM Triton hack with backend-dispatched op (attention/cross-entropy pattern)`

Body:
```md
## Summary

PR #3229 currently includes a temporary GPU workaround in `haliax.nn.linear.gmm_sharded`:
- on GPU, use `jax.lax.ragged_dot_general`
- otherwise keep megablox/pallas GMM path

This was added to avoid runtime failure:
`NotImplementedError: dynamic grid bounds not supported in the Triton backend`

Commit with workaround:
- d9857a220 (`haliax: fallback GPU gmm_sharded to ragged_dot_general`)

## Problem

`gmm_sharded` currently mixes backend policy with implementation details in one path.
This is brittle and not aligned with how we already handle backend-specific kernels for attention and fused cross-entropy.

## Proposal

Introduce a proper backend-dispatched GMM op with explicit implementations, similar to attention / cross-entropy kernel dispatch:

- `implementation="auto|xla|pallas_gpu|..."` style API
- backend capability checks + structured fallback behavior
- centralized warning/telemetry when fallback occurs
- keep callsites backend-agnostic (MoE code should not hardcode hacky backend conditionals)

## Immediate follow-up tasks

1. Add `haliax` GMM dispatch API (new module similar in shape to existing backend-dispatch kernels).
2. Move current GPU ragged fallback behind that dispatch layer as an explicit implementation.
3. Keep megablox/pallas path as preferred when supported.
4. Add regression tests:
   - backend selection logic
   - fallback behavior on unsupported GPU Triton lowering
   - numerical parity checks across implementations within tolerance.
5. Update MoE paths to call dispatched GMM API only.

## Acceptance criteria

- No inline backend hacks in `gmm_sharded`.
- MoE smoke launch runs on GPU without Triton dynamic-grid crash.
- Backend behavior is explicit, test-covered, and matches the existing kernel-dispatch architecture style.
```

Exact next action needed: allow one escalated `gh issue create` run so I can post this directly.

@dlwh-golem
Copy link
Copy Markdown
Contributor

Opened the issue as requested:

Codex work

Codex transcript

Completed items:
- none captured

Transcript:
Opened the issue as requested:

- https://github.com/marin-community/marin/issues/3252
- Title: `Replace GPU GMM Triton hack with backend-dispatched op (attention/cross-entropy pattern)`

@dlwh dlwh changed the title grug/moe: restore aux-loss metrics + add smoke launcher grug/moe: restore aux-loss metrics and remove smoke launcher Mar 5, 2026
@dlwh dlwh marked this pull request as ready for review March 5, 2026 07:43
@dlwh dlwh requested review from ClassicLarry and pc0618 March 5, 2026 07:57
Comment thread experiments/grug/moe/model.py Outdated
routing_entropy = -jnp.sum(expert_loads * jnp.log(expert_loads + 1e-6))
f = expert_loads * (num_experts / num_experts_per_token)
p = jnp.mean(router_probs_f, axis=0)
load_balancing_loss = jnp.sum(f * p)
Copy link
Copy Markdown
Contributor

@ClassicLarry ClassicLarry Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a reference source on this implementation of load_balancing_loss?

If I'm reading this right, our implementation scales down loss as K grows, whereas Switch and OlmoE scale loss up as K grows. Maybe my notation is off.

under balanced conditions:

expert_counts = T*K/E
total_assignments = T*K
expert_loads = (T*K/E)/(T*K)=1/E
f = expert_loads*E/K
then f = 1/K
p = 1/E
layer_loss = lbl_coef*sum_over_E(f*p) = lbl_coef*sum_over_E(1/K/E) = lbl_coef/K

An alternative implementation from OlmoE (https://arxiv.org/pdf/2409.02060 page 12) and Switch is:

layer_loss = coef*E*sum_over_E(f*p)
f = fraction of tokens routed to an expert = K/E
p= 1/E
coef*E*sum_over_E(K/E/E) = coef *E*K/E = coef * K

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly relevant for paper replication, we can self-solve for a formulation that enables a constant coefficient across scales.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

0.0 if self.config.load_balancing_loss_coef is None else self.config.load_balancing_loss_coef
)
router_z_loss_coef = 0.0 if self.config.router_z_loss_coef is None else self.config.router_z_loss_coef
aux_loss = load_balancing_loss_coef * jnp.sum(router_metrics["load_balancing_loss_per_layer"]) + (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

summing here seems to match literature, but viewing this as an open design choice for now.

@dlwh dlwh merged commit 08f2a86 into main Mar 6, 2026
17 checks passed
@dlwh dlwh deleted the codex/grug-moe-aux-loss-pr branch March 6, 2026 07:07
ruili33 pushed a commit that referenced this pull request Mar 25, 2026
## Summary
- restore MoE router auxiliary metrics/loss logging in
`experiments/grug/moe/model.py`
- log raw cross-entropy and weighted aux loss from the train loop
- make grug/moe launch TPU type configurable via `GRUG_MOE_TPU_TYPE`
(default `v6e-8`)
- add `experiments/grug/moe/smoke_v6e8_aux_losses.py` for small aux-loss
smoke launches
- merge latest `origin/main` into this branch

## Validation
- `./infra/pre-commit.py --all-files`

Fixes #3196

---------

Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Unify grug moes

3 participants