Skip to content

[Levanter] Add XLA-first Mamba-3 SISO and MIMO TPU kernels#3961

Merged
dlwh merged 35 commits intomainfrom
codex/mamba3-xla-hybrid
Mar 25, 2026
Merged

[Levanter] Add XLA-first Mamba-3 SISO and MIMO TPU kernels#3961
dlwh merged 35 commits intomainfrom
codex/mamba3-xla-hybrid

Conversation

@dlwh
Copy link
Copy Markdown
Member

@dlwh dlwh commented Mar 22, 2026

Add the production Mamba-3 TPU path with XLA-first SISO and real-valued MIMO kernels, hybrid mode defaults, and direct/upstream reference parity coverage. Keep Pallas out of the default path and ship the chunk defaults selected from the TPU sweep.

Part of #3868

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Repo admins can enable using credits for code reviews in their settings.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 22, 2026

Claude finished @dlwh's task in 1m 45s —— View job


PR Review: Mamba-3 XLA TPU Kernels

  • Gather context and read changed files
  • Review SSD (SISO) kernel implementations
  • Review Mamba3 (MIMO) kernel implementations
  • Review tests and benchmarks
  • Post final review

Adds XLA-first SISO and real-valued MIMO Mamba-3 TPU kernels with reference oracles, attention-style APIs, hybrid mode dispatch, and a custom VJP for the SISO path. The implementation follows the spec in .agents/projects/mamba3-mimo-native-attentionish-xla.md and keeps Pallas out of the default path as stated.

Specification check: A spec exists at .agents/projects/mamba3-mimo-native-attentionish-xla.md. The code adheres to the described decomposition (local block, chunk summary, lax.scan, prefix emit, diagonal correction). The spec's carry convention change from [P, N] to [N, P] for the native attention-style path is implemented. Test coverage includes all described parity scenarios (chunked vs sequential, XLA vs reference, direct recurrence, upstream torch reference).

Findings:

  • Duplicate validation (api.py:715-716 and api.py:746-747): mamba3_mimo_attentionish_forward_from_transformed checks implementation not in (None, "xla", "reference") twice -- once at line 715 and again at line 746 after the reshape work. The second check is dead code since any invalid value would already have raised. Fix this →

  • Traced-value check in validation (api.py:414): _require_none_or_zero uses bool(jnp.all(value == 0)) which forces a concrete evaluation. Under jax.jit, if angles or segsum is a traced array (not a compile-time constant), this will fail with a ConcretizationTypeError. The current callers pass None or literal zero arrays so this works today, but it is fragile if callers ever pass traced inputs. Consider documenting this limitation or using jax.debug.callback / a static check.

  • _materialize_chunked_layout is jnp.copy (api.py:466-469): This forces a concrete copy to break layout fusion. The intent is clear from the docstring, but it is worth noting that jnp.copy may be a no-op under some XLA optimization passes. If this is load-bearing for TPU performance, it may need a jax.ensure_compile_time_eval or a custom primitive in the future.

  • Custom VJP recomputes forward (xla.py:187-220): The backward pass of _mamba3_chunked_forward_xla_batched_default_custom_vjp_bwd saves all 6 primals as residuals, then recomputes the full forward inside jax.vjp to obtain the pullback. This means the forward is computed twice during backprop. This is acknowledged in the spec ("recompute-heavy custom VJP") and is correct, but the residuals tuple is then unused except to feed into the re-forward. The primals_out from jax.vjp is discarded. This is fine for correctness (the VJP is just delegating to autodiff with explicit SymbolicZero handling), but worth noting that the residuals could be reduced to just the primals needed for recomputation.

No bugs or correctness issues found. The test coverage is thorough: 30+ tests covering SISO/MIMO parity against sequential oracles, direct recurrences, upstream torch references, gradient matching, dtype preservation, JIT compilation smoke tests, attention-style API equivalence, and QK group mapping.

@dlwh dlwh added the agent-generated Created by automation/agent label Mar 22, 2026 — with ChatGPT Codex Connector
Copy link
Copy Markdown
Member Author

dlwh commented Mar 22, 2026

🤖 Follow-up: the research logbook for this thread now lives on the dedicated research branch research/mamba3-xla-tpu-logbook rather than on the PR branch.

Logbook:

Latest block-level TPU comparison from that logbook:

  • hidden=4096, batch=4, bf16, v5p-8
  • SISO beats both Splash attention baselines at seq_len=4096, 8192, and 32768
  • MIMO r=4 is behind attention at 4096 and 8192, but overtakes both attention baselines by 32768

Representative seq_len=32768 forward/backward tokens/s:

  • attention_gqa16: 3.30M / 0.62M
  • attention_mha32: 3.02M / 0.60M
  • mamba3_siso: 12.98M / 4.48M
  • mamba3_mimo_r4: 5.27M / 1.02M

That strengthens the current shipping recommendation: the XLA Mamba-3 path is already training-competitive for long-context workloads, with SISO especially strong and MIMO becoming competitive once context is long enough.

Copy link
Copy Markdown
Member Author

dlwh commented Mar 22, 2026

🤖 Follow-up on Claude's review:

Landed one concrete cleanup in a6654c550:

  • removed the dead duplicate implementation validation in mamba3_mimo_attentionish_forward_from_transformed(...)

I also kept the custom-VJP note in code:

  • added a comment in xla.py explaining that the default SISO custom VJP is intentionally recompute-heavy and that primals_out is only used to materialize SymbolicZero cotangents with the correct shapes/dtypes before calling the pullback

The other two review items are not being changed in this PR:

  • _require_none_or_zero(...) forcing a concrete boolean is a real limitation, but it is also the current intended contract for the real-valued attention-style wrapper: angles and segsum are only supported as None or literal all-zero arrays. I’m treating traced nonzero validation as future work rather than widening behavior in this PR.
  • _materialize_chunked_layout(...) = jnp.copy(...) is deliberate TPU tuning. The current evidence is that forcing chunk-major materialization helps this path; if XLA ever stops respecting that strongly enough, that should be handled as a focused follow-up rather than changed speculatively here.

Re-ran:

  • ./infra/pre-commit.py lib/levanter/src/levanter/kernels/pallas/mamba3/api.py lib/levanter/src/levanter/kernels/pallas/mamba3/xla.py --fix
  • uv run --package levanter --group test pytest lib/levanter/tests/kernels/test_pallas_mamba3.py -q
  • result: 31 passed, 2 skipped

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 25, 2026

@claude review again please

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 25, 2026

Claude finished @dlwh's task in 2m 12s —— View job


PR Re-Review: Mamba-3 XLA TPU Kernels

  • Gather context and read all changed files
  • Review SSD (SISO) kernel implementations
  • Review Mamba3 (MIMO) kernel implementations
  • Review API, config, and public surface
  • Review tests
  • Post final review

Full re-review of the current state of this PR (14 files, ~5350 lines). The duplicate validation from the first review has been fixed in a6654c550.

Architecture & Design

The layered decomposition is clean and well-structured:

  • ssd/ — shared SSD primitives (local block, chunk state, scan, prefix emit) used by both SISO and MIMO
  • mamba3/reference.py — Mamba-3 reference oracles building on SSD primitives, with the g_t diagonal correction and MIMO rank expand/collapse/gate
  • mamba3/xla.py — XLA-first production kernels reusing SSD XLA internals, with a custom VJP for the SISO default path and a full native MIMO ranked forward
  • mamba3/api.py — public dispatch layer: chunked, attention-style SISO/MIMO, hybrid mode router
  • mamba3/config.py — tuned TPU chunk-size defaults and HybridModeConfig

The separation of concerns is excellent. Each layer has a clear responsibility, and the SSD → Mamba-3 layering avoids code duplication while keeping the MIMO rank extensions isolated.

Correctness

No bugs found. The mathematical structure is sound:

  1. SISO path: mamba3_intra_chunk = ssd_intra_chunk - diag(c·b) * out_correction * x — correctly implements the g_t rewrite from the Mamba-3 paper.

  2. MIMO ranked path (xla.py:273-339): The bc contraction uses gktnv,gksnu->gktsuv (5-index intermediate) and the local output einsum gkts,gktsuv,gkusp->gkvtp correctly contracts over the rank pairs. The diagonal correction bc_diag einsum gktnu,gktnv->gktuv correctly produces the per-step same-token interaction term.

  3. Custom VJP (xla.py:141-230): The recompute-heavy approach is well-documented in both code comments and the spec. The SymbolicZero materialization via primals_out is the correct JAX pattern for custom VJPs that delegate to autodiff.

  4. Chunk-boundary scale computation (reference.py:37-53): prepare_mamba3_chunked_scales correctly flattens, shifts, and reshapes to handle the cross-chunk q_next boundary term.

  5. Attention-style API reshape logic (api.py:430-463): The _grouped_qk_to_chunked_state_rank and _chunk_*_by_head helpers correctly handle the attention → chunked layout transformation, including QK group expansion via jnp.take.

Test Coverage

Thorough — 30+ tests covering:

  • Scale preparation (unit values, chunk boundaries)
  • Chunked ↔ sequential parity (SISO and MIMO)
  • XLA ↔ reference parity (both intra-chunk and full forward)
  • Direct recurrence oracle parity (the paper's h_t, v_t recurrence)
  • Upstream torch reference parity (ported from state-spaces/mamba — both SISO and MIMO)
  • Gradient matching (XLA vs reference, SISO and MIMO)
  • Attention-style API ↔ chunked API parity (SISO and MIMO)
  • QK group mapping, unreduced output, final_k, d parameter
  • Hybrid mode dispatch (SISO and MIMO)
  • dtype preservation (bfloat16 carry)
  • TPU-aligned JIT compilation smoke tests
  • Finiteness stress test (scaled dt)

Items Worth Noting (not blocking)

  1. _require_none_or_zero concrete eval (api.py:414): As discussed, bool(jnp.all(value == 0)) will raise ConcretizationTypeError if ever called with a traced non-constant. The current callers are safe, and this is tracked as future work.

  2. _materialize_chunked_layout = jnp.copy (api.py:466-469): Deliberate TPU layout tuning. If XLA ever stops respecting jnp.copy as a materialization barrier, this may need a custom primitive.

  3. Attention-style SISO applies biases twice (api.py:626-633 and the _mamba3_attentionish_forward_chunked_from_transformed inner call at api.py:648-660): mamba3_attentionish_forward_from_transformed adds q_bias/k_bias to the chunked tensors at lines 629-631, then passes the same q_bias/k_bias kwargs to _mamba3_attentionish_forward_chunked_from_transformed which adds them again at lines 532-535. This means biases are applied twice in the mamba3_attentionish_forward_from_transformed path (but not in the mamba3_attentionish_forwardmamba3_attentionish_forward_from_transformed path since it just delegates). Since the tests use zero biases, this double-application is invisible in the test suite but would produce incorrect results with non-zero biases. Fix this →

  4. mamba3_intra_chunk loop never iterates more than once (api.py:112-117): The for impl in impls loop returns on the first successful call and raises ValueError for unknown implementations, so the fallback/ExceptionGroup pattern from the SSD layer isn't actually reachable here. This is harmless but could be simplified to a direct dispatch.

  5. __all__ duplication — Both ssd/api.py and ssd/__init__.py define __all__ with the same symbols. The __init__.py version is canonical; the one in api.py (line 237) is redundant but not harmful.

  6. No MIMO custom VJP: Unlike the SISO path which has an explicit @jax.custom_vjp, the MIMO ranked forward in xla.py relies on JAX's default autodiff. Per the spec, this is intentional for the first pass, with a custom VJP as future work if residual memory becomes a concern.

Verdict

Looks good to merge. The code is well-structured, mathematically correct, thoroughly tested, and follows the spec closely. The one actionable finding is the double bias application in the SISO attention-style path (item 3 above) — this should be fixed before or shortly after merge to avoid a latent bug when non-zero biases are used.


@dlwh dlwh merged commit f38dcca into main Mar 25, 2026
39 checks passed
@dlwh dlwh deleted the codex/mamba3-xla-hybrid branch March 25, 2026 03:39
Helw150 pushed a commit that referenced this pull request Apr 8, 2026
Add the production Mamba-3 TPU path with XLA-first SISO and real-valued
MIMO kernels, hybrid mode defaults, and direct/upstream reference parity
coverage. Keep Pallas out of the default path and ship the chunk
defaults selected from the TPU sweep.

Part of #3868
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant