Skip to content

Commit 78a0091

Browse files
ishovkunclaude
andauthored
Mamba SSU: better automatic kernel selection + algorithm selection optionally exposed to the user. (flashinfer-ai#2591)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR does several things: - Improves automatic kernel selection based on the arch, state_dtype, and the batch size (see image below). - Slightly improves performance at small batch sizes by launching several CTAs per tile. - Adds jinja templates for the `selective_state_update` function (jit is fast now) - Reduces the number of meaningless parameter combinations in the tests (test are still fast) ## Background This PR changes changes the behavior of the function. Now, an optional string `algorithm` can be passed to the kernel. The default value 'auto' allows the user not to think about the internals of the function. Optionally, the user can specify the kernel that they want. This adjustment allowed me to make use of the recent mamba benchmarks. The sweep is shown below: <!-- Link any related issues here --> <img width="3600" height="3000" alt="ssu_speedup_vs_batch_size_NVIDIA_B200" src="https://github.com/user-attachments/assets/11325924-64e2-48e7-8fee-7244cdd7a893" /> One can see, that the new benchmark now correctly shows the speed difference between the reference Triton and the current implementation as opposed to my previous [PR](flashinfer-ai#2301). Clearly, I previously messed up the measurements at small batch sizes. ## Kernel Selection This PR improves the kernel selection in the following ways: - If the problem size is too small, use the `simple` algorithm with several CTAs per tile. - If on Blackwell, use the `horizontal` algorithm only for bf16/fp16 states, else fall back to the `vertical`. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes Please check how I handled jinja templates as it's my first time using those. Also, please check whether I accidentally deleted any important tests. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Runtime-selectable algorithm for selective state update: auto (default), simple, vertical, horizontal. * **Bug Fixes** * Added runtime validation to ensure index/dtype consistency across execution paths. * **Chores** * JIT/module generation reworked to produce specialized builds per dtype/dimension and target architectures. * Public API unified to select appropriate compiled module based on device and data. * **Tests** * Expanded and parameterized tests covering algorithms, dtypes, tiling, intermediate states, and large batches. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ebf8a71 commit 78a0091

18 files changed

Lines changed: 1151 additions & 996 deletions

benchmarks/routines/mamba.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,8 @@
1414
limitations under the License.
1515
"""
1616

17-
# ==============================================================================
18-
# Triton reference implementation for selective_state_update.
19-
# Imported from tests/mamba/selective_state_update_triton.py to avoid code
20-
# duplication. See that file for the canonical Triton kernel source.
21-
# ==============================================================================
22-
23-
import importlib
2417
import os
18+
import sys
2519
from collections import defaultdict
2620

2721
import numpy as np
@@ -30,6 +24,14 @@
3024
import flashinfer
3125
from flashinfer.testing.utils import bench_gpu_time
3226

27+
# Add tests/mamba to sys.path so triton_reference is importable as a package
28+
_repo_root = os.path.normpath(
29+
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..")
30+
)
31+
_tests_mamba = os.path.join(_repo_root, "tests", "mamba")
32+
if _tests_mamba not in sys.path:
33+
sys.path.insert(0, _tests_mamba)
34+
3335
from .flashinfer_benchmark_utils import (
3436
dtype_str_to_torch_dtype,
3537
get_device,
@@ -38,40 +40,9 @@
3840
filter_backends_by_compute_capability,
3941
)
4042

41-
# ---- Import Triton reference kernel from tests/mamba/ ----
42-
# The canonical Triton selective_state_update lives in tests/mamba/selective_state_update_triton.py.
43-
# We import it here rather than duplicating ~400 lines of kernel code.
44-
45-
46-
def _import_triton_reference():
47-
"""Import selective_state_update_triton from tests/mamba/.
48-
49-
Uses importlib to load the module directly by file path, avoiding sys.path
50-
pollution and fragile relative path assumptions.
51-
"""
52-
# Resolve path: benchmarks/routines/mamba.py -> ../../tests/mamba/selective_state_update_triton.py
53-
_this_dir = os.path.dirname(os.path.abspath(__file__))
54-
_repo_root = os.path.normpath(os.path.join(_this_dir, "..", ".."))
55-
_triton_ref_path = os.path.join(
56-
_repo_root, "tests", "mamba", "selective_state_update_triton.py"
57-
)
58-
59-
if not os.path.isfile(_triton_ref_path):
60-
raise ImportError(
61-
f"Cannot find Triton reference kernel at: {_triton_ref_path}\n"
62-
f"Expected location: <repo>/tests/mamba/selective_state_update_triton.py\n"
63-
f"Make sure you are running from within the FlashInfer repository."
64-
)
65-
66-
spec = importlib.util.spec_from_file_location(
67-
"selective_state_update_triton", _triton_ref_path
68-
)
69-
module = importlib.util.module_from_spec(spec)
70-
spec.loader.exec_module(module)
71-
return module.selective_state_update_triton
72-
73-
74-
selective_state_update_triton_reference = _import_triton_reference()
43+
from triton_reference.selective_state_update import (
44+
selective_state_update_triton as selective_state_update_triton_reference,
45+
)
7546

7647

7748
# ==============================================================================

csrc/flashinfer_mamba_binding.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ void selective_state_update(
4242
bool disable_state_update,
4343
Optional<TensorView> intermediate_states_buffer, // (batch, cache_steps, nheads, dim, dstate)
4444
Optional<TensorView> intermediate_state_indices, // (batch,)
45-
int64_t cache_steps);
45+
int64_t cache_steps,
46+
int64_t algorithm); // SSUAlgorithm: 0=auto, 1=simple, 2=vertical, 3=horizontal
4647

4748
} // namespace flashinfer::mamba
4849

csrc/selective_state_update.cu

Lines changed: 30 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
// clang-format off
17+
// config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP
18+
// constexprs that the header's function templates rely on. Reordering breaks compilation.
19+
// NOTE: the .inc file is generated from the jinja templates
20+
#include "selective_state_update_config.inc"
1621
#include <flashinfer/mamba/selective_state_update.cuh>
17-
#include <sstream>
18-
22+
// clang-format on
1923
#include "tvm_ffi_utils.h"
2024

2125
using namespace flashinfer;
@@ -124,87 +128,13 @@ inline void validate_dtype_consistency(
124128
}
125129
}
126130

127-
// Helper to convert dtype code to string for error messages
128-
inline const char* dtype_code_to_string(int64_t code) {
129-
if (code == bfloat16_code) return "bfloat16";
130-
if (code == float16_code) return "float16";
131-
if (code == float32_code) return "float32";
132-
return "unknown";
133-
}
134-
135-
// Type traits to map dtype codes to C++ types
136-
template <int64_t code>
137-
struct DTypeToType;
138-
139-
template <>
140-
struct DTypeToType<bfloat16_code> {
141-
using type = nv_bfloat16;
142-
};
143-
template <>
144-
struct DTypeToType<float16_code> {
145-
using type = half;
146-
};
147-
template <>
148-
struct DTypeToType<float32_code> {
149-
using type = float;
150-
};
151-
template <>
152-
struct DTypeToType<int32_code> {
153-
using type = int32_t;
154-
};
155-
template <>
156-
struct DTypeToType<int64_code> {
157-
using type = int64_t;
158-
};
159-
160-
// Allowed dtype combinations: {state_code, input_code, weight_code, matrixA_code, stateIndex_code}
161-
constexpr std::tuple<int64_t, int64_t, int64_t, int64_t, int64_t> allowed_dtype_combos[] = {
162-
{bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int32_code},
163-
{float16_code, bfloat16_code, bfloat16_code, float32_code, int32_code},
164-
{float32_code, bfloat16_code, bfloat16_code, float32_code, int32_code},
165-
{bfloat16_code, bfloat16_code, float32_code, float32_code, int32_code},
166-
{float16_code, bfloat16_code, float32_code, float32_code, int32_code},
167-
{float32_code, bfloat16_code, float32_code, float32_code, int32_code},
168-
{bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int64_code},
169-
{float16_code, bfloat16_code, bfloat16_code, float32_code, int64_code},
170-
{float32_code, bfloat16_code, bfloat16_code, float32_code, int64_code},
171-
{bfloat16_code, bfloat16_code, float32_code, float32_code, int64_code},
172-
{float16_code, bfloat16_code, float32_code, float32_code, int64_code},
173-
{float32_code, bfloat16_code, float32_code, float32_code, int64_code},
174-
};
175-
176-
// Helper to dispatch to the right template instantiation for STP
177-
template <int64_t state_code, int64_t input_code, int64_t weight_code, int64_t matrixA_code,
178-
int64_t stateIndex_code>
179-
void dispatchCombo(SelectiveStateUpdateParams& p, cudaStream_t stream) {
180-
using state_t = typename DTypeToType<state_code>::type;
181-
using input_t = typename DTypeToType<input_code>::type;
182-
using weight_t = typename DTypeToType<weight_code>::type;
183-
using matrixA_t = typename DTypeToType<matrixA_code>::type;
184-
using stateIndex_t = typename DTypeToType<stateIndex_code>::type;
185-
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p, stream);
186-
}
187-
188-
// Helper to dispatch to the right template instantiation for MTP
189-
template <int64_t state_code, int64_t input_code, int64_t weight_code, int64_t matrixA_code,
190-
int64_t stateIndex_code>
191-
void dispatchComboMTP(mtp::SelectiveStateMTPParams& p, cudaStream_t stream) {
192-
using state_t = typename DTypeToType<state_code>::type;
193-
using input_t = typename DTypeToType<input_code>::type;
194-
using weight_t = typename DTypeToType<weight_code>::type;
195-
using matrixA_t = typename DTypeToType<matrixA_code>::type;
196-
using stateIndex_t = typename DTypeToType<stateIndex_code>::type;
197-
mtp::invokeSelectiveStateUpdateMTP<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p,
198-
stream);
199-
}
200-
201131
void run_selective_state_update_stp(TensorView const& state, TensorView const& x,
202132
TensorView const& dt, TensorView const& A, TensorView const& B,
203133
TensorView const& C, TensorView const& D,
204134
Optional<TensorView> z, Optional<TensorView> dt_bias,
205135
bool dt_softplus, Optional<TensorView> state_batch_indices,
206136
int64_t pad_slot_id, Optional<TensorView> out,
207-
bool disable_state_update) {
137+
bool disable_state_update, int64_t algorithm) {
208138
// Extract dimensions from input tensors
209139
auto const batch = x.size(0);
210140
auto const state_cache_size = state.size(0);
@@ -344,64 +274,8 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x
344274
ffi::CUDADeviceGuard device_guard(state.device().device_id);
345275
const cudaStream_t stream = get_stream(state.device());
346276

347-
// Dispatch based on dtype combination
348-
DLDataType state_dtype = state.dtype();
349-
DLDataType input_dtype = x.dtype();
350-
DLDataType weight_dtype = dt.dtype();
351-
DLDataType matrixA_dtype = A.dtype();
352-
int64_t state_dtype_code = encode_dlpack_dtype(state_dtype);
353-
int64_t input_dtype_code = encode_dlpack_dtype(input_dtype);
354-
int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype);
355-
int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype);
356-
357-
// Get state_batch_indices dtype, default to int32 if not provided
358-
int64_t stateIndex_dtype_code = int32_code;
359-
if (state_batch_indices.has_value()) {
360-
DLDataType stateIndex_dtype = state_batch_indices.value().dtype();
361-
stateIndex_dtype_code = encode_dlpack_dtype(stateIndex_dtype);
362-
}
363-
364-
// Dispatch kernel based on dtype combination
365-
auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code,
366-
matrixA_dtype_code, stateIndex_dtype_code);
367-
368-
// Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion
369-
auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool {
370-
constexpr size_t I = decltype(idx)::value;
371-
if constexpr (I < std::size(allowed_dtype_combos)) {
372-
constexpr auto combo = allowed_dtype_combos[I];
373-
if (key == combo) {
374-
constexpr auto s = std::get<0>(combo);
375-
constexpr auto i = std::get<1>(combo);
376-
constexpr auto w = std::get<2>(combo);
377-
constexpr auto m = std::get<3>(combo);
378-
constexpr auto si = std::get<4>(combo);
379-
dispatchCombo<s, i, w, m, si>(p, stream);
380-
return true;
381-
}
382-
return self(key, std::integral_constant<size_t, I + 1>{}, self);
383-
}
384-
return false;
385-
};
386-
387-
// Dispatch using compile-time type traits
388-
if (!tryDispatch(dtype_key, std::integral_constant<size_t, 0>{}, tryDispatch)) {
389-
// Unsupported dtype combination - build error message dynamically
390-
std::ostringstream error_msg;
391-
error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype="
392-
<< state_dtype.code << ":" << state_dtype.bits << ", "
393-
<< "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", "
394-
<< "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", "
395-
<< "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits
396-
<< ". Supported combos include:\n";
397-
for (const auto& combo : allowed_dtype_combos) {
398-
error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo))
399-
<< ", input=" << dtype_code_to_string(std::get<1>(combo))
400-
<< ", weight=" << dtype_code_to_string(std::get<2>(combo))
401-
<< ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n";
402-
}
403-
TVM_FFI_ICHECK(false) << error_msg.str();
404-
}
277+
auto algo = static_cast<SSUAlgorithm>(algorithm);
278+
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p, algo, stream);
405279
}
406280

407281
void run_selective_state_update_mtp(
@@ -410,7 +284,7 @@ void run_selective_state_update_mtp(
410284
Optional<TensorView> dt_bias, bool dt_softplus, Optional<TensorView> state_batch_indices,
411285
int64_t pad_slot_id, Optional<TensorView> out, bool disable_state_update,
412286
Optional<TensorView> intermediate_states_buffer,
413-
Optional<TensorView> intermediate_state_indices, int64_t cache_steps) {
287+
Optional<TensorView> intermediate_state_indices, int64_t cache_steps, int64_t algorithm) {
414288
// Extract dimensions from input tensors
415289
auto const batch = x.size(0);
416290
auto const ntokens_mtp = x.size(1);
@@ -505,6 +379,15 @@ void run_selective_state_update_mtp(
505379
validate_intermediate_state_indices(intermediate_state_indices, batch);
506380
validate_intermediate_states_buffer(intermediate_states_buffer);
507381

382+
// Validate that state_batch_indices and intermediate_state_indices have the same dtype
383+
if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) {
384+
DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype();
385+
DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype();
386+
FLASHINFER_CHECK(state_batch_idx_dtype.code == intermediate_idx_dtype.code &&
387+
state_batch_idx_dtype.bits == intermediate_idx_dtype.bits,
388+
"state_batch_indices and intermediate_state_indices must have the same dtype");
389+
}
390+
508391
// Validate cache_steps is non-negative
509392
FLASHINFER_CHECK(cache_steps >= 0, "cache_steps must be non-negative, got ", cache_steps);
510393

@@ -588,75 +471,9 @@ void run_selective_state_update_mtp(
588471
ffi::CUDADeviceGuard device_guard(state.device().device_id);
589472
const cudaStream_t stream = get_stream(state.device());
590473

591-
// Dispatch based on dtype combination
592-
DLDataType state_dtype = state.dtype();
593-
DLDataType input_dtype = x.dtype();
594-
DLDataType weight_dtype = dt.dtype();
595-
DLDataType matrixA_dtype = A.dtype();
596-
int64_t state_dtype_code = encode_dlpack_dtype(state_dtype);
597-
int64_t input_dtype_code = encode_dlpack_dtype(input_dtype);
598-
int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype);
599-
int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype);
600-
601-
// Get stateIndex dtype from whichever index tensor is available
602-
// If both are provided, they must have the same dtype
603-
int64_t stateIndex_dtype_code = int32_code; // default
604-
if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) {
605-
DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype();
606-
DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype();
607-
FLASHINFER_CHECK(state_batch_idx_dtype.code == intermediate_idx_dtype.code &&
608-
state_batch_idx_dtype.bits == intermediate_idx_dtype.bits,
609-
"state_batch_indices and intermediate_state_indices must have the same dtype");
610-
stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype);
611-
} else if (state_batch_indices.has_value()) {
612-
DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype();
613-
stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype);
614-
} else if (intermediate_state_indices.has_value()) {
615-
DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype();
616-
stateIndex_dtype_code = encode_dlpack_dtype(intermediate_idx_dtype);
617-
}
618-
619-
// Dispatch kernel based on dtype combination
620-
auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code,
621-
matrixA_dtype_code, stateIndex_dtype_code);
622-
623-
// Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion
624-
auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool {
625-
constexpr size_t I = decltype(idx)::value;
626-
if constexpr (I < std::size(allowed_dtype_combos)) {
627-
constexpr auto combo = allowed_dtype_combos[I];
628-
if (key == combo) {
629-
constexpr auto s = std::get<0>(combo);
630-
constexpr auto i = std::get<1>(combo);
631-
constexpr auto w = std::get<2>(combo);
632-
constexpr auto m = std::get<3>(combo);
633-
constexpr auto si = std::get<4>(combo);
634-
dispatchComboMTP<s, i, w, m, si>(p, stream);
635-
return true;
636-
}
637-
return self(key, std::integral_constant<size_t, I + 1>{}, self);
638-
}
639-
return false;
640-
};
641-
642-
// Dispatch using compile-time type traits
643-
if (!tryDispatch(dtype_key, std::integral_constant<size_t, 0>{}, tryDispatch)) {
644-
// Unsupported dtype combination - build error message dynamically
645-
std::ostringstream error_msg;
646-
error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype="
647-
<< state_dtype.code << ":" << state_dtype.bits << ", "
648-
<< "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", "
649-
<< "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", "
650-
<< "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits
651-
<< ". Supported combos include:\n";
652-
for (const auto& combo : allowed_dtype_combos) {
653-
error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo))
654-
<< ", input=" << dtype_code_to_string(std::get<1>(combo))
655-
<< ", weight=" << dtype_code_to_string(std::get<2>(combo))
656-
<< ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n";
657-
}
658-
TVM_FFI_ICHECK(false) << error_msg.str();
659-
}
474+
auto algo = static_cast<SSUAlgorithm>(algorithm);
475+
mtp::invokeSelectiveStateUpdateMTP<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p, algo,
476+
stream);
660477
}
661478

662479
// =============================================================================
@@ -668,14 +485,17 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso
668485
Optional<TensorView> state_batch_indices, int64_t pad_slot_id,
669486
TensorView output, bool disable_state_update,
670487
Optional<TensorView> intermediate_states_buffer,
671-
Optional<TensorView> intermediate_state_indices, int64_t cache_steps) {
488+
Optional<TensorView> intermediate_state_indices, int64_t cache_steps,
489+
int64_t algorithm) {
672490
if (x.dim() == 3) {
673491
run_selective_state_update_stp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus,
674-
state_batch_indices, pad_slot_id, output, disable_state_update);
492+
state_batch_indices, pad_slot_id, output, disable_state_update,
493+
algorithm);
675494
} else if (x.dim() == 4) {
676-
run_selective_state_update_mtp(
677-
state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, pad_slot_id, output,
678-
disable_state_update, intermediate_states_buffer, intermediate_state_indices, cache_steps);
495+
run_selective_state_update_mtp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus,
496+
state_batch_indices, pad_slot_id, output, disable_state_update,
497+
intermediate_states_buffer, intermediate_state_indices,
498+
cache_steps, algorithm);
679499
} else {
680500
FLASHINFER_CHECK(false,
681501
"x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ",
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
#include <cuda_bf16.h>
3+
#include <cuda_fp16.h>
4+
#include <cstdint>
5+
6+
using state_t = {{ state_dtype }};
7+
using input_t = {{ input_dtype }};
8+
using weight_t = {{ weight_dtype }};
9+
using matrixA_t = {{ matrixA_dtype }};
10+
using stateIndex_t = {{ stateIndex_dtype }};
11+
12+
constexpr int DIM = {{ dim }};
13+
constexpr int DSTATE = {{ dstate }};
14+
constexpr int NTOKENS_MTP = {{ ntokens_mtp }};

0 commit comments

Comments
 (0)