Skip to content

Commit c0dbc38

Browse files
authored
feat: Mamba2 SSD Combined Forward Pass (Blackwell CuTe DSL Kernel) (#2709)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR adds a high-performance **Mamba2 Structured State-Space Duality (SSD) combined forward kernel** targeting Blackwell (SM100+) GPUs, implemented using NVIDIA's CuTe DSL. The kernel fuses the entire SSD chunk-scan forward pass — including cumulative sum preprocessing, chunk state computation, inter-chunk state passing, and output projection — into a single persistent kernel launch. ### Features - **Datatypes**: bf16 I/O, bf16/fp16 state cache (fp32 state not yet supported) - **D tensor fusion**: optional additive bias with per-head (`d_has_hdim=True`) or broadcast (`d_has_hdim=False`) modes - **Z gating**: optional sigmoid gating on the output - **Initial states**: optional user-provided initial hidden states - **Variable-length sequences**: packed varlen via `seq_idx` + precomputed chunk metadata - **Batched mode**: uniform-length sequences without varlen overhead - **CUDA graph compatible**: all allocations cached, no host-device sync in the hot path ### Architecture The kernel uses **16 warps** with warp-level task specialization: | Warps | Role | |-------|------| | 0 | TMA loads (B, C, x, dt, A, D, z) | | 1–4 | MMA: CB = C^T B (state-space matrix) | | 5–8 | MMA: P = CB * decay (intra-chunk transitions) | | 9 | Preprocessing (dt softplus/limit, dA cumsum, decay masks) | | 10–13 | MMA: states = B^T (x * dt) + decay * prev_state | | 14–15 | Epilogue MMA: out = C * state, write output via TMA | A persistent tile scheduler distributes chunks across SMs. Shape-polymorphic compilation (via CuTe DSL fake tensors with symbolic dims) means **one compilation covers all batch/seqlen combinations**. ### Files Added/Modified | Area | Key files | |------|-----------| | Python API | `flashinfer/mamba/ssd_combined.py` — `SSDCombined` class | | CuTe DSL kernel | `flashinfer/mamba/ssd_kernel.py` (4.3k lines) | | Tile scheduler | `flashinfer/mamba/ssd_tile_scheduler.py` | | Cumsum preprocessing | `flashinfer/triton/kernels/ssd_chunk_state.py` (Triton) | | Seq-chunk metadata | `include/flashinfer/mamba/seq_chunk_cumsum.cuh`, `csrc/seq_chunk_cumsum.cu` | | JIT integration | `flashinfer/jit/mamba/seq_chunk_cumsum.py` | | Tests | `tests/mamba/test_chunk_scan_combined.py` (1.9k lines) | | Triton reference | `tests/mamba/triton_reference/ssd_*.py` | | Benchmark | `benchmarks/bench_mamba_ssd_combined.py` | ### Performance All numbers on **NVIDIA B200**, bf16, `chunk_size=128, nheads=8, headdim=64, dstate=128, ngroups=8`. The timings were collected using cuda-graphs. **Column definitions:** - **batch** / **num_seqs**: number of independent sequences (batched) or packed user sequences (varlen) - **chunks/seq**: number of `chunk_size`-token chunks per sequence (`seqlen / chunk_size`) - **total chunks**: total chunk count across all sequences (`batch * chunks/seq`) - **total seqlen**: total number of tokens processed (`batch * chunks/seq * chunk_size`) - **FlashInfer (ms)**: end-to-end wall time for the fused CuTe DSL kernel (including Triton cumsum preprocessing) - **Triton (ms)**: end-to-end wall time for the Triton reference (5 separate kernel launches) - **Speedup**: `Triton / FlashInfer` (>1x means FlashInfer is faster) #### Batched mode (uniform sequence lengths, no initial states) | batch | chunks/seq | total chunks | total seqlen | FlashInfer (ms) | Triton (ms) | Speedup | |------:|----------:|-----------:|------------:|----------------:|------------:|--------:| | 1 | 1 | 1 | 128 | 0.012 | 0.016 | 1.28x | | 1 | 4 | 4 | 512 | 0.032 | 0.018 | 0.57x | | 1 | 16 | 16 | 2,048 | 0.112 | 0.031 | 0.28x | | 1 | 64 | 64 | 8,192 | 0.425 | 0.080 | 0.19x | | 1 | 256 | 256 | 32,768 | 1.675 | 0.317 | 0.19x | | 4 | 1 | 1 | 128 | 0.013 | 0.018 | 1.44x | | 4 | 4 | 4 | 512 | 0.034 | 0.029 | 0.87x | | 4 | 16 | 16 | 2,048 | 0.112 | 0.073 | 0.65x | | 4 | 64 | 64 | 8,192 | 0.426 | 0.255 | 0.60x | | 16 | 1 | 1 | 128 | 0.015 | 0.033 | 2.22x | | 16 | 4 | 4 | 512 | 0.036 | 0.074 | 2.05x | | 16 | 16 | 16 | 2,048 | 0.116 | 0.252 | 2.17x | | 64 | 1 | 1 | 128 | 0.038 | 0.095 | 2.50x | | 64 | 4 | 4 | 512 | 0.118 | 0.261 | 2.21x | | 64 | 16 | 16 | 2,048 | 0.437 | 0.952 | 2.18x | | 128 | 1 | 1 | 128 | 0.062 | 0.173 | 2.81x | | 128 | 4 | 4 | 512 | 0.201 | 0.502 | 2.49x | | 128 | 16 | 16 | 2,048 | 0.759 | 1.881 | 2.48x | | 256 | 1 | 1 | 128 | 0.113 | 0.326 | 2.87x | | 256 | 4 | 4 | 512 | 0.391 | 0.982 | 2.51x | | 256 | 16 | 16 | 2,048 | 1.507 | 3.736 | 2.48x | | 512 | 2 | 2 | 256 | 0.397 | 1.028 | 2.59x | > At low batch / many chunks per sequence (batch=1, long sequences), the Triton reference is faster because it launches 5 separate small kernels with lower per-kernel overhead. The fused CuTe kernel wins at **batch >= 16** where parallelism across sequences saturates the GPU. #### Varlen mode (packed sequences, with initial states) | num_seqs | chunks/seq | total chunks | total seqlen | FlashInfer (ms) | Triton (ms) | Speedup | |---------:|----------:|-----------:|------------:|----------------:|------------:|--------:| | 1 | 1 | 1 | 128 | 0.022 | 0.027 | 1.22x | | 4 | 1 | 4 | 512 | 0.023 | 0.030 | 1.31x | | 8 | 1 | 8 | 1,024 | 0.024 | 0.041 | 1.71x | | 32 | 1 | 32 | 4,096 | 0.046 | 0.118 | 2.56x | | 64 | 1 | 64 | 8,192 | 0.080 | 0.211 | 2.63x | | 128 | 1 | 128 | 16,384 | 0.134 | 0.388 | 2.90x | | 256 | 1 | 256 | 32,768 | 0.259 | 0.753 | 2.91x | | 4 | 8 | 32 | 4,096 | 0.119 | 0.103 | 0.87x | | 8 | 8 | 64 | 8,192 | 0.121 | 0.177 | 1.46x | | 16 | 8 | 128 | 16,384 | 0.124 | 0.335 | 2.70x | | 32 | 8 | 256 | 32,768 | 0.237 | 0.657 | 2.77x | | 64 | 8 | 512 | 65,536 | 0.463 | 1.283 | 2.77x | | 32 | 32 | 1,024 | 131,072 | 0.896 | 2.486 | 2.77x | | 64 | 32 | 2,048 | 262,144 | 1.771 | 4.936 | 2.79x | | 128 | 32 | 4,096 | 524,288 | 3.099 | 9.819 | 3.17x | > In the serving-relevant regime (many short sequences packed together), FlashInfer is consistently **2.5–3.2x faster** than Triton. The single-fused-kernel design amortizes launch overhead across all packed sequences. ### Reproducing Benchmarks The benchmark script is included at `benchmarks/bench_mamba_ssd_combined.py`: ```bash # Batched mode (CUDA graphs) python benchmarks/bench_mamba_ssd_combined.py --batched --cuda_graph # Varlen mode (CUDA graphs) python benchmarks/bench_mamba_ssd_combined.py --varlen --cuda_graph ``` ### Limitations / Future Work - **fp32 state cache** not yet supported (bf16/fp16 only) - **Forward pass only** — backward pass not included - **SM100+ required** — Blackwell only (CuTe DSL) - Low-batch / long-sequence regime (batch=1, many chunks) is slower than Triton due to persistent-kernel overhead vs. 5 small kernel launches ## 🔍 Related Issues ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes The kernel is Blackwell-only. Please check if I handled all the imports correctly. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * New SSDCombined implementation with varlen, initial-states, optional scaling and gating, plus GPU-backed chunked primitives and a tile scheduler for faster chunked SSM/attention workloads. * **Benchmarks** * Added a CLI benchmarking tool with single/multi-config sweeps and profiling modes (NCU, profiler) for per-configuration comparisons. * **Tests** * Large test suite covering batched/varlen paths, dtype combinations, gating, initial-states, and end-to-end correctness. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent f01e83e commit c0dbc38

18 files changed

Lines changed: 10983 additions & 0 deletions

benchmarks/bench_mamba_ssd_combined.py

Lines changed: 688 additions & 0 deletions
Large diffs are not rendered by default.

csrc/seq_chunk_cumsum.cu

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "flashinfer/mamba/seq_chunk_cumsum.cuh"
17+
#include "tvm_ffi_utils.h"
18+
19+
using namespace flashinfer::mamba;
20+
using tvm::ffi::Optional;
21+
22+
void seq_chunk_cumsum(TensorView seq_idx, TensorView chunk_indices, TensorView chunk_offsets,
23+
TensorView output, Optional<TensorView> tile_state, int64_t chunk_size,
24+
int64_t num_logical_chunks, int64_t num_seqs) {
25+
CHECK_INPUT(seq_idx);
26+
CHECK_INPUT(chunk_indices);
27+
CHECK_INPUT(chunk_offsets);
28+
CHECK_INPUT(output);
29+
30+
auto stream = get_stream(seq_idx.device());
31+
32+
uint8_t* tile_state_ptr = nullptr;
33+
std::size_t tile_state_size = 0;
34+
if (tile_state.has_value()) {
35+
CHECK_INPUT(tile_state.value());
36+
tile_state_ptr = static_cast<uint8_t*>(tile_state.value().data_ptr());
37+
tile_state_size = static_cast<std::size_t>(tile_state.value().shape()[0]);
38+
}
39+
40+
cudaError_t status;
41+
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(seq_idx.dtype(), SeqIdxT, [&] {
42+
status = SeqChunkCumsumLauncher(static_cast<const SeqIdxT*>(seq_idx.data_ptr()),
43+
static_cast<const int32_t*>(chunk_indices.data_ptr()),
44+
static_cast<const int32_t*>(chunk_offsets.data_ptr()),
45+
static_cast<int32_t*>(output.data_ptr()), tile_state_ptr,
46+
tile_state_size, static_cast<int>(chunk_size),
47+
static_cast<int>(num_logical_chunks),
48+
static_cast<int>(num_seqs), stream);
49+
return true;
50+
});
51+
52+
TVM_FFI_ICHECK(status == cudaSuccess)
53+
<< "SeqChunkCumsumLauncher failed: " << cudaGetErrorString(status);
54+
}
55+
56+
int64_t seq_chunk_cumsum_tile_state_size(int64_t num_seqs) {
57+
return static_cast<int64_t>(SeqChunkCumsumWorkspaceSize(static_cast<int>(num_seqs)));
58+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "tvm_ffi_utils.h"
17+
18+
using tvm::ffi::Optional;
19+
20+
void seq_chunk_cumsum(TensorView seq_idx, TensorView chunk_indices, TensorView chunk_offsets,
21+
TensorView output, Optional<TensorView> tile_state, int64_t chunk_size,
22+
int64_t num_logical_chunks, int64_t num_seqs);
23+
24+
int64_t seq_chunk_cumsum_tile_state_size(int64_t num_seqs);
25+
26+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(seq_chunk_cumsum, seq_chunk_cumsum);
27+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(seq_chunk_cumsum_tile_state_size, seq_chunk_cumsum_tile_state_size);

flashinfer/jit/mamba/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
gen_selective_state_update_module,
1919
gen_selective_state_update_sm90_module,
2020
)
21+
from .seq_chunk_cumsum import gen_seq_chunk_cumsum_module
2122

2223
__all__ = [
2324
"gen_selective_state_update_module",
2425
"gen_selective_state_update_sm90_module",
26+
"gen_seq_chunk_cumsum_module",
2527
]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .. import env as jit_env
18+
from ..core import JitSpec, gen_jit_spec
19+
20+
21+
def gen_seq_chunk_cumsum_module() -> JitSpec:
22+
"""Generate JIT module for seq_chunk_cumsum kernel.
23+
24+
No Jinja, no dtype parameterization — everything is int32.
25+
No architecture restrictions — plain CUDA (no tensor cores).
26+
"""
27+
return gen_jit_spec(
28+
"mamba_seq_chunk_cumsum",
29+
[
30+
jit_env.FLASHINFER_CSRC_DIR / "seq_chunk_cumsum.cu",
31+
jit_env.FLASHINFER_CSRC_DIR / "seq_chunk_cumsum_jit_binding.cu",
32+
],
33+
)

flashinfer/mamba/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,10 @@
1717
from .selective_state_update import selective_state_update
1818

1919
__all__ = ["selective_state_update"]
20+
21+
try:
22+
from .ssd_combined import SSDCombined
23+
24+
__all__.append("SSDCombined")
25+
except ImportError:
26+
pass

0 commit comments

Comments
 (0)