Skip to content

Experiment: ship Mamba-3 XLA SISO/MIMO TPU kernel #3868

@dlwh

Description

@dlwh

Summary

This experiment asked whether Marin should ship Mamba-3 on TPU through a new Pallas kernel or an XLA-first implementation. The issue’s conclusion is that the maintainable shipping path is pure JAX/XLA: SISO reached production-ready parity and throughput with a chunk_size=512 TPU default, MIMO rank-4 worked with the same scan-state shape and a chunk_size=256 default, and the attempted Pallas variants either failed on TPU or lost on performance. The broad experiment is effectively resolved, with follow-on cleanup and narrower kernel work tracked separately.

Helpful links

Description

TL;DR: Implement and benchmark a production-facing Mamba-3 TPU path in Levanter with a pure JAX/XLA SSD-style decomposition, keeping Pallas optional and non-default. The thread covered SISO real-valued Mamba-3 first, then a real-valued MIMO rank-4 extension, with all work anchored in a direct recurrent reference, a local research logbook, and TPU benchmarks on v5p-8.

Scope:

  • add a correct SISO Mamba-3 reference and chunked XLA fast path
  • add a real-valued MIMO rank-4 extension without changing scan carry shape
  • expose a hybrid public API/config surface for SISO vs MIMO
  • validate against direct recurrences and upstream Mamba Torch references
  • benchmark representative TPU shapes and determine default chunk sizes

Success metrics:

  • forward and grad parity against direct recurrent references
  • upstream reference parity for the real-valued path
  • stable TPU XLA path with explicit default chunk sizes
  • a production-facing API that defaults to the fastest maintainable path

Stop criteria:

  • SISO and MIMO both run end-to-end on TPU with XLA only
  • the winning defaults are identified and documented
  • Pallas is either shown to help or explicitly rejected for now

Decision log:

  • 2026-03-18: choose pure JAX/XLA SSD-style decomposition as the primary implementation target; do not start from a pairwise Pallas kernel. Owner: agent. Evidence: direct recurrence equivalence tests and initial TPU benchmarks.
  • 2026-03-18: set SISO TPU default to chunk_size=512. Owner: agent. Evidence: best throughput on v5p-8 among 128/256/512/1024/2048 sweep points.
  • 2026-03-18: reject the first SISO Pallas local-output revisit. Owner: agent. Evidence: compilation/runtime failures and no kept TPU training win.
  • 2026-03-19: add real-valued MIMO rank-4 on the same XLA decomposition with unchanged scan carry shape. Owner: agent. Evidence: direct recurrence tests, R=1 parity, TPU compile/run success.
  • 2026-03-19: set MIMO rank-4 TPU default to chunk_size=256. Owner: agent. Evidence: best token throughput across 128/256/512/1024 on v5p-8.
  • 2026-03-19: reject the MIMO Pallas local-output revisit for now. Owner: agent. Evidence: TPU layout failures and no valid parity/perf result.

Negative-results index:

  • token-scan local rewrite: much slower than materialized chunk-local XLA on TPU
  • associative_scan: much slower than lax.scan on TPU
  • SISO Pallas local-output attempt: not worth keeping
  • MIMO streamed-rank XLA local rewrite: much slower than current materialized local block
  • MIMO Pallas local-output attempt: TPU layout crash, no keepable result

Conclusion:
The maintainable shipping path is XLA-first. SISO is production-ready with a chunk_size=512 TPU default. Real-valued MIMO rank-4 is viable with the same scan-carry shape and a chunk_size=256 TPU default. Pallas should remain out of the default path until a narrower kernel shape demonstrates a real TPU win.

Confidence:

  • SISO XLA path: stable
  • MIMO rank-4 XLA path: replicated
  • Pallas negative result: replicated for the attempted kernel shapes

Hypothesis or Goal

Mamba-3 on TPU should be implementable as a chunked SSD-style recurrence in pure JAX/XLA first, with Pallas reserved only for a clearly dominant local block if profiling justifies it. For MIMO, the extra rank should live in B/C/X/Z/O while the carried scan state remains [P, N], preserving the SISO decomposition and keeping the TPU implementation maintainable.

Links

Results

Representative slice-level TPU results on v5p-8 with explicit G sharding, seq_len=16384, groups=16, bf16:

SISO XLA, chunk_size=512

  • 128 x 512: 341.20M / 211.94M tok/s forward/backward
  • 512 x 1024: 152.60M / 91.89M tok/s
  • 1024 x 512: 171.88M / 99.67M tok/s

MIMO rank-4 XLA, chunk_size=256

  • 128 x 512: 78.53M / 39.73M tok/s forward/backward
  • 512 x 1024: 34.92M / 15.86M tok/s
  • 1024 x 512: 37.76M / 22.70M tok/s

Key findings:

  • SISO and MIMO both pass direct reference parity tests.
  • Real-valued SISO and MIMO references also match the upstream Mamba Torch reference logic in Torch-gated tests.
  • MIMO R=4 keeps the scan carry at SISO shape and compiles/runs end-to-end on TPU with XLA only.
  • The honest production recommendation is hybrid XLA mode, not Pallas.

Already sealed:

  • the research question is resolved enough to ship the XLA path and close the experiment issue
  • future work should be tracked in follow-on optimization issues rather than continuing this thread

Metadata

Metadata

Assignees

No one assigned

    Labels

    agent-generatedCreated by automation/agentexperimentlevanterIssues related to Levanter librarytldrIssue has a community-friendly TL;DR summary

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions