Summary
This experiment asked whether Marin should ship Mamba-3 on TPU through a new Pallas kernel or an XLA-first implementation. The issue’s conclusion is that the maintainable shipping path is pure JAX/XLA: SISO reached production-ready parity and throughput with a chunk_size=512 TPU default, MIMO rank-4 worked with the same scan-state shape and a chunk_size=256 default, and the attempted Pallas variants either failed on TPU or lost on performance. The broad experiment is effectively resolved, with follow-on cleanup and narrower kernel work tracked separately.
Helpful links
Description
TL;DR: Implement and benchmark a production-facing Mamba-3 TPU path in Levanter with a pure JAX/XLA SSD-style decomposition, keeping Pallas optional and non-default. The thread covered SISO real-valued Mamba-3 first, then a real-valued MIMO rank-4 extension, with all work anchored in a direct recurrent reference, a local research logbook, and TPU benchmarks on v5p-8.
Scope:
- add a correct SISO Mamba-3 reference and chunked XLA fast path
- add a real-valued MIMO rank-4 extension without changing scan carry shape
- expose a hybrid public API/config surface for SISO vs MIMO
- validate against direct recurrences and upstream Mamba Torch references
- benchmark representative TPU shapes and determine default chunk sizes
Success metrics:
- forward and grad parity against direct recurrent references
- upstream reference parity for the real-valued path
- stable TPU XLA path with explicit default chunk sizes
- a production-facing API that defaults to the fastest maintainable path
Stop criteria:
- SISO and MIMO both run end-to-end on TPU with XLA only
- the winning defaults are identified and documented
- Pallas is either shown to help or explicitly rejected for now
Decision log:
- 2026-03-18: choose pure JAX/XLA SSD-style decomposition as the primary implementation target; do not start from a pairwise Pallas kernel. Owner: agent. Evidence: direct recurrence equivalence tests and initial TPU benchmarks.
- 2026-03-18: set SISO TPU default to
chunk_size=512. Owner: agent. Evidence: best throughput on v5p-8 among 128/256/512/1024/2048 sweep points.
- 2026-03-18: reject the first SISO Pallas local-output revisit. Owner: agent. Evidence: compilation/runtime failures and no kept TPU training win.
- 2026-03-19: add real-valued MIMO rank-4 on the same XLA decomposition with unchanged scan carry shape. Owner: agent. Evidence: direct recurrence tests,
R=1 parity, TPU compile/run success.
- 2026-03-19: set MIMO rank-4 TPU default to
chunk_size=256. Owner: agent. Evidence: best token throughput across 128/256/512/1024 on v5p-8.
- 2026-03-19: reject the MIMO Pallas local-output revisit for now. Owner: agent. Evidence: TPU layout failures and no valid parity/perf result.
Negative-results index:
- token-scan local rewrite: much slower than materialized chunk-local XLA on TPU
associative_scan: much slower than lax.scan on TPU
- SISO Pallas local-output attempt: not worth keeping
- MIMO streamed-rank XLA local rewrite: much slower than current materialized local block
- MIMO Pallas local-output attempt: TPU layout crash, no keepable result
Conclusion:
The maintainable shipping path is XLA-first. SISO is production-ready with a chunk_size=512 TPU default. Real-valued MIMO rank-4 is viable with the same scan-carry shape and a chunk_size=256 TPU default. Pallas should remain out of the default path until a narrower kernel shape demonstrates a real TPU win.
Confidence:
- SISO XLA path: stable
- MIMO rank-4 XLA path: replicated
- Pallas negative result: replicated for the attempted kernel shapes
Hypothesis or Goal
Mamba-3 on TPU should be implementable as a chunked SSD-style recurrence in pure JAX/XLA first, with Pallas reserved only for a clearly dominant local block if profiling justifies it. For MIMO, the extra rank should live in B/C/X/Z/O while the carried scan state remains [P, N], preserving the SISO decomposition and keeping the TPU implementation maintainable.
Links
Results
Representative slice-level TPU results on v5p-8 with explicit G sharding, seq_len=16384, groups=16, bf16:
SISO XLA, chunk_size=512
128 x 512: 341.20M / 211.94M tok/s forward/backward
512 x 1024: 152.60M / 91.89M tok/s
1024 x 512: 171.88M / 99.67M tok/s
MIMO rank-4 XLA, chunk_size=256
128 x 512: 78.53M / 39.73M tok/s forward/backward
512 x 1024: 34.92M / 15.86M tok/s
1024 x 512: 37.76M / 22.70M tok/s
Key findings:
- SISO and MIMO both pass direct reference parity tests.
- Real-valued SISO and MIMO references also match the upstream Mamba Torch reference logic in Torch-gated tests.
- MIMO
R=4 keeps the scan carry at SISO shape and compiles/runs end-to-end on TPU with XLA only.
- The honest production recommendation is hybrid XLA mode, not Pallas.
Already sealed:
- the research question is resolved enough to ship the XLA path and close the experiment issue
- future work should be tracked in follow-on optimization issues rather than continuing this thread
Summary
This experiment asked whether Marin should ship Mamba-3 on TPU through a new Pallas kernel or an XLA-first implementation. The issue’s conclusion is that the maintainable shipping path is pure JAX/XLA: SISO reached production-ready parity and throughput with a
chunk_size=512TPU default, MIMO rank-4 worked with the same scan-state shape and achunk_size=256default, and the attempted Pallas variants either failed on TPU or lost on performance. The broad experiment is effectively resolved, with follow-on cleanup and narrower kernel work tracked separately.Helpful links
Description
TL;DR: Implement and benchmark a production-facing Mamba-3 TPU path in Levanter with a pure JAX/XLA SSD-style decomposition, keeping Pallas optional and non-default. The thread covered SISO real-valued Mamba-3 first, then a real-valued MIMO rank-4 extension, with all work anchored in a direct recurrent reference, a local research logbook, and TPU benchmarks on v5p-8.
Scope:
Success metrics:
Stop criteria:
Decision log:
chunk_size=512. Owner: agent. Evidence: best throughput on v5p-8 among128/256/512/1024/2048sweep points.R=1parity, TPU compile/run success.chunk_size=256. Owner: agent. Evidence: best token throughput across128/256/512/1024on v5p-8.Negative-results index:
associative_scan: much slower thanlax.scanon TPUConclusion:
The maintainable shipping path is XLA-first. SISO is production-ready with a
chunk_size=512TPU default. Real-valued MIMO rank-4 is viable with the same scan-carry shape and achunk_size=256TPU default. Pallas should remain out of the default path until a narrower kernel shape demonstrates a real TPU win.Confidence:
Hypothesis or Goal
Mamba-3 on TPU should be implementable as a chunked SSD-style recurrence in pure JAX/XLA first, with Pallas reserved only for a clearly dominant local block if profiling justifies it. For MIMO, the extra rank should live in
B/C/X/Z/Owhile the carried scan state remains[P, N], preserving the SISO decomposition and keeping the TPU implementation maintainable.Links
Results
Representative slice-level TPU results on v5p-8 with explicit
Gsharding,seq_len=16384,groups=16,bf16:SISO XLA,
chunk_size=512128 x 512:341.20M / 211.94M tok/sforward/backward512 x 1024:152.60M / 91.89M tok/s1024 x 512:171.88M / 99.67M tok/sMIMO rank-4 XLA,
chunk_size=256128 x 512:78.53M / 39.73M tok/sforward/backward512 x 1024:34.92M / 15.86M tok/s1024 x 512:37.76M / 22.70M tok/sKey findings:
R=4keeps the scan carry at SISO shape and compiles/runs end-to-end on TPU with XLA only.Already sealed: