1313 * See the License for the specific language governing permissions and
1414 * limitations under the License.
1515 */
16+ // clang-format off
17+ // config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP
18+ // constexprs that the header's function templates rely on. Reordering breaks compilation.
19+ // NOTE: the .inc file is generated from the jinja templates
20+ #include " selective_state_update_config.inc"
1621#include < flashinfer/mamba/selective_state_update.cuh>
17- #include < sstream>
18-
22+ // clang-format on
1923#include " tvm_ffi_utils.h"
2024
2125using namespace flashinfer ;
@@ -124,87 +128,13 @@ inline void validate_dtype_consistency(
124128 }
125129}
126130
127- // Helper to convert dtype code to string for error messages
128- inline const char * dtype_code_to_string (int64_t code) {
129- if (code == bfloat16_code) return " bfloat16" ;
130- if (code == float16_code) return " float16" ;
131- if (code == float32_code) return " float32" ;
132- return " unknown" ;
133- }
134-
135- // Type traits to map dtype codes to C++ types
136- template <int64_t code>
137- struct DTypeToType ;
138-
139- template <>
140- struct DTypeToType <bfloat16_code> {
141- using type = nv_bfloat16;
142- };
143- template <>
144- struct DTypeToType <float16_code> {
145- using type = half;
146- };
147- template <>
148- struct DTypeToType <float32_code> {
149- using type = float ;
150- };
151- template <>
152- struct DTypeToType <int32_code> {
153- using type = int32_t ;
154- };
155- template <>
156- struct DTypeToType <int64_code> {
157- using type = int64_t ;
158- };
159-
160- // Allowed dtype combinations: {state_code, input_code, weight_code, matrixA_code, stateIndex_code}
161- constexpr std::tuple<int64_t , int64_t , int64_t , int64_t , int64_t > allowed_dtype_combos[] = {
162- {bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int32_code},
163- {float16_code, bfloat16_code, bfloat16_code, float32_code, int32_code},
164- {float32_code, bfloat16_code, bfloat16_code, float32_code, int32_code},
165- {bfloat16_code, bfloat16_code, float32_code, float32_code, int32_code},
166- {float16_code, bfloat16_code, float32_code, float32_code, int32_code},
167- {float32_code, bfloat16_code, float32_code, float32_code, int32_code},
168- {bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int64_code},
169- {float16_code, bfloat16_code, bfloat16_code, float32_code, int64_code},
170- {float32_code, bfloat16_code, bfloat16_code, float32_code, int64_code},
171- {bfloat16_code, bfloat16_code, float32_code, float32_code, int64_code},
172- {float16_code, bfloat16_code, float32_code, float32_code, int64_code},
173- {float32_code, bfloat16_code, float32_code, float32_code, int64_code},
174- };
175-
176- // Helper to dispatch to the right template instantiation for STP
177- template <int64_t state_code, int64_t input_code, int64_t weight_code, int64_t matrixA_code,
178- int64_t stateIndex_code>
179- void dispatchCombo (SelectiveStateUpdateParams& p, cudaStream_t stream) {
180- using state_t = typename DTypeToType<state_code>::type;
181- using input_t = typename DTypeToType<input_code>::type;
182- using weight_t = typename DTypeToType<weight_code>::type;
183- using matrixA_t = typename DTypeToType<matrixA_code>::type;
184- using stateIndex_t = typename DTypeToType<stateIndex_code>::type;
185- invokeSelectiveStateUpdate<input_t , weight_t , matrixA_t, state_t , stateIndex_t>(p, stream);
186- }
187-
188- // Helper to dispatch to the right template instantiation for MTP
189- template <int64_t state_code, int64_t input_code, int64_t weight_code, int64_t matrixA_code,
190- int64_t stateIndex_code>
191- void dispatchComboMTP (mtp::SelectiveStateMTPParams& p, cudaStream_t stream) {
192- using state_t = typename DTypeToType<state_code>::type;
193- using input_t = typename DTypeToType<input_code>::type;
194- using weight_t = typename DTypeToType<weight_code>::type;
195- using matrixA_t = typename DTypeToType<matrixA_code>::type;
196- using stateIndex_t = typename DTypeToType<stateIndex_code>::type;
197- mtp::invokeSelectiveStateUpdateMTP<input_t , weight_t , matrixA_t, state_t , stateIndex_t>(p,
198- stream);
199- }
200-
201131void run_selective_state_update_stp (TensorView const & state, TensorView const & x,
202132 TensorView const & dt, TensorView const & A, TensorView const & B,
203133 TensorView const & C, TensorView const & D,
204134 Optional<TensorView> z, Optional<TensorView> dt_bias,
205135 bool dt_softplus, Optional<TensorView> state_batch_indices,
206136 int64_t pad_slot_id, Optional<TensorView> out,
207- bool disable_state_update) {
137+ bool disable_state_update, int64_t algorithm ) {
208138 // Extract dimensions from input tensors
209139 auto const batch = x.size (0 );
210140 auto const state_cache_size = state.size (0 );
@@ -344,64 +274,8 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x
344274 ffi::CUDADeviceGuard device_guard (state.device ().device_id );
345275 const cudaStream_t stream = get_stream (state.device ());
346276
347- // Dispatch based on dtype combination
348- DLDataType state_dtype = state.dtype ();
349- DLDataType input_dtype = x.dtype ();
350- DLDataType weight_dtype = dt.dtype ();
351- DLDataType matrixA_dtype = A.dtype ();
352- int64_t state_dtype_code = encode_dlpack_dtype (state_dtype);
353- int64_t input_dtype_code = encode_dlpack_dtype (input_dtype);
354- int64_t weight_dtype_code = encode_dlpack_dtype (weight_dtype);
355- int64_t matrixA_dtype_code = encode_dlpack_dtype (matrixA_dtype);
356-
357- // Get state_batch_indices dtype, default to int32 if not provided
358- int64_t stateIndex_dtype_code = int32_code;
359- if (state_batch_indices.has_value ()) {
360- DLDataType stateIndex_dtype = state_batch_indices.value ().dtype ();
361- stateIndex_dtype_code = encode_dlpack_dtype (stateIndex_dtype);
362- }
363-
364- // Dispatch kernel based on dtype combination
365- auto dtype_key = std::make_tuple (state_dtype_code, input_dtype_code, weight_dtype_code,
366- matrixA_dtype_code, stateIndex_dtype_code);
367-
368- // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion
369- auto tryDispatch = [&](const auto & key, auto idx, auto & self) -> bool {
370- constexpr size_t I = decltype (idx)::value;
371- if constexpr (I < std::size (allowed_dtype_combos)) {
372- constexpr auto combo = allowed_dtype_combos[I];
373- if (key == combo) {
374- constexpr auto s = std::get<0 >(combo);
375- constexpr auto i = std::get<1 >(combo);
376- constexpr auto w = std::get<2 >(combo);
377- constexpr auto m = std::get<3 >(combo);
378- constexpr auto si = std::get<4 >(combo);
379- dispatchCombo<s, i, w, m, si>(p, stream);
380- return true ;
381- }
382- return self (key, std::integral_constant<size_t , I + 1 >{}, self);
383- }
384- return false ;
385- };
386-
387- // Dispatch using compile-time type traits
388- if (!tryDispatch (dtype_key, std::integral_constant<size_t , 0 >{}, tryDispatch)) {
389- // Unsupported dtype combination - build error message dynamically
390- std::ostringstream error_msg;
391- error_msg << " Unsupported dtype combination for selective_state_update: " << " state_dtype="
392- << state_dtype.code << " :" << state_dtype.bits << " , "
393- << " input_dtype=" << input_dtype.code << " :" << input_dtype.bits << " , "
394- << " weight_dtype=" << weight_dtype.code << " :" << weight_dtype.bits << " , "
395- << " matrixA_dtype=" << matrixA_dtype.code << " :" << matrixA_dtype.bits
396- << " . Supported combos include:\n " ;
397- for (const auto & combo : allowed_dtype_combos) {
398- error_msg << " (state=" << dtype_code_to_string (std::get<0 >(combo))
399- << " , input=" << dtype_code_to_string (std::get<1 >(combo))
400- << " , weight=" << dtype_code_to_string (std::get<2 >(combo))
401- << " , matrixA=" << dtype_code_to_string (std::get<3 >(combo)) << " )\n " ;
402- }
403- TVM_FFI_ICHECK (false ) << error_msg.str ();
404- }
277+ auto algo = static_cast <SSUAlgorithm>(algorithm);
278+ invokeSelectiveStateUpdate<input_t , weight_t , matrixA_t, state_t , stateIndex_t>(p, algo, stream);
405279}
406280
407281void run_selective_state_update_mtp (
@@ -410,7 +284,7 @@ void run_selective_state_update_mtp(
410284 Optional<TensorView> dt_bias, bool dt_softplus, Optional<TensorView> state_batch_indices,
411285 int64_t pad_slot_id, Optional<TensorView> out, bool disable_state_update,
412286 Optional<TensorView> intermediate_states_buffer,
413- Optional<TensorView> intermediate_state_indices, int64_t cache_steps) {
287+ Optional<TensorView> intermediate_state_indices, int64_t cache_steps, int64_t algorithm ) {
414288 // Extract dimensions from input tensors
415289 auto const batch = x.size (0 );
416290 auto const ntokens_mtp = x.size (1 );
@@ -505,6 +379,15 @@ void run_selective_state_update_mtp(
505379 validate_intermediate_state_indices (intermediate_state_indices, batch);
506380 validate_intermediate_states_buffer (intermediate_states_buffer);
507381
382+ // Validate that state_batch_indices and intermediate_state_indices have the same dtype
383+ if (state_batch_indices.has_value () && intermediate_state_indices.has_value ()) {
384+ DLDataType state_batch_idx_dtype = state_batch_indices.value ().dtype ();
385+ DLDataType intermediate_idx_dtype = intermediate_state_indices.value ().dtype ();
386+ FLASHINFER_CHECK (state_batch_idx_dtype.code == intermediate_idx_dtype.code &&
387+ state_batch_idx_dtype.bits == intermediate_idx_dtype.bits ,
388+ " state_batch_indices and intermediate_state_indices must have the same dtype" );
389+ }
390+
508391 // Validate cache_steps is non-negative
509392 FLASHINFER_CHECK (cache_steps >= 0 , " cache_steps must be non-negative, got " , cache_steps);
510393
@@ -588,75 +471,9 @@ void run_selective_state_update_mtp(
588471 ffi::CUDADeviceGuard device_guard (state.device ().device_id );
589472 const cudaStream_t stream = get_stream (state.device ());
590473
591- // Dispatch based on dtype combination
592- DLDataType state_dtype = state.dtype ();
593- DLDataType input_dtype = x.dtype ();
594- DLDataType weight_dtype = dt.dtype ();
595- DLDataType matrixA_dtype = A.dtype ();
596- int64_t state_dtype_code = encode_dlpack_dtype (state_dtype);
597- int64_t input_dtype_code = encode_dlpack_dtype (input_dtype);
598- int64_t weight_dtype_code = encode_dlpack_dtype (weight_dtype);
599- int64_t matrixA_dtype_code = encode_dlpack_dtype (matrixA_dtype);
600-
601- // Get stateIndex dtype from whichever index tensor is available
602- // If both are provided, they must have the same dtype
603- int64_t stateIndex_dtype_code = int32_code; // default
604- if (state_batch_indices.has_value () && intermediate_state_indices.has_value ()) {
605- DLDataType state_batch_idx_dtype = state_batch_indices.value ().dtype ();
606- DLDataType intermediate_idx_dtype = intermediate_state_indices.value ().dtype ();
607- FLASHINFER_CHECK (state_batch_idx_dtype.code == intermediate_idx_dtype.code &&
608- state_batch_idx_dtype.bits == intermediate_idx_dtype.bits ,
609- " state_batch_indices and intermediate_state_indices must have the same dtype" );
610- stateIndex_dtype_code = encode_dlpack_dtype (state_batch_idx_dtype);
611- } else if (state_batch_indices.has_value ()) {
612- DLDataType state_batch_idx_dtype = state_batch_indices.value ().dtype ();
613- stateIndex_dtype_code = encode_dlpack_dtype (state_batch_idx_dtype);
614- } else if (intermediate_state_indices.has_value ()) {
615- DLDataType intermediate_idx_dtype = intermediate_state_indices.value ().dtype ();
616- stateIndex_dtype_code = encode_dlpack_dtype (intermediate_idx_dtype);
617- }
618-
619- // Dispatch kernel based on dtype combination
620- auto dtype_key = std::make_tuple (state_dtype_code, input_dtype_code, weight_dtype_code,
621- matrixA_dtype_code, stateIndex_dtype_code);
622-
623- // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion
624- auto tryDispatch = [&](const auto & key, auto idx, auto & self) -> bool {
625- constexpr size_t I = decltype (idx)::value;
626- if constexpr (I < std::size (allowed_dtype_combos)) {
627- constexpr auto combo = allowed_dtype_combos[I];
628- if (key == combo) {
629- constexpr auto s = std::get<0 >(combo);
630- constexpr auto i = std::get<1 >(combo);
631- constexpr auto w = std::get<2 >(combo);
632- constexpr auto m = std::get<3 >(combo);
633- constexpr auto si = std::get<4 >(combo);
634- dispatchComboMTP<s, i, w, m, si>(p, stream);
635- return true ;
636- }
637- return self (key, std::integral_constant<size_t , I + 1 >{}, self);
638- }
639- return false ;
640- };
641-
642- // Dispatch using compile-time type traits
643- if (!tryDispatch (dtype_key, std::integral_constant<size_t , 0 >{}, tryDispatch)) {
644- // Unsupported dtype combination - build error message dynamically
645- std::ostringstream error_msg;
646- error_msg << " Unsupported dtype combination for selective_state_update: " << " state_dtype="
647- << state_dtype.code << " :" << state_dtype.bits << " , "
648- << " input_dtype=" << input_dtype.code << " :" << input_dtype.bits << " , "
649- << " weight_dtype=" << weight_dtype.code << " :" << weight_dtype.bits << " , "
650- << " matrixA_dtype=" << matrixA_dtype.code << " :" << matrixA_dtype.bits
651- << " . Supported combos include:\n " ;
652- for (const auto & combo : allowed_dtype_combos) {
653- error_msg << " (state=" << dtype_code_to_string (std::get<0 >(combo))
654- << " , input=" << dtype_code_to_string (std::get<1 >(combo))
655- << " , weight=" << dtype_code_to_string (std::get<2 >(combo))
656- << " , matrixA=" << dtype_code_to_string (std::get<3 >(combo)) << " )\n " ;
657- }
658- TVM_FFI_ICHECK (false ) << error_msg.str ();
659- }
474+ auto algo = static_cast <SSUAlgorithm>(algorithm);
475+ mtp::invokeSelectiveStateUpdateMTP<input_t , weight_t , matrixA_t, state_t , stateIndex_t>(p, algo,
476+ stream);
660477}
661478
662479// =============================================================================
@@ -668,14 +485,17 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso
668485 Optional<TensorView> state_batch_indices, int64_t pad_slot_id,
669486 TensorView output, bool disable_state_update,
670487 Optional<TensorView> intermediate_states_buffer,
671- Optional<TensorView> intermediate_state_indices, int64_t cache_steps) {
488+ Optional<TensorView> intermediate_state_indices, int64_t cache_steps,
489+ int64_t algorithm) {
672490 if (x.dim () == 3 ) {
673491 run_selective_state_update_stp (state, x, dt, A, B, C, D, z, dt_bias, dt_softplus,
674- state_batch_indices, pad_slot_id, output, disable_state_update);
492+ state_batch_indices, pad_slot_id, output, disable_state_update,
493+ algorithm);
675494 } else if (x.dim () == 4 ) {
676- run_selective_state_update_mtp (
677- state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, pad_slot_id, output,
678- disable_state_update, intermediate_states_buffer, intermediate_state_indices, cache_steps);
495+ run_selective_state_update_mtp (state, x, dt, A, B, C, D, z, dt_bias, dt_softplus,
496+ state_batch_indices, pad_slot_id, output, disable_state_update,
497+ intermediate_states_buffer, intermediate_state_indices,
498+ cache_steps, algorithm);
679499 } else {
680500 FLASHINFER_CHECK (false ,
681501 " x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got " ,
0 commit comments