Skip to content

[🤖][Mamba-3] Quaternion States #959

@swfsql

Description

@swfsql

Hi, I believe Quaternions may be reachable with the same RoPE trick.
When I say "believe", I really mean "having faith", because it was an AI who said it (and made one implementation out of it) and I have no idea if it is correct or not. It seems to behave well and differently from Complex though, but I don't really have any real benchmark to support this claim.

If desired, happy to share the entire prompt and context which generate the response. Below is the message generated by the AI - (sorry but it's beyond me to comment on anything there, and "abelian" sounds like 🐝 in my mother language).


Extending Mamba-3's data-dependent RoPE to a non-abelian (quaternion) rotation

While building a port of Mamba-1/2/3, we noticed something about your complex-SSM / RoPE derivation that seemed worth passing along. Nothing here is a feature request — the point is just that a non-abelian rotation (unit quaternions, SU(2) ⊂ SO(4)) falls out of §4.2 (Complex-Valued SSMs) and Appendix C almost for free, and crucially leaves the official single-pass SSD kernel unchanged. So this is the math plus a few implementation notes, in case any of it is useful to you.

We've verified the whole thing end-to-end in our port — forward = step parity, split-prefill continuity, gradient parity, and exact collapse to your RoPE in the abelian limit — on values and gradients.

1. The RoPE factoring never required commutativity

Prop. Rotary Embedding Equivalence with Exponential-Trapezoidal Discretization gives

h̃ₜ = αₜ h̃ₜ₋₁ + βₜ B̄ₜ₋₁ xₜ₋₁ + γₜ B̄ₜ xₜ ,    yₜ = C̄ₜᵀ h̃ₜ

with B̄ₜ = (∏_{s=0..t} Rₛᵀ) Bₜ, C̄ₜ = (∏_{s=0..t} Rₛᵀ) Cₜ. The Appendix C proof invokes Rᵢ Rⱼ = Rⱼ Rᵢ (the SO(2) blocks commute) to collapse the cumulative rotation into a cumsum of angles.

The factoring itself — pushing the inter-step rotation onto B and C and leaving a plain scalar-decay SSD — does not use commutativity. With the ordered cumulative rotation Pₜ = Rₜ Rₜ₋₁ ⋯ R₁ (newest on the left), the inter-step product telescopes by associativity alone:

Rₜ ⋯ Rᵢ₊₁ = Pₜ Pᵢ⁻¹

and then, using only orthogonality Pᵢ⁻¹ = Pᵢᵀ,

Cₜᵀ (Rₜ⋯Rᵢ₊₁) Bᵢ  =  Cₜᵀ Pₜ Pᵢᵀ Bᵢ  =  (Pₜᵀ Cₜ)ᵀ (Pᵢᵀ Bᵢ)  =  C̄ₜᵀ B̄ᵢ .

The per-head scalar decay αₜ commutes with everything, so the mask L is unaffected. Commutativity buys only the closed-form cumsum; drop it and the cumsum becomes an ordered associative scan, while the SSD duality is preserved verbatim.

2. SO(2) → SU(2): unit quaternions

Draw Rₜ from the left-isoclinic SU(2) ⊂ SO(4) — left-multiplication v ↦ qₜ ⊗ v by a per-step unit quaternion. This is the smallest non-abelian step up from the 2×2 rotations:

  • SO(2) ≅ U(1) is abelian; its finite subgroups are cyclic (solvable) — enough for parity / mod-k, the TC⁰ regime targeted by the complex update.
  • Unit quaternions are non-abelian and contain the binary icosahedral group 2I = SL(2,5) (a double cover of A₅), which is non-solvable. That non-solvable subgroup is the algebraic ingredient for representing A₅ word problems — the kind of state-tracking a cumsum-of-angles rotation simply can't express. (The TC⁰ → NC¹ framing from Barrington's theorem is the lens we used to motivate this; we're claiming the representational gap here, not a trained-model separation — that's a much bigger empirical question we haven't settled. See the note at the end.)

A state_rank of r = 4·J is treated as J independent quaternion blocks; the rotation acts within each 4-block exactly as RoPE acts within each 2-pair.

3. What changes vs. what stays

Stays identical: the scalar-decay SSD core, the chunked algorithm, and the custom recompute backward. Because the rotation is applied to B/C before chunking and is chunk_len-agnostic, the official single-pass kernel — the Triton SISO / Tilelang MIMO form with scaleₜ = γₜ + (1−λₜ₊₁)·Δₜ₊₁, the strict-lower-triangular intra-chunk mask, the same-step γ correction, and the boundary-β seed — only ever consumes the rotated /. It requires no modification: the rotated projections are produced upstream, exactly where the abelian RoPE already sits. (We additionally cross-check against a double-pass γ/β decomposition; both consume the rotation identically, as expected from §1.)

Changes — three small pieces:

  1. Per-step rotation materialisation. From the in-projection take a scaled rotation vector g ∈ ℝ³ (axis·angle) per block, mapped to a unit quaternion via the exponential map

    q = (cos(‖g‖/2),  sin(‖g‖/2) · g/‖g‖)
    

    Scaling g by Δₜ makes a small step a near-identity rotation — the analogue of Δₜ · π · tanh(θₜ). (sin(‖g‖/2)/‖g‖ → ½ is the stable form at g→0.)

  2. cumsum → ordered associative scan. cum_angle becomes a cumulative quaternion product Pₜ = qₜ ⊗ ⋯ ⊗ q₀. A product of unit quaternions is unit, so the scan is exactly orthogonal — no drift, no wrap_angle — and the cross-chunk carry is a single quaternion per (head, block), the precise analogue of the cum_angle accumulator. We run it as a Hillis–Steele parallel scan (O(log T) dependency depth), not a token loop (see §4).

  3. Apply by conjugate. B̄ = rotate(B, conj(P)), C̄ = rotate(C, conj(P)). For a unit quaternion conj(q) = q⁻¹, and left-mult q ↦ L_q is the 4×4 orthogonal matrix, so L_q v = q ⊗ v (no 4×4 materialisation on the hot path).

The recurrent step() is a single qₜ ⊗ Pₜ₋₁ compose, mirroring the cumsum's single +.

4. Implementation findings worth flagging

  • The abelian collapse is exact and serves as a regression test. Restricting every qₜ to a single fixed axis makes them commute, and the quaternion scan provably reduces to a cumsum of half-angles — reproducing apply_rope bit-for-bit. (Isoclinic subtlety: a single q = (cos φ, sin φ, 0, 0) rotates both planes of its 4-block by the same φ, so the cross-check locks each block's two RoPE pairs to a shared angle.)

  • The rotation's backward and the SSD's backward compose, independently. The positional rotation is its own autodiff op upstream of the SSD; the SSD's memory-efficient recompute backward (the K1–K5 replay that returns d_v, d_da, d_B̄, d_C̄, d_gamma, d_scale, d_init) treats / as ordinary leaf inputs. The d_B̄/d_C̄ it produces then flow back through the rotation's own custom backward by ordinary autodiff composition — neither node needs to know about the other. So the kernel-side recompute backward you already ship is untouched, and only the rotation contributes a second (small) custom node.

  • Memory-efficient custom backward for the scan, no token loop. The parallel scan retains O(log T) full-sequence intermediates; to recover that we use a recompute backward (in the spirit of ssd_combined) that saves only the two leaf inputs and recomputes the prefix product. The VJP has a clean closed form for unit quaternions (using conj as inverse), entirely parallel:

    S[t]   = Σ_{s≥t} conj(Pₛ) ⊗ d_cum[s]     (reverse-cumsum)
    G[t]   = Pₜ ⊗ S[t]
    d_q[t] = G[t] ⊗ conj(cum[t−1])
    d_init = S[0]
    

    Since final_carry = cum[:, −1] is a slice of the single output, no two-output gradient flattening is needed (unlike the SSD's (y, final_state) pair).

  • The Hamilton product wants struct-of-arrays. A packed […, 4] layout forces every product to narrow four components and cat back (fusion-breakers); keeping (w, x, y, z) as four separate tensors makes the product pure element-wise arithmetic — also the natural register layout for a fused kernel.

5. Where the generalisation stops

In case quaternions invite a push to octonions or larger k:

  • The true requirements are (a) orthogonal, (b) associative composition, (c) a cheap orthogonal parameterisationnot "division algebra".
  • Octonions do not work: non-associativity means o ↦ L_o is not a homomorphism (L_{ab} ≠ L_a L_b), so there is no unit-octonion product scan; S⁷ is not a group. Plain SO(8) Givens blocks are fine, but that is not octonion algebra and gains nothing over generic SO(8).
  • Non-solvability is already reached at the quaternion (2I ⊃ A₅); larger k is more parameters, not a new complexity class. A commuting higher-D rotation collapses to independent 2D angles (the maximal torus) — i.e. RoPE again. Only the non-commuting step is new.
  • For general SO(k), the cheap parameterisations are Cayley (I−S)(I+S)⁻¹ of a skew S, or Givens/Householder products (cf. the DeltaNet / RWKV-7 line); the Hurwitz "free" orthogonal parameterisation only exists at dims 1/2/4.

So the short version: the complex-SSM / RoPE-trick formulation already contains its non-abelian generalisation. The factoring survives, only cumsum → associative scan and cum_angle → cum_quaternion change, and the official single-pass kernel is untouched. We haven't run this at any serious scale — what we have is the math, the parity/gradient tests, and a working portable implementation, all of which we're happy to share if it's useful.

A note on the complexity-theory framing in §2: the 2I ⊃ A₅ / TC⁰ → NC¹ story is genuinely what motivated reaching for a non-abelian group, and the representational facts are solid (an abelian rotation provably can't do A₅; a quaternion can). But "can represent" is not "will learn", and we make no claim about the latter — whether a trained Mamba-3 quaternion layer actually acquires A₅-style state-tracking is exactly the experiment we haven't done. Treat §2 as motivation, not a result.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions