Problem 1 -> Quadratic attention cost : The core operation of a Transformer is
Problem 2 -> KV cache memory explosion : During autoregressive inference (generating one token at a time), the Transformer caches the Key and Value matrices for every previously generated token. This is necessary to avoid recomputing attention from scratch at each step. But the cache grows linearly with sequence length and never shrinks. For a context window of 100,000 tokens with 32 layers and 8 heads at 128d per head;
Problem 3 -> Uniform attention across all positions : Every token attends to every other token with equal architectural access, regardless of relevance. A token at position 1 and a token at position 50,000 are computationally equivalent. The model must learn which positions matter from data alone, with no inductive bias toward recency or locality. For signals where recent context matters far more than distant context (time series, audio, streaming data), this is wasteful; the model spends capacity learning to ignore most of what it attends to.
The idea behind Mamba is to model sequences as dynamical systems rather than as sets of pairwise relationships. Instead of asking "how every token relate to every other token," it asks "how a hidden state evolve as it absorbs new inputs over time."
This is the continuous-time State Space Model (SSM).
The system is governed by two differential equations :
Where;;
-
$x(t) \in \mathbb{R}^d$ is the hidden state at time$t$ ; the system's memory. -
$u(t) \in \mathbb{R}$ is the input at time$t$ . -
$A \in \mathbb{R}^{d \times d}$ governs how the state evolves over time; the memory dynamics. -
$B \in \mathbb{R}^{d \times 1}$ governs how the input is written into the state. -
$C \in \mathbb{R}^{1 \times d}$ governs how the state is read out as the output$y(t)$ .
The equation
This comes from the definition of the matrix exponential. The scalar ODE
For the full system with input, the variation of parameters method gives:
This is the general solution. The first term is the initial state decayed/evolved by
Neural networks operate on discrete token sequences, not continuous signals. To use the SSM on sequences, we discretize with a learnable timestep
The discrete recurrence then becomes;
This is a standard linear recurrence. Given the state at
The most natural choice for
This creates three immediate problems:
Instability : For the recurrence
Parallelism Nightmare : A full
Too much Computation :
The solution is to constrain
Then:
The exponential of a diagonal matrix is just the diagonal matrix of exponentials.
Any matrix
In the full eigendecomposition, we need
So working in the diagonal basis is a reparameterization that separates the "remembered" (
The
1. Guaranteed stability when
Each diagonal entry is a decay factor strictly between 0 and 1. The state magnitude can never blow up. Stability is guaranteed by the parameterization, not by constrained optimization.
2. Elementwise multiplication instead of matrix multiply :
The state update becomes;
Each dimension
3. Interpretability of eigenvalues :
Classic SSMs have fixed
Selective SSMs (Mamba) make
Now the effective
The role of
-
$\Delta_t \to 0$ (small timestep):$\bar{A}_t = e^{-\Delta_t \Lambda} \to I$ and$\bar{B}_t \to 0$ . The state barely changes; the input is nearly ignored. The system "holds" what it has. -
$\Delta_t \to \infty$ (large timestep):$\bar{A}_t \to 0$ and$\bar{B}_t \to W_B^{-1}$ (ZOH limit). The old state is completely forgotten; the state becomes purely the new input. A hard reset. - Intermediate
$\Delta_t$ : partial memory retention. The system blends old state with new input in a learned, input-dependent ratio.
This is selective gating: the model learns when to remember, when to forget, and when to reset, based on the content of the current input.
The discrete recurrence;
is computed as a for-loop over
Each step depends on the previous. This is
The recurrence is an associative operation. We define a "state pair"
The combining rule for two consecutive pairs is :
Verifying :
After two steps;
The combined pair
Since the operation is associative, we can apply it like a tree reduction. For
Level 1 : Combine pairs
Level 2 : Combine
Level 3 : Combine
Total depth:
In practice this is computed via cumulative product and cumulative sum;
Step 1 -> Cumulative product of decay factors :
This gives the total accumulated decay from position 1 through position
Step 2 -> Weighted inputs :
Each input contribution
So the normalized contribution is
Step 3 -> Recover state :
Both torch.cumprod and torch.cumsum, which are fully parallel GPU primitives. The entire length
| Operation | Sequential Recurrence | Parallel Scan |
|---|---|---|
| Compute |
|
|
| Memory |
|
|
| GPU utilization | Near-zero (serial) | Near-full (parallel) |
Mamba is excellent at long-range memory and GPU-efficient inference.
But there are limitations :
No Bidirectional Context : The recurrence is causal by design;
Weak at precise token-to-token matching : Attention's
Fixed state dimension
The hybrid model interleaves Mamba and Transformer blocks within the same stack.
Each HybridBlock contains :
-
Causal Multi-Head Attention : handles precise token-to-token relationships, bidirectional pattern matching, and short-to-medium range dependencies where
$O(N^2)$ cost is acceptable. -
Mamba Parallel Scan : handles long-range context compression, memory-efficient state evolution, and provides
$O(N)$ memory cost that does not blow up the KV cache. -
FFN (GELU) : position-wise nonlinear feature mixing, shared across both components.
All three sublayers use Pre-LayerNorm and residual connections:
The key property is that the attention sublayer operates in
Together, the hybrid does what neither alone can; precise local attention along with efficient long-range memory.
Input: token sequence (Batch, N)
|
Token Embedding : vocab_size → d_model = 256 (Batch, N, 256)
|
HybridBlock × 4:
LayerNorm → MultiHeadAttention (8 heads, d_k = 32, causal mask)
QKᵀ / √32 + M_causal → softmax → ⊗ V
Residual add
LayerNorm → MambaScan (parallel scan via cumprod + cumsum)
Δt = σ(u) input-dependent timestep
A = exp(-Δt · λ) diagonal decay matrix
B = (1-g) · W_B(u) gated input projection
A = g · A gated decay
P = cumprod(A) accumulated decay
S = cumsum(B/P) normalized input sum
x = P · S state reconstruction
y = W_C(x) output projection
Residual add
LayerNorm → FFN (256 → 512 → 256, GELU)
Residual add
→ (Batch, N, 256)
Final LayerNorm → Linear(256 → vocab_size) → logits
Let
Attention sublayer :
The
Mamba sublayer (Parallel Scan) :
The parallel scan has
Inference with KV cache :
The Mamba sublayer requires no KV cache because its state
Full hybrid per layer :
For very long sequences where the attention component is applied only to a local window (sliding window attention), the Mamba component dominates and the hybrid achieves
Attention still quadratic within each block : The hybrid reduces the quadratic cost but does not eliminate it. If full global attention is applied at every layer across the full sequence, the memory bottleneck remains. The hybrid is most effective when attention is applied locally (to a window of recent tokens) and Mamba handles the long-range state.
Mamba parallel scan adds scan buffer overhead : The
Stability of the gated scan : The implementation adds 1e-6 to cumprod and 1e-6 to
No positional encoding in this implementation : Neither sinusoidal PE nor RoPE is applied. The Mamba component encodes position implicitly through the temporal recurrence. The attention component has no positional information at all; it processes all positions identically. This limits the hybrid's ability to learn position-sensitive patterns in the attention sublayer. Adding RoPE to Q and K (as in Llama) would improve this.
Fixed torch.randn(d_model), which places some
Dataset : WikiText-2 (raw v1), tokenized with a 10,000-token BPE vocabulary trained on the training split. Both models trained for 200 steps with AdamW (
$lr = 3 \times 10^{-4}$ ), batch size 8, sequence length 512. Benchmarks run at batch size 4 across sequence lengths${64, 128, 256, 512, 1024, 2048}$ on CUDA.
The three plots below show training perplexity, latency scaling, and throughput scaling as measured on the WikiText-2 benchmark run.
Perplexity : Perplexity measures how confidently a language model predicts the next token; lower is better, and it is the exponential of cross-entropy loss. It is the standard quality metric for language modelling.
-
Only one curve is prominently visible because both models' traces sit almost on top of each other, which is itself informative: the Mamba scan component does not hurt language modelling quality at this scale.
-
The curve drops sharply from
$10^4$ in the first 50 steps. This is the model learning BPE token frequency statistics, the easy part. The gradient is steep because even a crude unigram prior slashes perplexity dramatically. -
By step 100 the descent slows and plateaus around
$10^3$ . This is where syntactic and semantic structure must be learned, and 200 steps on WikiText-2 is nowhere near enough. The plateau is a training budget limit, not an architectural ceiling.
Latency Scaling : Latency measures wall-clock time for a single forward pass; plotted on a log-log scale so that power-law scaling appears as a straight line.
-
The Transformer shows a sharp anomalous dip around sequence length 64–128. This is a GPU warmup artifact, the CUDA kernel launch overhead dominates at very small batch-sequence products, so the first measurement is noisy and unreliable.
-
Above 256 tokens, both curves grow steeply and roughly in parallel. The Hybrid is consistently above the Transformer because it runs attention and the Mamba scan together; the scan adds to the forward pass rather than replacing attention.
-
The crossover where Mamba's
$O(N \log N)$ scaling wins over attention's$O(N^2)$ would only appear at sequence lengths well beyond 2048. At this model size ($d = 128$ , single-layer attention) the quadratic regime has not yet made attention catastrophically expensive.
Throughput : Throughput measures tokens processed per second; higher is better, and it captures how efficiently the model uses GPU parallelism across the batch.
-
The Transformer (blue) peaks around 256–512 tokens then declines. This is the point where the
$N^2$ attention matrix starts saturating memory bandwidth; the GPU spends more time moving data than computing. -
The Hybrid (orange) shows a flatter, lower curve throughout. It does strictly more work per forward pass (attention + scan + FFN vs. attention + FFN), so it processes fewer tokens per second at every sequence length tested.
-
The Hybrid does not show a throughput advantage here because the benchmark does not reach the regime where Mamba's linear memory cost would dominate. At
$N > 8{,}000$ with fixed memory, the Transformer begins to OOM while the Mamba state stays constant; that is where the curves would invert.
Raw per-model, per-sequence-length measurements from the benchmark run. Latency is single-forward-pass wall-clock time in seconds; VRAM is peak GPU memory allocated in MB; Throughput is tokens per second (
| SeqLen | Model | Latency (s) | VRAM (MB) | Throughput (tok/s) |
|---|---|---|---|---|
| 64 | Transformer | 0.016416 | 80.59 | 1.56 × 10⁴ |
| 64 | Hybrid | 0.001456 | 82.60 | 1.76 × 10⁵ |
| 128 | Transformer | 0.001230 | 103.41 | 4.16 × 10⁵ |
| 128 | Hybrid | 0.002068 | 107.42 | 2.48 × 10⁵ |
| 256 | Transformer | 0.001454 | 151.38 | 7.04 × 10⁵ |
| 256 | Hybrid | 0.002229 | 159.41 | 4.59 × 10⁵ |
| 512 | Transformer | 0.002007 | 251.49 | 1.02 × 10⁶ |
| 512 | Hybrid | 0.003202 | 267.20 | 6.40 × 10⁵ |
| 1024 | Transformer | 0.004619 | 482.42 | 8.87 × 10⁵ |
| 1024 | Hybrid | 0.005923 | 514.55 | 6.92 × 10⁵ |
| 2048 | Transformer | 0.013924 | 1063.99 | 5.88 × 10⁵ |
| 2048 | Hybrid | 0.015996 | 1287.74 | 5.12 × 10⁵ |
Interpretation of Metrics :
- At sequence lengths 64–256, the Hybrid's latency is comparable to or slightly better than the Transformer despite doing more computation.
cumprodandcumsumare highly parallelized CUDA primitives that hit peak throughput at small sizes where attention head computation has not yet saturated the device. - By 512 tokens, the Hybrid is reliably slower and more memory-hungry. The VRAM gap grows with sequence length; at 2048 tokens the Hybrid uses ~224 MB more than the Transformer.
- This gap grows because the scan buffers
$P$ and$S$ are both$(B, N, d)$ tensors that accumulate on top of the attention matrix. The Transformer's VRAM grows roughly quadratically (attention matrix); the Hybrid grows quadratically plus linearly (attention + scan), so the gap widens predictably. - In a production setting at 8k–128k sequence lengths, the Mamba component's
$O(d)$ constant inference state would begin to invert this relationship. That regime is not captured in this benchmark.
- The benchmark confirms theoretical predictions at small scale that the Hybrid adds latency and VRAM overhead at moderate sequence lengths because it runs both attention and the SSM, not one instead of the other.
- Perplexity trajectories for both models are nearly identical. The Mamba component does not hurt language modelling quality; it just has not yet had the chance to help it at this sequence length and training budget.
- The architectural advantage, constant KV-cache memory at inference and
$O(N \log N)$ long-range state, only becomes the dominant factor at sequence lengths well beyond what this 200-step WikiText-2 run covers. - The right experimental setting to observe the Hybrid's edge is fixed-memory inference at
$N > 4{,}096$ against a Transformer that has begun to OOM; that is where the Mamba component's$O(d)$ state pays off.

