Commit c0dbc38
authored
feat: Mamba2 SSD Combined Forward Pass (Blackwell CuTe DSL Kernel) (#2709)
<!-- .github/pull_request_template.md -->
## 📌 Description
This PR adds a high-performance **Mamba2 Structured State-Space Duality
(SSD) combined forward kernel** targeting Blackwell (SM100+) GPUs,
implemented using NVIDIA's CuTe DSL. The kernel fuses the entire SSD
chunk-scan forward pass — including cumulative sum preprocessing, chunk
state computation, inter-chunk state passing, and output projection —
into a single persistent kernel launch.
### Features
- **Datatypes**: bf16 I/O, bf16/fp16 state cache (fp32 state not yet
supported)
- **D tensor fusion**: optional additive bias with per-head
(`d_has_hdim=True`) or broadcast (`d_has_hdim=False`) modes
- **Z gating**: optional sigmoid gating on the output
- **Initial states**: optional user-provided initial hidden states
- **Variable-length sequences**: packed varlen via `seq_idx` +
precomputed chunk metadata
- **Batched mode**: uniform-length sequences without varlen overhead
- **CUDA graph compatible**: all allocations cached, no host-device sync
in the hot path
### Architecture
The kernel uses **16 warps** with warp-level task specialization:
| Warps | Role |
|-------|------|
| 0 | TMA loads (B, C, x, dt, A, D, z) |
| 1–4 | MMA: CB = C^T B (state-space matrix) |
| 5–8 | MMA: P = CB * decay (intra-chunk transitions) |
| 9 | Preprocessing (dt softplus/limit, dA cumsum, decay masks) |
| 10–13 | MMA: states = B^T (x * dt) + decay * prev_state |
| 14–15 | Epilogue MMA: out = C * state, write output via TMA |
A persistent tile scheduler distributes chunks across SMs.
Shape-polymorphic compilation (via CuTe DSL fake tensors with symbolic
dims) means **one compilation covers all batch/seqlen combinations**.
### Files Added/Modified
| Area | Key files |
|------|-----------|
| Python API | `flashinfer/mamba/ssd_combined.py` — `SSDCombined` class
|
| CuTe DSL kernel | `flashinfer/mamba/ssd_kernel.py` (4.3k lines) |
| Tile scheduler | `flashinfer/mamba/ssd_tile_scheduler.py` |
| Cumsum preprocessing | `flashinfer/triton/kernels/ssd_chunk_state.py`
(Triton) |
| Seq-chunk metadata | `include/flashinfer/mamba/seq_chunk_cumsum.cuh`,
`csrc/seq_chunk_cumsum.cu` |
| JIT integration | `flashinfer/jit/mamba/seq_chunk_cumsum.py` |
| Tests | `tests/mamba/test_chunk_scan_combined.py` (1.9k lines) |
| Triton reference | `tests/mamba/triton_reference/ssd_*.py` |
| Benchmark | `benchmarks/bench_mamba_ssd_combined.py` |
### Performance
All numbers on **NVIDIA B200**, bf16, `chunk_size=128, nheads=8,
headdim=64, dstate=128, ngroups=8`. The timings were collected using
cuda-graphs.
**Column definitions:**
- **batch** / **num_seqs**: number of independent sequences (batched) or
packed user sequences (varlen)
- **chunks/seq**: number of `chunk_size`-token chunks per sequence
(`seqlen / chunk_size`)
- **total chunks**: total chunk count across all sequences (`batch *
chunks/seq`)
- **total seqlen**: total number of tokens processed (`batch *
chunks/seq * chunk_size`)
- **FlashInfer (ms)**: end-to-end wall time for the fused CuTe DSL
kernel (including Triton cumsum preprocessing)
- **Triton (ms)**: end-to-end wall time for the Triton reference (5
separate kernel launches)
- **Speedup**: `Triton / FlashInfer` (>1x means FlashInfer is faster)
#### Batched mode (uniform sequence lengths, no initial states)
| batch | chunks/seq | total chunks | total seqlen | FlashInfer (ms) |
Triton (ms) | Speedup |
|------:|----------:|-----------:|------------:|----------------:|------------:|--------:|
| 1 | 1 | 1 | 128 | 0.012 | 0.016 | 1.28x |
| 1 | 4 | 4 | 512 | 0.032 | 0.018 | 0.57x |
| 1 | 16 | 16 | 2,048 | 0.112 | 0.031 | 0.28x |
| 1 | 64 | 64 | 8,192 | 0.425 | 0.080 | 0.19x |
| 1 | 256 | 256 | 32,768 | 1.675 | 0.317 | 0.19x |
| 4 | 1 | 1 | 128 | 0.013 | 0.018 | 1.44x |
| 4 | 4 | 4 | 512 | 0.034 | 0.029 | 0.87x |
| 4 | 16 | 16 | 2,048 | 0.112 | 0.073 | 0.65x |
| 4 | 64 | 64 | 8,192 | 0.426 | 0.255 | 0.60x |
| 16 | 1 | 1 | 128 | 0.015 | 0.033 | 2.22x |
| 16 | 4 | 4 | 512 | 0.036 | 0.074 | 2.05x |
| 16 | 16 | 16 | 2,048 | 0.116 | 0.252 | 2.17x |
| 64 | 1 | 1 | 128 | 0.038 | 0.095 | 2.50x |
| 64 | 4 | 4 | 512 | 0.118 | 0.261 | 2.21x |
| 64 | 16 | 16 | 2,048 | 0.437 | 0.952 | 2.18x |
| 128 | 1 | 1 | 128 | 0.062 | 0.173 | 2.81x |
| 128 | 4 | 4 | 512 | 0.201 | 0.502 | 2.49x |
| 128 | 16 | 16 | 2,048 | 0.759 | 1.881 | 2.48x |
| 256 | 1 | 1 | 128 | 0.113 | 0.326 | 2.87x |
| 256 | 4 | 4 | 512 | 0.391 | 0.982 | 2.51x |
| 256 | 16 | 16 | 2,048 | 1.507 | 3.736 | 2.48x |
| 512 | 2 | 2 | 256 | 0.397 | 1.028 | 2.59x |
> At low batch / many chunks per sequence (batch=1, long sequences), the
Triton reference is faster because it launches 5 separate small kernels
with lower per-kernel overhead. The fused CuTe kernel wins at **batch >=
16** where parallelism across sequences saturates the GPU.
#### Varlen mode (packed sequences, with initial states)
| num_seqs | chunks/seq | total chunks | total seqlen | FlashInfer (ms)
| Triton (ms) | Speedup |
|---------:|----------:|-----------:|------------:|----------------:|------------:|--------:|
| 1 | 1 | 1 | 128 | 0.022 | 0.027 | 1.22x |
| 4 | 1 | 4 | 512 | 0.023 | 0.030 | 1.31x |
| 8 | 1 | 8 | 1,024 | 0.024 | 0.041 | 1.71x |
| 32 | 1 | 32 | 4,096 | 0.046 | 0.118 | 2.56x |
| 64 | 1 | 64 | 8,192 | 0.080 | 0.211 | 2.63x |
| 128 | 1 | 128 | 16,384 | 0.134 | 0.388 | 2.90x |
| 256 | 1 | 256 | 32,768 | 0.259 | 0.753 | 2.91x |
| 4 | 8 | 32 | 4,096 | 0.119 | 0.103 | 0.87x |
| 8 | 8 | 64 | 8,192 | 0.121 | 0.177 | 1.46x |
| 16 | 8 | 128 | 16,384 | 0.124 | 0.335 | 2.70x |
| 32 | 8 | 256 | 32,768 | 0.237 | 0.657 | 2.77x |
| 64 | 8 | 512 | 65,536 | 0.463 | 1.283 | 2.77x |
| 32 | 32 | 1,024 | 131,072 | 0.896 | 2.486 | 2.77x |
| 64 | 32 | 2,048 | 262,144 | 1.771 | 4.936 | 2.79x |
| 128 | 32 | 4,096 | 524,288 | 3.099 | 9.819 | 3.17x |
> In the serving-relevant regime (many short sequences packed together),
FlashInfer is consistently **2.5–3.2x faster** than Triton. The
single-fused-kernel design amortizes launch overhead across all packed
sequences.
### Reproducing Benchmarks
The benchmark script is included at
`benchmarks/bench_mamba_ssd_combined.py`:
```bash
# Batched mode (CUDA graphs)
python benchmarks/bench_mamba_ssd_combined.py --batched --cuda_graph
# Varlen mode (CUDA graphs)
python benchmarks/bench_mamba_ssd_combined.py --varlen --cuda_graph
```
### Limitations / Future Work
- **fp32 state cache** not yet supported (bf16/fp16 only)
- **Forward pass only** — backward pass not included
- **SM100+ required** — Blackwell only (CuTe DSL)
- Low-batch / long-sequence regime (batch=1, many chunks) is slower than
Triton due to persistent-kernel overhead vs. 5 small kernel launches
## 🔍 Related Issues
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
The kernel is Blackwell-only. Please check if I handled all the imports
correctly.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* New SSDCombined implementation with varlen, initial-states, optional
scaling and gating, plus GPU-backed chunked primitives and a tile
scheduler for faster chunked SSM/attention workloads.
* **Benchmarks**
* Added a CLI benchmarking tool with single/multi-config sweeps and
profiling modes (NCU, profiler) for per-configuration comparisons.
* **Tests**
* Large test suite covering batched/varlen paths, dtype combinations,
gating, initial-states, and end-to-end correctness.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->1 parent f01e83e commit c0dbc38
18 files changed
Lines changed: 10983 additions & 0 deletions
File tree
- benchmarks
- csrc
- flashinfer
- jit/mamba
- mamba
- triton/kernels
- include/flashinfer/mamba
- tests/mamba
- triton_reference
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
24 | 25 | | |
| 26 | + | |
25 | 27 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
0 commit comments