diff --git a/csrc/flat_prefill_kernel_delta_rule_sm120_extern.inc b/csrc/flat_prefill_kernel_delta_rule_sm120_extern.inc new file mode 100644 index 0000000000..c7e901d444 --- /dev/null +++ b/csrc/flat_prefill_kernel_delta_rule_sm120_extern.inc @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Extern template declarations to prevent implicit instantiation in the dispatcher. +// Explicit instantiations are in separate generated files for parallel compilation. + +#pragma once + +#include +#include +#include "cutlass/arch/arch.h" + +namespace flat { + +// clang-format off + +#define FOR_EACH_BOOL_5(MACRO, ...) \ + MACRO(false, false, false, false, false, __VA_ARGS__) \ + MACRO(false, false, false, false, true, __VA_ARGS__) \ + MACRO(false, false, false, true, false, __VA_ARGS__) \ + MACRO(false, false, false, true, true, __VA_ARGS__) \ + MACRO(false, false, true, false, false, __VA_ARGS__) \ + MACRO(false, false, true, false, true, __VA_ARGS__) \ + MACRO(false, false, true, true, false, __VA_ARGS__) \ + MACRO(false, false, true, true, true, __VA_ARGS__) \ + MACRO(false, true, false, false, false, __VA_ARGS__) \ + MACRO(false, true, false, false, true, __VA_ARGS__) \ + MACRO(false, true, false, true, false, __VA_ARGS__) \ + MACRO(false, true, false, true, true, __VA_ARGS__) \ + MACRO(false, true, true, false, false, __VA_ARGS__) \ + MACRO(false, true, true, false, true, __VA_ARGS__) \ + MACRO(false, true, true, true, false, __VA_ARGS__) \ + MACRO(false, true, true, true, true, __VA_ARGS__) \ + MACRO(true, false, false, false, false, __VA_ARGS__) \ + MACRO(true, false, false, false, true, __VA_ARGS__) \ + MACRO(true, false, false, true, false, __VA_ARGS__) \ + MACRO(true, false, false, true, true, __VA_ARGS__) \ + MACRO(true, false, true, false, false, __VA_ARGS__) \ + MACRO(true, false, true, false, true, __VA_ARGS__) \ + MACRO(true, false, true, true, false, __VA_ARGS__) \ + MACRO(true, false, true, true, true, __VA_ARGS__) \ + MACRO(true, true, false, false, false, __VA_ARGS__) \ + MACRO(true, true, false, false, true, __VA_ARGS__) \ + MACRO(true, true, false, true, false, __VA_ARGS__) \ + MACRO(true, true, false, true, true, __VA_ARGS__) \ + MACRO(true, true, true, false, false, __VA_ARGS__) \ + MACRO(true, true, true, false, true, __VA_ARGS__) \ + MACRO(true, true, true, true, false, __VA_ARGS__) \ + MACRO(true, true, true, true, true, __VA_ARGS__) + +#define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, enable_ckpt, ctype) \ +extern template void launch_delta_rule_prefill_kernel_gbai( \ + cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \ + float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, \ + int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t, float*, int64_t const*, \ + int32_t); + +// Extern template declarations for half +FOR_EACH_BOOL_5(DECLARE_TEMPLATE_INSTANCE, half) + +// Extern template declarations for nv_bfloat16 +FOR_EACH_BOOL_5(DECLARE_TEMPLATE_INSTANCE, nv_bfloat16) + +#undef DECLARE_TEMPLATE_INSTANCE +#undef FOR_EACH_BOOL_5 + +// clang-format on + +} // namespace flat diff --git a/csrc/gdn_prefill_launcher.cu b/csrc/gdn_prefill_launcher.cu index 449d9e01b6..d7c984a7ec 100644 --- a/csrc/gdn_prefill_launcher.cu +++ b/csrc/gdn_prefill_launcher.cu @@ -46,8 +46,8 @@ void gdn_prefill_launcher(void* output, void* output_state, void* q, void* k, vo int device_major; cudaDeviceGetAttribute(&device_major, cudaDevAttrComputeCapabilityMajor, dev_id); -#if defined(FLAT_SM90A_ENABLED) if (device_major == 9) { +#if defined(FLAT_SM90A_ENABLED) flat::launch_delta_rule_prefill_kernel( stream, static_cast(output), static_cast(output_state), static_cast(q), static_cast(k), static_cast(v), @@ -57,16 +57,31 @@ void gdn_prefill_launcher(void* output, void* output_state, void* q, void* k, vo static_cast(state_checkpoints), checkpoint_cu_starts, static_cast(checkpoint_every_n_tokens)); return true; +#else + FLASHINFER_ERROR("sm_90a is not enabled, delta rule kernel is not built"); + return false; +#endif + } else if (device_major == 12) { +#if defined(FLAT_SM120A_ENABLED) + flat::launch_delta_rule_prefill_kernel( + stream, static_cast(output), static_cast(output_state), + static_cast(q), static_cast(k), static_cast(v), + static_cast(input_state), static_cast(alpha), + static_cast(beta), cu_seqlens, workspace_buffer, num_seqs, num_q_heads, + num_k_heads, num_v_heads, num_o_heads, head_size, packed_seq, scale, sm_count, + static_cast(state_checkpoints), checkpoint_cu_starts, + static_cast(checkpoint_every_n_tokens)); + return true; +#else + FLASHINFER_ERROR("sm_120a is not enabled, delta rule kernel is not built"); + return false; +#endif } else { std::ostringstream err_msg; err_msg << "delta rule kernel does not support this device major version: " << device_major; FLASHINFER_ERROR(err_msg.str()); return false; } -#else - FLASHINFER_ERROR("sm_90a is not enabled, delta rule kernel is not built"); - return false; -#endif }); } diff --git a/csrc/gdn_prefill_sm120_kernel_inst.jinja b/csrc/gdn_prefill_sm120_kernel_inst.jinja new file mode 100644 index 0000000000..ffdecf5907 --- /dev/null +++ b/csrc/gdn_prefill_sm120_kernel_inst.jinja @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Auto-generated file for separate compilation of GDN prefill kernel variants. +// Template parameters: dtype={{ dtype }}, is_gva={{ is_gva }}, needs_beta={{ needs_beta }}, +// needs_alpha={{ needs_alpha }}, init_state={{ init_state }}, +// enable_checkpointing={{ enable_checkpointing }} + +// CUDA type definitions for half and nv_bfloat16 +#include +#include + +// Include the header which defines the function template +// The header includes all necessary CUTLASS type definitions +#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh" + +namespace flat { + +// Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai +// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm120_extern.inc +template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, {{ enable_checkpointing }}, cutlass::arch::Sm120, {{ dtype }}, {{ dtype }}, float>( + cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*, + float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, + int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t, float*, int64_t const*, + int32_t); + +} // namespace flat diff --git a/csrc/prefill_kernel_delta_rule_sm120.cu b/csrc/prefill_kernel_delta_rule_sm120.cu new file mode 100644 index 0000000000..52cfbcd320 --- /dev/null +++ b/csrc/prefill_kernel_delta_rule_sm120.cu @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh" + +// Extern template declarations prevent implicit instantiation here. +// Explicit instantiations are in separate generated files for parallel compilation. +#include "flat_prefill_kernel_delta_rule_sm120_extern.inc" + +namespace flat { + +using namespace cute; + +template +void launch_delta_rule_prefill_kernel( + cudaStream_t stream, TO* output, TState* output_state, TQKV const* q, TQKV const* k, + TQKV const* v, TState const* input_state, float const* alpha, float const* beta, + int64_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, int32_t num_q_heads, + int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads, int32_t head_size, + int64_t total_seqlen, float scale, int32_t sm_count, float* state_checkpoints, + int64_t const* checkpoint_cu_starts, int32_t checkpoint_every_n_tokens) { + bool is_gva = num_v_heads > num_q_heads; + bool needs_beta = beta != nullptr; + bool needs_alpha = alpha != nullptr; + bool init_state = input_state != nullptr; + bool enable_ckpt = checkpoint_every_n_tokens > 0; + +#define LAUNCH(is_gva, needs_beta, needs_alpha, init_state, enable_ckpt) \ + launch_delta_rule_prefill_kernel_gbai( \ + stream, output, output_state, q, k, v, input_state, alpha, beta, cu_seqlens, \ + workspace_buffer, num_seqs, num_q_heads, num_k_heads, num_v_heads, num_o_heads, head_size, \ + total_seqlen, scale, sm_count, state_checkpoints, checkpoint_cu_starts, \ + checkpoint_every_n_tokens); + +#define DISPATCH_GBAI(init_state, enable_ckpt) \ + if (is_gva && needs_beta && needs_alpha) { \ + LAUNCH(true, true, true, init_state, enable_ckpt); \ + } else if (is_gva && needs_beta && !needs_alpha) { \ + LAUNCH(true, true, false, init_state, enable_ckpt); \ + } else if (is_gva && !needs_beta && needs_alpha) { \ + LAUNCH(true, false, true, init_state, enable_ckpt); \ + } else if (is_gva && !needs_beta && !needs_alpha) { \ + LAUNCH(true, false, false, init_state, enable_ckpt); \ + } else if (!is_gva && needs_beta && needs_alpha) { \ + LAUNCH(false, true, true, init_state, enable_ckpt); \ + } else if (!is_gva && needs_beta && !needs_alpha) { \ + LAUNCH(false, true, false, init_state, enable_ckpt); \ + } else if (!is_gva && !needs_beta && needs_alpha) { \ + LAUNCH(false, false, true, init_state, enable_ckpt); \ + } else if (!is_gva && !needs_beta && !needs_alpha) { \ + LAUNCH(false, false, false, init_state, enable_ckpt); \ + } else { \ + throw std::runtime_error("unreachable"); \ + } + + if (enable_ckpt) { + if (init_state) { + DISPATCH_GBAI(true, true); + } else { + DISPATCH_GBAI(false, true); + } + } else { + if (init_state) { + DISPATCH_GBAI(true, false); + } else { + DISPATCH_GBAI(false, false); + } + } + +#undef DISPATCH_GBAI +#undef LAUNCH +} + +// Explicit instantiations for the outer dispatch function only. +// The inner launch_delta_rule_prefill_kernel_gbai instantiations are in separate files. +template void launch_delta_rule_prefill_kernel( + cudaStream_t stream, half* output, float* state, half const* q, half const* k, half const* v, + float const* input_state, float const* alpha, float const* beta, int64_t const* cu_seqlens, + uint8_t* workspace_buffer, int32_t num_seqs, int32_t num_q_heads, int32_t num_k_heads, + int32_t num_v_heads, int32_t num_o_heads, int32_t head_size, int64_t total_seqlen, float scale, + int32_t sm_count, float* state_checkpoints, int64_t const* checkpoint_cu_starts, + int32_t checkpoint_every_n_tokens); + +template void +launch_delta_rule_prefill_kernel( + cudaStream_t stream, nv_bfloat16* output, float* state, nv_bfloat16 const* q, + nv_bfloat16 const* k, nv_bfloat16 const* v, float const* input_state, float const* alpha, + float const* beta, int64_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, + int32_t num_q_heads, int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads, + int32_t head_size, int64_t total_seqlen, float scale, int32_t sm_count, + float* state_checkpoints, int64_t const* checkpoint_cu_starts, + int32_t checkpoint_every_n_tokens); + +} // namespace flat diff --git a/flashinfer/gdn_prefill.py b/flashinfer/gdn_prefill.py index 124784ff22..84e8a7c326 100644 --- a/flashinfer/gdn_prefill.py +++ b/flashinfer/gdn_prefill.py @@ -21,20 +21,29 @@ import torch from .api_logging import flashinfer_api -from .jit.gdn import gen_gdn_prefill_sm90_module +from .jit.gdn import gen_gdn_prefill_sm90_module, gen_gdn_prefill_sm120_module from .utils import ( register_custom_op, register_fake_op, get_device_sm_count, + is_sm90a_supported, is_sm100a_supported, + is_sm120a_supported, _get_cache_buf, ) from .gdn_kernels import chunk_gated_delta_rule_sm100, _has_blackwell_prefill @functools.cache -def get_gdn_prefill_module(): - module = gen_gdn_prefill_sm90_module().build_and_load() +def get_gdn_prefill_module(device: torch.device): + if is_sm90a_supported(device): + module = gen_gdn_prefill_sm90_module().build_and_load() + elif is_sm120a_supported(device): + module = gen_gdn_prefill_sm120_module().build_and_load() + else: + raise RuntimeError( + f"GDN prefill kernel requires SM90 or SM120, but device {device} is not supported" + ) @register_custom_op( "flashinfer::gdn_prefill", @@ -183,7 +192,7 @@ def chunk_gated_delta_rule( - Supports GQA: ``num_q_heads > num_k_heads = num_v_heads`` - Supports GVA: ``num_v_heads > num_q_heads = num_k_heads`` - The final state layout is ``[N, H, V, K]``. - - Requires SM90 (Hopper) or SM100 (Blackwell) architecture. + - Requires SM90 (Hopper) or SM100 (Blackwell) or SM120 (Blackwell RTX) architecture. - SM100 path requires head_size == 128. - SM100 path requires ``nvidia-cutlass-dsl[cu13]>=4.4.2`` (install via ``pip install flashinfer-python[cu13]``). @@ -336,7 +345,7 @@ def chunk_gated_delta_rule( "gdn_prefill_workspace", workspace_size, device ) - get_gdn_prefill_module().gdn_prefill( + get_gdn_prefill_module(q.device).gdn_prefill( output, output_state, q, diff --git a/flashinfer/jit/gdn.py b/flashinfer/jit/gdn.py index 95667c373d..b288e253dc 100644 --- a/flashinfer/jit/gdn.py +++ b/flashinfer/jit/gdn.py @@ -24,25 +24,37 @@ JitSpec, gen_jit_spec, sm90a_nvcc_flags, + sm120a_nvcc_flags, ) from .utils import write_if_different -def gen_gdn_prefill_sm90_module() -> JitSpec: +def _gen_gdn_prefill_module(arch: str) -> JitSpec: """Generate JIT module for GDN prefill kernel with separate compilation. This generates 32 separate kernel instantiation files (2 dtypes × 16 boolean combinations) plus the original launcher file. The separate files enable parallel compilation by ninja, significantly reducing build time on multi-core machines. """ - uri = "gdn_prefill_sm90" + assert arch in ["sm90", "sm120"], ( + "GDN prefill kernel is only supported on sm_90a and sm_120a" + ) + + if arch == "sm90": + arch_specific_flags = sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED"] + elif arch == "sm120": + arch_specific_flags = sm120a_nvcc_flags + ["-DFLAT_SM120A_ENABLED"] + + uri = f"gdn_prefill_{arch}" gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri os.makedirs(gen_directory, exist_ok=True) source_paths = [] # Load kernel instantiation template - with open(jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_sm90_kernel_inst.jinja") as f: + with open( + jit_env.FLASHINFER_CSRC_DIR / f"gdn_prefill_{arch}_kernel_inst.jinja" + ) as f: kernel_inst_templ = jinja2.Template(f.read()) # Generate 64 separate instance files (2 dtypes × 32 boolean combinations) @@ -74,7 +86,7 @@ def gen_gdn_prefill_sm90_module() -> JitSpec: # Headers are now in include/flashinfer/flat/ and accessible via standard include paths for filename in [ "gdn_prefill_launcher.cu", - "prefill_kernel_delta_rule_sm90.cu", + f"prefill_kernel_delta_rule_{arch}.cu", ]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / src_path.name @@ -84,5 +96,13 @@ def gen_gdn_prefill_sm90_module() -> JitSpec: return gen_jit_spec( uri, source_paths, - extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"], + extra_cuda_cflags=arch_specific_flags + ["-std=c++20"], ) + + +def gen_gdn_prefill_sm90_module(): + return _gen_gdn_prefill_module("sm90") + + +def gen_gdn_prefill_sm120_module(): + return _gen_gdn_prefill_module("sm120") diff --git a/include/flashinfer/flat/hopper/collective/flat_common.hpp b/include/flashinfer/flat/hopper/collective/flat_common.hpp index df3f66ce54..0e9a565333 100644 --- a/include/flashinfer/flat/hopper/collective/flat_common.hpp +++ b/include/flashinfer/flat/hopper/collective/flat_common.hpp @@ -93,6 +93,7 @@ CUTE_DEVICE constexpr auto convert_c_layout_to_a_layout(CLayout const& c, AValue } template +[[deprecated("use restage_smem_layout instead")]] CUTE_DEVICE constexpr auto unstage_smem_layout(Layout const& layout, Stages stages = {}) { return composition(layout, make_tuple(_, _, make_layout(stages))); } @@ -163,4 +164,9 @@ CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, return operand; } +template +CUTE_DEVICE constexpr auto restage_smem_layout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, make_layout(stages))); +} + } // namespace flat::collective diff --git a/include/flashinfer/flat/prefill/prefill_kernel.hpp b/include/flashinfer/flat/prefill/prefill_kernel.hpp index 9e09b74797..48be30808a 100644 --- a/include/flashinfer/flat/prefill/prefill_kernel.hpp +++ b/include/flashinfer/flat/prefill/prefill_kernel.hpp @@ -22,6 +22,7 @@ // Forward declarations to avoid including full cutlass headers namespace cutlass::arch { struct Sm90; +struct Sm120; } // namespace cutlass::arch namespace flat { diff --git a/include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh b/include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh new file mode 100644 index 0000000000..99c5f94f40 --- /dev/null +++ b/include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/util/device_memory.h" +#include "flashinfer/flat/common.hpp" +#include "flashinfer/flat/sm120/device/device_universal.hpp" +#include "flashinfer/flat/sm120/kernel/flat_kernel_builder_delta_rule.hpp" + +namespace flat { + +using namespace cute; + +template +void launch_delta_rule_prefill_kernel_gbai( + cudaStream_t stream, TO* output, TState* output_state, TQKV const* q, TQKV const* k, + TQKV const* v, TState const* input_state, float const* alpha, float const* beta, + int64_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, int32_t num_q_heads, + int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads, int32_t head_size, + int64_t total_seqlen, float scale, int32_t sm_count, float* state_checkpoints, + int64_t const* checkpoint_cu_starts, int32_t checkpoint_every_n_tokens) { +#if defined(FLAT_SM120A_ENABLED) + constexpr bool ArchSupported = true; +#else + constexpr bool ArchSupported = false; +#endif + + if constexpr (ArchSupported) { + static_assert(std::is_same_v); + + using namespace flat::kernel; + using T = map_to_cutlass_t; + + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = sm_count; + + using Options = decltype([&]() { + constexpr auto options_0 = DefaultOptions{}; + constexpr auto options_1 = + add_option(Option{}, options_0); + constexpr auto options_2 = add_option( + Option>{}, + options_1); + constexpr auto options_3 = + add_option(Option>{}, + options_2); + constexpr auto options_4 = + add_option(Option>{}, + options_3); + constexpr auto options_5 = add_option( + Option>{}, + options_4); + constexpr auto options_6 = add_option( + Option>{}, + options_5); + return options_6; + }()); + + using TileShape = Shape<_64, _64, _128>; + using Scheduler = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using Operation = cutlass::device::Universal, + /*LayoutK=*/cute::tuple, + /*LayoutV=*/cute::tuple, + /*LayoutO=*/cute::tuple, Scheduler, Options>::Kernel>; + using Arguments = typename Operation::Arguments; + + // NOTE: LayoutQ/K/V in (seq, head_size, (b,h)) coordinate semantics + + int32_t num_sab_heads = std::max(num_q_heads, num_v_heads); + + int32_t q_tok_stride = num_q_heads * head_size; + int32_t o_tok_stride = num_o_heads * head_size; + int32_t k_tok_stride = num_k_heads * head_size; + int32_t v_tok_stride = num_v_heads * head_size; + + int32_t q_head_stride = head_size; + int32_t o_head_stride = head_size; + int32_t k_head_stride = head_size; + int32_t v_head_stride = head_size; + + Operation op; + Arguments arguments{.problem_size = + { + .cu_seqlens = cu_seqlens, + .total_seqlen = total_seqlen, + .num_seqs = num_seqs, + .num_q_heads = num_q_heads, + .num_k_heads = num_k_heads, + .num_v_heads = num_v_heads, + .num_o_heads = num_o_heads, + .num_sab_heads = num_sab_heads, + .head_size = head_size, + }, + .mainloop = + { + // clang-format off + .ptr_Q = (T*)q, .dQ = {q_tok_stride, _1{}, q_head_stride}, + .ptr_K = (T*)k, .dK = {k_tok_stride, _1{}, k_head_stride}, + .ptr_V = (T*)v, .dV = {v_tok_stride, _1{}, v_head_stride}, + .ptr_O = (T*)output, .dO = {o_tok_stride, _1{}, o_head_stride}, + .ptr_output_state = (float*)output_state, + .ptr_input_state = (float*)input_state, + .scale = scale, + .alpha_ptr = alpha, .alpha_stride = {num_sab_heads, 1}, + .beta_ptr = beta, .beta_stride = {num_sab_heads, 1}, + .ptr_state_checkpoints = state_checkpoints, + .checkpoint_cu_starts = checkpoint_cu_starts, + .checkpoint_every_n_tokens = checkpoint_every_n_tokens, + }, // clang-format on + .hw_info = hw_info}; + + cutlass::Status status; + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("can_implement failed"); + } + + status = op.initialize(arguments, workspace_buffer, stream); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("initialize failed"); + } + + status = op.run(stream); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("run failed"); + } + } else { + throw std::runtime_error("sm_120a not supported"); + } +} + +} // namespace flat diff --git a/include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp b/include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp new file mode 100644 index 0000000000..d8a15c57f3 --- /dev/null +++ b/include/flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp @@ -0,0 +1,1206 @@ +/* + * Copyright (c) 2026 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +// Reuse hopper code +#include "flashinfer/flat/ampere/collective/flat_collective_inverse.hpp" +#include "flashinfer/flat/ampere/collective/flat_collective_load.hpp" +#include "flashinfer/flat/cute_ext.hpp" +#include "flashinfer/flat/hopper/collective/flat_collective_load.hpp" +#include "flashinfer/flat/hopper/collective/flat_collective_store.hpp" +#include "flashinfer/flat/hopper/collective/flat_common.hpp" +#include "flashinfer/flat/hopper/collective/flat_named_barriers.hpp" +#include "flashinfer/flat/hopper/kernel/flat_options.hpp" +#include "flashinfer/flat/math_order_barrier.hpp" +#include "flashinfer/flat/unused.hpp" + +// #define INLINE_LAMBDA [[gnu::always_inline]] +#define INLINE_LAMBDA __attribute__((always_inline)) +// #define INLINE_LAMBDA [[msvc::forceinline]] + +namespace flat::collective { + +struct DeltaRuleNamedBarriers : FlatSharedNamedBarriers { + static constexpr int KKLaunched = FlatSharedNamedBarriers::NumBarriersUsed + 0; + static constexpr int AuxMath = FlatSharedNamedBarriers::NumBarriersUsed + 1; +}; + +using namespace cute; +using flat::kernel::find_option_t; +using flat::kernel::Tag; + +template +struct FlatMainloopTmaWarpSpecializedDeltaRule { + static_assert(std::is_same_v, + "HMMA pipeline only supports float accumulator for QK matmul"); + static_assert(std::is_same_v, + "HMMA pipeline only supports float accumulator for KV matmul"); + using Element = Element_; + using ElementAccumulatorQK = ElementAccumulatorQK_; + using ElementAccumulatorO = ElementAccumulatorQK; + using ElementAccumulatorKV = ElementAccumulatorKV_; + using ElementO = Element; + + using TileShape = TileShape_; + + using LayoutQ = LayoutQ_; // (seqlen_q, d, h) + using LayoutK = LayoutK_; // (seqlen_k, d, h) + using LayoutV = LayoutV_; // (seqlen_k, d, h) + using LayoutO = LayoutO_; // (seqlen_k, d, h) + + // Options + static constexpr bool kIsPersistent = + find_option_t::value; + + static constexpr bool kInitStateFromInput = + find_option_t::value; + + static constexpr bool kEnableCheckpointing = + find_option_t::value; + + static constexpr int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = 2; + + static constexpr int StageCountQ = find_option_t, Options>::value; + static constexpr int StageCountK = find_option_t, Options>::value; + static constexpr int StageCountV = find_option_t, Options>::value; + + static constexpr int NeedsAlpha = + find_option_t::value; + static constexpr int NeedsBeta = find_option_t::value; + + static constexpr int NeedsDecay = + find_option_t::value; + static_assert(!NeedsDecay, "DeltaRule does not supports decay"); + + static constexpr int NumLoadThreads = NumLoadWarpGroups * 128; + static constexpr int NumMmaThreads = NumMmaWarpGroups * 128; + + static constexpr uint32_t OrderedBarrierId0 = + uint32_t(cutlass::arch::ReservedNamedBarriers::StreamkBarrier0); + static constexpr uint32_t OrderedBarrierId1 = + uint32_t(cutlass::arch::ReservedNamedBarriers::StreamkBarrier1); + + using OrderedMathBarriers = std::conditional_t< + NumMmaWarpGroups == 2, + OrderedNamedBarriers, + OrderedNamedBarriers>; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesK = cutlass::gemm::collective::StageCount; + using StagesV = cutlass::gemm::collective::StageCount; + using StagesO = cutlass::gemm::collective::StageCount<1>; + using ClusterShape = Shape<_1, _1, _1>; + + using StagesAlphaBeta = cutlass::gemm::collective::StageCount<2>; + + static constexpr int Alignment = 16 / sizeof(Element); + + static constexpr auto BlkSeqQ = get<0>(TileShape{}); // Blk_Q + static constexpr auto BlkSeqKV = get<1>(TileShape{}); // Blk_K/V + static constexpr auto HeadSize = get<2>(TileShape{}); // D (Dq, Dk, Dv all equal) + static constexpr auto HeadSizeQK = HeadSize; + static constexpr auto HeadSizeV = HeadSize; + + using TileShapeQK = decltype(make_shape(BlkSeqQ, BlkSeqKV, HeadSizeQK)); + using TileShapeKK = decltype(make_shape(BlkSeqKV, BlkSeqKV, HeadSizeQK)); + using TileShapeKV = decltype(make_shape(HeadSizeV, HeadSizeQK, BlkSeqKV)); + static_assert(std::is_same_v); + + using TileShapeO2 = decltype(make_shape(HeadSizeV, BlkSeqQ, BlkSeqKV)); + using TileShapeO1 = decltype(make_shape(HeadSizeV, BlkSeqQ, HeadSizeQK)); + + static_assert(BlkSeqQ % 64 == 0); + static_assert(BlkSeqQ == 64 || BlkSeqQ == 128); + static constexpr bool IsQKCooperative = BlkSeqQ == 128; + static constexpr bool IsKKCooperative = IsQKCooperative; + + using SmemLayoutQ_SD = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + make_shape(get<0>(TileShape{}), get<2>(TileShape{}), Int{}), + Step<_0, _1, _2>{})); + using SmemLayoutK_SD = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + make_shape(get<1>(TileShape{}), get<2>(TileShape{}), Int{}), + Step<_0, _1, _2>{})); + using SmemLayoutV_SD = decltype(restage_smem_layout(SmemLayoutK_SD{}, Int{})); + ; + + using SmemLayoutK_DS = decltype(select_layout<1, 0, 2>(SmemLayoutK_SD{})); + using SmemLayoutV_DS = decltype(select_layout<1, 0, 2>(SmemLayoutV_SD{})); + + using MmaOp = std::conditional_t, + SM80_16x8x16_F32BF16BF16F32_TN, SM80_16x8x16_F32F16F16F32_TN>; + + using RefLayoutV = decltype(make_layout(select<0, 2>(TileShapeKV{}), LayoutRight{})); + + using RefLayoutKV = + decltype(make_layout(select<0, 1>(TileShapeKV{}), LayoutRight{})); // (dv, dk) + + // (blk_q,blk_k) to align with O2 mma, LayoutRight to align with QK mma output + using DesiredLayoutQK = decltype(make_layout(select<0, 1>(TileShapeQK{}), LayoutRight{})); + + using TiledMmaQK = decltype(make_tiled_mma(MmaOp{}, Layout<_4, _1>{}, TileShapeQK{})); // Q@K^t + using TiledMmaKV = decltype(make_tiled_mma(MmaOp{}, Layout<_8, _1>{}, TileShapeKV{})); // V @ K + using TiledMmaO1 = decltype(make_tiled_mma(MmaOp{}, Layout<_8, _1>{}, TileShapeO1{})); // KV @ Q + using TiledMmaO2 = decltype(make_tiled_mma(MmaOp{}, Layout<_8, _1>{}, TileShapeO2{})); // V @ QK + + static_assert(size(TiledMmaQK{}) == NumMmaThreads || size(TiledMmaQK{}) == NumMmaThreads / 2); + + static_assert(size(TiledMmaKV{}) == NumMmaThreads); + static_assert(size(TiledMmaO1{}) == NumMmaThreads); + static_assert(size(TiledMmaO2{}) == NumMmaThreads); + + using CollectiveStoreO = + CollectiveStoreTma(LayoutO{})), StagesO::value>; + + // layout for compute output + using LayoutAtom = Layout, Stride<_8, _1>>; + using SmemLayoutQK = + decltype(tile_to_shape(LayoutAtom{}, select<0, 1>(TileShapeQK{}), Step<_1, _2>{})); + using SmemLayoutKK = + decltype(tile_to_shape(LayoutAtom{}, select<0, 1>(TileShapeQK{}), Step<_1, _2>{})); + using SmemLayoutO = typename CollectiveStoreO::SmemLayoutO; + + using InverseType = cutlass::half_t; + using CollectiveInverse = flat::collective::CollectiveInverse; + + using ElementAccumulatorSK = float; + using TileShapeSK = decltype(make_shape(HeadSizeV, BlkSeqKV, HeadSizeQK)); + + using ElementAccumulatorNewV = float; + using TileShapeNewV = decltype(make_shape(HeadSizeV, BlkSeqKV, BlkSeqKV)); + using RefLayoutSK = + decltype(make_layout(select<0, 2>(TileShapeNewV{}), LayoutRight{})); // (dv, Blk) + using DesiredLayoutKK = decltype(make_layout(select<1, 2>(TileShapeNewV{}), LayoutRight{})); // + + using TiledMmaKK = decltype(make_tiled_mma( + MmaOp{}, Layout<_4, _1>{}, TileShapeKK{})); // T = inv(I + strict_lower_triangular(K@K^t)) + using TiledMmaSK = + decltype(make_tiled_mma(MmaOp{}, Layout<_8, _1>{}, TileShapeSK{})); // ?? = -S@K^t + V^t + + using TiledMmaNewV = + decltype(make_tiled_mma(MmaOp{}, Layout<_8, _1>{}, TileShapeNewV{})); // NewV = ??@T^t + + static_assert(size(TiledMmaKK{}) == size(TiledMmaQK{})); + + using GmemStrideAlphaBeta = Stride; + using GmemLayoutAlphaBeta = Layout, GmemStrideAlphaBeta>; // (seq, head) + + // (blk, pipe, cumsum_log/cumprod), + // 0 for cumsum(log(alpha)) aka log(cumprod(alpha)) + // 1 for cumprod(alpha) + // 2 for cumprod(alpha) * scale + using AlphaCumSumLogIdx = _0; + using AlphaCumProdIdx = _1; + using AlphaCumProdScaleIdx = _2; + + using SmemLayoutAlpha = + decltype(make_layout(make_shape(BlkSeqQ, Int<3>{}, Int{}))); + using SmemLayoutBeta = decltype(make_layout(make_shape(BlkSeqQ, Int{}))); + + using MainloopQPipeline = cutlass::PipelineTmaAsync; + using MainloopKPipeline = cutlass::PipelineTmaAsync; + using MainloopVPipeline = cutlass::PipelineTmaAsync; + using MainloopOPipeline = typename CollectiveStoreO::Pipeline; + + using MainloopAlphaPipeline = + std::conditional_t, Unused>; + using MainloopBetaPipeline = + std::conditional_t, Unused>; + + using QPipelineState = typename cutlass::PipelineState; + using KPipelineState = typename cutlass::PipelineState; + using VPipelineState = typename cutlass::PipelineState; + using OPipelineState = typename CollectiveStoreO::PipelineState; + + using AlphaPipelineState = + std::conditional_t, Unused>; + using BetaPipelineState = + std::conditional_t, Unused>; + + struct AlphaProcessor { + CUTE_DEVICE + AlphaProcessor(float scale) : scale_(scale) {} + + template + CUTE_DEVICE void operator()(T&& vecs) { + constexpr int WarpSize = cutlass::NumThreadsPerWarp; + int lane_id = cutlass::canonical_lane_idx(); + + Tensor vecs_32 = flat_divide( + std::forward(vecs), + make_tile(Int{})); // ((32), iter, cumsum_log/cumprod/cumprod_scale) + Tensor vec_cumsum_log = vecs_32(make_coord(_), _, AlphaCumSumLogIdx{}); + Tensor vec_cumprod = vecs_32(make_coord(_), _, AlphaCumProdIdx{}); + Tensor vec_cumprod_s = vecs_32(make_coord(_), _, AlphaCumProdScaleIdx{}); // cumprod * scale + Tensor frag = make_tensor(size<1>(vec_cumprod)); + + CUTE_UNROLL + for (int iter = 0; iter < size(frag); ++iter) { + frag(iter) = log2f(vec_cumsum_log(lane_id, iter) + 1e-10f); + } + + CUTE_UNROLL + for (int offset = 1; offset < WarpSize; offset *= 2) { + CUTE_UNROLL + for (int iter = 0; iter < size(frag); ++iter) { + auto v = __shfl_up_sync(0xFFFFFFFF, frag(iter), offset); + if (lane_id >= offset) { + frag(iter) += v; + } + } + } + + float sum = 0.0f; + CUTE_UNROLL + for (int iter = 1; iter < size(frag); ++iter) { + sum = __shfl_sync(0xFFFFFFFF, frag(iter - 1), 31); + frag(iter) += sum; + } + + CUTE_UNROLL + for (int iter = 0; iter < size(frag); ++iter) { + vec_cumsum_log(lane_id, iter) = frag(iter); + float cumprod = exp2f(frag(iter)); + vec_cumprod(lane_id, iter) = cumprod; + vec_cumprod_s(lane_id, iter) = cumprod * scale_; + } + } + + float scale_ = 1.0f; + }; + + using BetaProcessor = Unused; + // struct BetaProcessor { + // template + // CUTE_DEVICE + // void operator()(T&& vec) { + // int lane_id = cutlass::canonical_lane_idx(); + // int warp_size = cutlass::NumThreadsPerWarp; + // for (int i = lane_id; i < size(vec); i += warp_size) { + // auto val = vec(i); + // val = max(val, 1e-10f); // clamp due to fusion with IKK before matrix inverse + // vec(i) = 1.0f / val; + // } + // } + // }; + + static constexpr int LoadQBytes = size(SmemLayoutQ_SD{}(_, _, _0{})) * sizeof(Element); + static constexpr int LoadKBytes = size(SmemLayoutK_DS{}(_, _, _0{})) * sizeof(Element); + static constexpr int LoadVBytes = size(SmemLayoutV_DS{}(_, _, _0{})) * sizeof(Element); + static constexpr int StoreOBytes = CollectiveStoreO::TmaTransactionBytes; + + using SharedStorageO = typename CollectiveStoreO::SharedStorage; + + struct SharedStorage { + alignas(alignment_for_swizzle( + SmemLayoutQ_SD{})) cute::array_aligned> smem_q; + alignas(alignment_for_swizzle( + SmemLayoutK_DS{})) cute::array_aligned> smem_k; + alignas(alignment_for_swizzle( + SmemLayoutV_DS{})) cute::array_aligned> smem_v; + alignas(alignment_for_swizzle( + SmemLayoutQK{})) cute::array_aligned> smem_qk; + alignas(alignment_for_swizzle( + SmemLayoutKK{})) cute::array_aligned> smem_kk; + + SharedStorageO smem_o; + // TODO: make optional + cute::array_aligned> smem_beta; + cute::array_aligned> smem_alpha; + }; + + using GmemTiledCopyQKV = cute::SM90_TMA_LOAD; + using ShapeQKV = Shape; // (seq, d, h) + + static_assert(size(ClusterShape{}) == 1, "mcast TMA not supported"); + static constexpr auto cluster_size_no_mcast = _1{}; + + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopyQKV{}, + make_tensor(static_cast(nullptr), + make_layout(ShapeQKV{}, LayoutQ{})), // LayoutQ is stride actually + take<0, 2>(SmemLayoutQ_SD{}), // no Stages + select<0, 2>(TileShape{}), // (seqlen, d) + cluster_size_no_mcast)); + + using TMA_K = decltype(make_tma_copy( + GmemTiledCopyQKV{}, + select_tensor<1, 0, 2>( + make_tensor(static_cast(nullptr), + make_layout(ShapeQKV{}, LayoutK{}))), // LayoutK is stride actually + take<0, 2>(SmemLayoutK_DS{}), // no Stages + select<2, 1>(TileShape{}), // (d, seqlen) + cluster_size_no_mcast)); + + using TMA_V = decltype(make_tma_copy( + GmemTiledCopyQKV{}, + select_tensor<1, 0, 2>( + make_tensor(static_cast(nullptr), + make_layout(ShapeQKV{}, LayoutV{}))), // LayoutV is stride actually + take<0, 2>(SmemLayoutV_DS{}), // no Stages + select<2, 1>(TileShape{}), // (d, seqlen) + cluster_size_no_mcast)); + + using TMA_O = typename CollectiveStoreO::Params::TMA_O; + + using LoadQ = CollectiveLoadTma; + using LoadK = CollectiveLoadTma; + using LoadV = CollectiveLoadTma; + + using LoadAlpha = + CollectiveLoadVector; + using LoadBeta = CollectiveLoadVector; + + struct Arguments { // clang-format off + Element const* ptr_Q; LayoutQ dQ; + Element const* ptr_K; LayoutK dK; + Element const* ptr_V; LayoutV dV; + Element* ptr_O; LayoutO dO; + float* ptr_output_state; // layout fixed (kdim, vdim, num_heads, num_seqs):LayoutLeft{} + float const* ptr_input_state; + float scale; + float const* alpha_ptr; GmemStrideAlphaBeta alpha_stride; + float const* beta_ptr; GmemStrideAlphaBeta beta_stride; + float* ptr_state_checkpoints; // [total_checkpoints, num_sab_heads, K, V] + int64_t const* checkpoint_cu_starts; // [num_seqs + 1] + int32_t checkpoint_every_n_tokens; // 0 = disabled, must be multiple of BlkSeqKV(64) + }; // clang-format on + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_O tma_store_o; + void* tensormaps; + float scale; + + float* ptr_output_state; + float const* ptr_input_state; + + float const* alpha_ptr; + GmemLayoutAlphaBeta alpha_layout; + float const* beta_ptr; + GmemLayoutAlphaBeta beta_layout; + + float* ptr_state_checkpoints; + int64_t const* checkpoint_cu_starts; + int32_t checkpoint_every_n_tokens; + }; + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + auto ratio = problem_size.num_q_heads > problem_size.num_v_heads + ? problem_size.num_q_heads / problem_size.num_v_heads + : problem_size.num_v_heads / problem_size.num_q_heads; + + constexpr bool IsGVAEnabled = find_option_t::value; + + bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) && + (problem_size.num_q_heads == ratio * problem_size.num_k_heads) && + (problem_size.num_q_heads == ratio * problem_size.num_v_heads); + + bool is_gva_like = (problem_size.num_q_heads == problem_size.num_k_heads) && + (problem_size.num_v_heads == ratio * problem_size.num_q_heads) && + (problem_size.num_v_heads == ratio * problem_size.num_k_heads); + return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) && + (problem_size.head_size <= get<2>(TileShape{})) && + ((problem_size.head_size % Alignment) == 0); + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, + void* workspace) { + int64_t s = problem_size.total_seqlen; + int64_t t = problem_size.total_seqlen; + int32_t d = problem_size.head_size; + + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), + make_layout(ShapeQKV(s, d, problem_size.num_q_heads), args.dQ)); + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), + make_layout(ShapeQKV(t, d, problem_size.num_k_heads), args.dK)); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), + make_layout(ShapeQKV(t, d, problem_size.num_v_heads), args.dV)); + + TMA_Q tma_load_q = + make_tma_copy(GmemTiledCopyQKV{}, mQ, take<0, 2>(SmemLayoutQ_SD{}), // no Stages + select<0, 2>(TileShape{}), // (seqlen_q, d) + cluster_size_no_mcast); + + TMA_K tma_load_k = make_tma_copy(GmemTiledCopyQKV{}, select_tensor<1, 0, 2>(mK), + take<0, 2>(SmemLayoutK_DS{}), // no Stages + select<2, 1>(TileShape{}), // (d, seqlen_kv) + cluster_size_no_mcast); + + TMA_V tma_load_v = make_tma_copy(GmemTiledCopyQKV{}, select_tensor<1, 0, 2>(mV), + take<0, 2>(SmemLayoutV_DS{}), // no Stages + select<2, 1>(TileShape{}), // (d, seqlen_kv) + cluster_size_no_mcast); + + auto params_o = CollectiveStoreO::to_underlying_arguments( + make_shape(d, s, d, problem_size.num_o_heads), // in O1 + // make_shape(d, s, s, problem_size.num_o_heads), // in O2 + typename CollectiveStoreO::Arguments{args.ptr_O, select<1, 0, 2>(args.dO)}, workspace); + + return Params{ + .tma_load_q = tma_load_q, + .tma_load_k = tma_load_k, + .tma_load_v = tma_load_v, + .tma_store_o = params_o.tma_store_o, + .tensormaps = params_o.tensormaps, + .scale = args.scale, + + .ptr_output_state = args.ptr_output_state, + .ptr_input_state = args.ptr_input_state, + + // TODO: refactor all name to varname_vartype + .alpha_ptr = args.alpha_ptr, + .alpha_layout = make_layout(make_shape(s, problem_size.num_sab_heads), args.alpha_stride), + .beta_ptr = args.beta_ptr, + .beta_layout = make_layout(make_shape(s, problem_size.num_sab_heads), args.beta_stride), + + .ptr_state_checkpoints = args.ptr_state_checkpoints, + .checkpoint_cu_starts = args.checkpoint_cu_starts, + .checkpoint_every_n_tokens = args.checkpoint_every_n_tokens, + }; + } + + static size_t get_workspace_size(Arguments const& args, int sm_count) { + return CollectiveStoreO::get_workspace_size(sm_count); + } + + template + static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, + Arguments const& args, void* workspace, + cudaStream_t stream) { + return CollectiveStoreO::initialize_workspace(problem_shape, workspace, stream); + } + + CUTE_DEVICE static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + template + CUTE_DEVICE void load_qkv(Params const& params, ProblemShape const& problem_size, + LoadTileShape const& load_tile_shape, WorkDesc const& work_desc, + MainloopQPipeline& q_pipeline, QPipelineState& q_smem_pipe_write, + MainloopKPipeline& k_pipeline, KPipelineState& k_smem_pipe_write, + MainloopVPipeline& v_pipeline, VPipelineState& v_smem_pipe_write, + SharedStorage& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + uint32_t lane_predicate = cute::elect_one_sync(); + + auto q_collective_load = LoadQ(params.tma_load_q, q_pipeline, storage.smem_q); + auto k_collective_load = LoadK(params.tma_load_k, k_pipeline, storage.smem_k); + auto v_collective_load = LoadV(params.tma_load_v, v_pipeline, storage.smem_v); + + auto q_src_dst = q_collective_load.partition_SD(problem_size, load_tile_shape, work_desc); + auto k_src_dst = k_collective_load.partition_SD(problem_size, load_tile_shape, work_desc); + auto v_src_dst = v_collective_load.partition_SD(problem_size, load_tile_shape, work_desc); + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks; ++blk) { + k_collective_load.step(k_src_dst, blk, k_smem_pipe_write, lane_predicate); + q_collective_load.step(q_src_dst, blk, q_smem_pipe_write, lane_predicate); + v_collective_load.step(v_src_dst, blk, v_smem_pipe_write, lane_predicate); + } + } + + template + CUTE_DEVICE void load_beta(Params const& params, ProblemShape const& problem_size, + TileShape const& tile_shape, WorkDesc const& work_desc, + MainloopBetaPipeline& pipeline, BetaPipelineState& smem_pipe_write, + SharedStorage& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + + // fuse post inverse diag(beta) into diagonal of IKK + // auto collective_load = LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/1.0f, + // pipeline, storage.smem_beta}; + auto collective_load = LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/0.0f, + pipeline, storage.smem_beta}; + auto src_dst = collective_load.partition_SD(problem_size, tile_shape, work_desc); + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks - 1; ++blk) { + collective_load.step(src_dst, blk, smem_pipe_write, num_blocks); + } + collective_load.step(src_dst, num_blocks - 1, smem_pipe_write, num_blocks); + } + + template + CUTE_DEVICE void load_alpha(Params const& params, ProblemShape const& problem_size, + TileShape const& tile_shape, WorkDesc const& work_desc, + MainloopAlphaPipeline& pipeline, AlphaPipelineState& smem_pipe_write, + SharedStorage& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + + auto collective_load = LoadAlpha{params.alpha_ptr, params.alpha_layout, /*oob_value=*/1.0f, + pipeline, storage.smem_alpha}; + auto src_dst = collective_load.partition_SD(problem_size, tile_shape, work_desc); + + typename LoadAlpha::VectorProcessor processor{params.scale}; + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks - 1; ++blk) { + collective_load.step(src_dst, blk, smem_pipe_write, num_blocks, processor); + } + collective_load.step(src_dst, num_blocks - 1, smem_pipe_write, num_blocks, + processor); + } + + template + CUTE_DEVICE void store(TMA_O const& tma_store, void* tensormaps, ProblemSize const& problem_size, + StoreTileShape const& store_tile_shape, WorkDesc const& work_desc, + MainloopOPipeline& pipeline, PipelineState& smem_pipe_read, + SharedStorageO& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + uint32_t lane_predicate = cute::elect_one_sync(); + + auto collective_store = CollectiveStoreO{tma_store, pipeline, storage, tensormaps}; + auto src_dst = collective_store.partition_SD(problem_size, store_tile_shape, work_desc); + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks; ++blk) { + DPRINTF0_W("O collective_store.step smem_pipe_read:%d -> blk_idx:%d, num_blocks:%d\n", + smem_pipe_read.index(), blk, num_blocks); + collective_store.step(problem_size, work_desc, src_dst, smem_pipe_read, blk, num_blocks, + lane_predicate); + } + } + + template + CUTE_DEVICE void compute(Params const& params, ProblemShape const& problem_size, + WorkDesc const& work_desc, MainloopQPipeline& q_pipeline, + QPipelineState& q_smem_pipe_read, MainloopKPipeline& k_pipeline, + KPipelineState& k_smem_pipe_read, MainloopVPipeline& v_pipeline, + VPipelineState& v_smem_pipe_read, MainloopOPipeline& o_pipeline, + OPipelineState& o_smem_pipe_write, MainloopAlphaPipeline& alpha_pipeline, + AlphaPipelineState& alpha_smem_pipe_read, + MainloopBetaPipeline& beta_pipeline, + BetaPipelineState& beta_smem_pipe_read, + OrderedMathBarriers& math_barriers, SharedStorage& storage) { + // MAKE NVCC HAPPY! + constexpr auto zero = Element{}; + + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + DPRINTF0_WG("num_blocks: %d\n", num_blocks); + + int thread_idx = int(threadIdx.x) - NumLoadThreads; + int warpgroup_idx = thread_idx / cutlass::NumThreadsPerWarpGroup; + + int kk_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; + int qk_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; + bool is_kk_wg = NumMmaWarpGroups == 1 || warpgroup_idx == 0; + bool is_qk_wg = NumMmaWarpGroups == 1 || warpgroup_idx == 1; + + float scale = params.scale; + + Tensor Beta = make_tensor(make_smem_ptr(storage.smem_beta.data()), SmemLayoutBeta{}); + Tensor Alpha = make_tensor(make_smem_ptr(storage.smem_alpha.data()), SmemLayoutAlpha{}); + + Tensor sQ_SD = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ_SD{}); + Tensor sK_SD = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK_SD{}); + Tensor sK_DS = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK_DS{}); + Tensor sV_DS = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV_DS{}); + Tensor sQK = make_tensor(make_smem_ptr(storage.smem_qk.data()), SmemLayoutQK{}); + Tensor sO = make_tensor(make_smem_ptr(storage.smem_o.data()), SmemLayoutO{}); + + static_assert(sizeof(InverseType) == sizeof(Element)); + Tensor sKK_inv = make_tensor(make_smem_ptr(storage.smem_kk.data()), SmemLayoutKK{}); + Tensor sKK_opd = make_tensor(make_smem_ptr(reinterpret_cast(storage.smem_kk.data())), + SmemLayoutKK{}); + + /////////////////////////////////////////////////////////////////////////// + // Q@K + auto qk_tiled_mma = TiledMmaQK{}; + auto qk_thr_mma = qk_tiled_mma.get_thread_slice(qk_thread_idx); + auto qk_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, qk_tiled_mma); + auto qk_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, qk_tiled_mma); + auto qk_thr_copy_A = qk_tiled_copy_A.get_thread_slice(qk_thread_idx); + auto qk_thr_copy_B = qk_tiled_copy_B.get_thread_slice(qk_thread_idx); + + Tensor tQKrQ = qk_thr_mma.partition_fragment_A(sQ_SD(_, _, _0{})); + Tensor tQKrQ_cv = qk_thr_copy_A.retile_D(tQKrQ); + Tensor tQKsQ = qk_thr_copy_A.partition_S(sQ_SD); + + Tensor tQKrK = qk_thr_mma.partition_fragment_B(sK_SD(_, _, _0{})); + Tensor tQKrK_cv = qk_thr_copy_B.retile_D(tQKrK); + Tensor tQKsK = qk_thr_copy_B.partition_S(sK_SD); + + auto cMqk = make_identity_tensor(select<0, 1>(TileShapeQK{})); // (QTok, KTok) + auto tQKcMqk = qk_thr_mma.partition_C(cMqk); // (idx) -> (tok_q, tok_k) + + /////////////////////////////////////////////////////////////////////////// + // K@K (basically I + strict_lower_triangular(K K^T) + auto kk_tiled_mma = TiledMmaKK{}; + auto kk_thr_mma = kk_tiled_mma.get_thread_slice(kk_thread_idx); + auto kk_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, kk_tiled_mma); + auto kk_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, kk_tiled_mma); + auto kk_thr_copy_A = kk_tiled_copy_A.get_thread_slice(kk_thread_idx); + auto kk_thr_copy_B = kk_tiled_copy_B.get_thread_slice(kk_thread_idx); + + Tensor tKKrA = kk_thr_mma.partition_fragment_A(sK_SD(_, _, _0{})); + Tensor tKKrA_cv = kk_thr_copy_A.retile_D(tKKrA); + Tensor tKKsA = kk_thr_copy_A.partition_S(sK_SD); + + Tensor tKKrB = kk_thr_mma.partition_fragment_B(sK_SD(_, _, _0{})); + Tensor tKKrB_cv = kk_thr_copy_B.retile_D(tKKrB); + Tensor tKKsB = kk_thr_copy_B.partition_S(sK_SD); + + auto const& cMkk = cMqk; + auto tKKcMkk = kk_thr_mma.partition_C(cMkk); + + /////////////////////////////////////////////////////////////////////////// + // S@K (-S K^T + V^T) + auto sk_tiled_mma = TiledMmaSK{}; + auto sk_thr_mma = sk_tiled_mma.get_thread_slice(thread_idx); + + auto layout_SKAlpha = flatten(make_layout( // broadcast Alpha vector to SK size + make_layout(select<0, 1>(TileShapeSK{}), Stride<_0, _1>{}), // (D, Blk_KV) + select<1, 2>(SmemLayoutAlpha{}) // (Idx, pipe) + )); // (D, Blk_KV, Idx, pipe) + + auto tSKrAlpha = sk_thr_mma.partition_C(Alpha.compose(layout_SKAlpha))( + _, _, _, AlphaCumProdIdx{}, _); // (frag, iter_D, iter_Blk_Q, pipe) + + // tSKrV adds to tSKrSK (acc) + auto sk_tiled_copy_C = make_tiled_copy_C(Copy_Atom{}, sk_tiled_mma); + auto sk_thr_copy_C = sk_tiled_copy_C.get_thread_slice(thread_idx); + auto sk_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, sk_tiled_mma); + auto sk_thr_copy_B = sk_tiled_copy_B.get_thread_slice(thread_idx); + + Tensor tSKrK = sk_thr_mma.partition_fragment_B(sK_SD(_, _, _0{})); + Tensor tSKrK_cv = sk_thr_copy_B.retile_D(tSKrK); + Tensor tSKsK = sk_thr_copy_B.partition_S(sK_SD); + + /////////////////////////////////////////////////////////////////////////// + // NewV = (S@K result) @ T^t + auto newv_tiled_mma = TiledMmaNewV{}; + auto newv_thr_mma = newv_tiled_mma.get_thread_slice(thread_idx); + auto newv_tiled_copy_B = + make_tiled_copy_B(Copy_Atom{}, newv_tiled_mma); + auto newv_thr_copy_B = newv_tiled_copy_B.get_thread_slice(thread_idx); + + Tensor tNewVrB = newv_thr_mma.partition_fragment_B(sKK_opd); + Tensor tNewVrB_cv = newv_thr_copy_B.retile_D(tNewVrB); + Tensor tNewVsB = newv_thr_copy_B.partition_S(sKK_opd); + + /////////////////////////////////////////////////////////////////////////// + // K@V + auto kv_tiled_mma = TiledMmaKV{}; + auto kv_thr_mma = kv_tiled_mma.get_thread_slice(thread_idx); + auto kv_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, kv_tiled_mma); + auto kv_thr_copy_B = kv_tiled_copy_B.get_thread_slice(thread_idx); + + Tensor tKVrKV = partition_fragment_C(kv_thr_mma, select<0, 1>(TileShapeKV{})); + Tensor tKVrK = kv_thr_mma.partition_fragment_B(sK_DS(_, _, _0{})); + Tensor tKVrK_cv = kv_thr_copy_B.retile_D(tKVrK); + Tensor tKVsK = kv_thr_copy_B.partition_S(sK_DS); + + auto const cV = make_identity_tensor(Shape, Int>{}); + Tensor tKVcV = kv_thr_mma.partition_A(cV); + + /////////////////////////////////////////////////////////////////////////// + // Q@K@V + auto o1_tiled_mma = TiledMmaO1{}; + auto o1_thr_mma = o1_tiled_mma.get_thread_slice(thread_idx); + auto o2_tiled_mma = TiledMmaO2{}; + auto o2_thr_mma = o2_tiled_mma.get_thread_slice(thread_idx); + + auto o1_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, o1_tiled_mma); + auto o1_thr_copy_B = o1_tiled_copy_B.get_thread_slice(thread_idx); + auto o2_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, o2_tiled_mma); + auto o2_thr_copy_B = o2_tiled_copy_B.get_thread_slice(thread_idx); + + // A1 for Q@(KV) + // Tensor tOrKV = make_acc_into_op(tKVrKV, typename TiledMmaO1::LayoutA_TV{}); + // B1 for Q@(KV) + Tensor tOrQ = o1_thr_mma.partition_fragment_B(sQ_SD(_, _, _0{})); + Tensor tOrQ_cv = o1_thr_copy_B.retile_D(tOrQ); + Tensor tOsQ = o1_thr_copy_B.partition_S(sQ_SD); + + // A2 for QK@V + // Tensor tOrV = make_acc_into_op(tKVrKV, typename TiledMmaO2::LayoutA_TV{}); + // B2 for QK@V + Tensor tOrQK = o2_thr_mma.partition_fragment_B(sQK); + Tensor tOrQK_cv = o2_thr_copy_B.retile_D(tOrQK); + Tensor tOsQK = o2_thr_copy_B.partition_S(sQK); + + using O_R2S = typename CollectiveStoreO::CopyAtomR2S; + auto tiled_copy_o = make_tiled_copy_C(O_R2S{}, o1_tiled_mma); + auto thr_copy_o = tiled_copy_o.get_thread_slice(thread_idx); + auto tOsO = thr_copy_o.partition_D(sO); + + auto const cO = make_identity_tensor(Shape, Int>{}); + Tensor tOcO = o1_thr_mma.partition_C(cO); + + auto layout_OAlpha = flatten(make_layout( // broadcast Alpha vector to O size + make_layout(select<0, 1>(TileShapeO1{}), Stride<_0, _1>{}), // (D, Blk_Q) + select<1, 2>(SmemLayoutAlpha{}) // (Idx, pipe) + )); // (D, Blk_Q, Idx, pipe) + + auto tOrAlphaScale = o1_thr_mma.partition_C(Alpha.compose(layout_OAlpha))( + _, _, _, AlphaCumProdScaleIdx{}, _); // (frag, iter_D, iter_Blk_Q, pipe) + + auto const seq_idx = work_desc.seq_idx; + auto const q_head_idx = work_desc.q_head_idx(); + auto const k_head_idx = work_desc.k_head_idx(); + auto const v_head_idx = work_desc.v_head_idx(); + + auto qk_or_kk_mask = [&](auto& tQKrQK, auto is_final_block_, auto B /*valid seqlen*/) { + constexpr bool is_final_block = decltype(is_final_block_)::value; + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk(i); + auto [s, t] = coord; + bool pred = s >= t; + if constexpr (is_final_block) { + pred = pred && (s < B || t < B); + } + // for tKKrKK diagonal is garbage filled, will be processed during inversion + tQKrQK(i) = pred ? tQKrQK(i) : 0.0f; + }); + }; + + auto qk_epi = [&](auto& tQKrQK, auto const& alpha_smem_pipe_read) { + if constexpr (NeedsAlpha) { + Tensor Alpha_cumsum_log = Alpha(_, AlphaCumSumLogIdx{}, alpha_smem_pipe_read.index()); + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk(i); + auto [s, t] = coord; + float alpha = exp2f(Alpha_cumsum_log(s) - Alpha_cumsum_log(t)); + tQKrQK(i) *= alpha * scale; + }); + } else { + transform(tQKrQK, [scale](auto v) { return v * scale; }); + } + }; + + auto qk_store = [&](auto tQKrQK) { + static_assert(sizeof(Element) == 2); + using CopyOpR2S = SM90_U32x4_STSM_N; + auto tiled_copy_qk = make_tiled_copy_C(Copy_Atom{}, qk_tiled_mma); + auto thr_copy_qk = tiled_copy_qk.get_thread_slice(qk_thread_idx); + auto tQKsQK = thr_copy_qk.partition_D(sQK); + auto tQKrQK_cv = thr_copy_qk.retile_S(tQKrQK); + auto tQKrQK_cvt_cv = make_fragment_like(tQKrQK_cv); + cute::transform(tQKrQK_cv, tQKrQK_cvt_cv, [](auto v) { return Element(v); }); + copy(tiled_copy_qk, tQKrQK_cvt_cv, tQKsQK); + }; + + auto kk_epi = [&](auto& tKKrKK, auto const& alpha_smem_pipe_read, + auto const& beta_smem_pipe_read) { + if constexpr (NeedsAlpha) { + Tensor Alpha_cumsum_log = Alpha(_, AlphaCumSumLogIdx{}, alpha_smem_pipe_read.index()); + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk(i); + auto [s, t] = coord; + float alpha = exp2f(Alpha_cumsum_log(s) - Alpha_cumsum_log(t)); + tKKrKK(i) *= alpha; + }); + } + + if constexpr (NeedsBeta) { + Tensor Beta_ = Beta(_, beta_smem_pipe_read.index()); + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk(i); + auto [s, t] = coord; + tKKrKK(i) *= Beta_(s); + }); + } + }; + + auto kk_store_and_inv = [&](auto tKKrKK) INLINE_LAMBDA { + static_assert(sizeof(Element) == 2); + using CopyOpR2S = SM90_U32x4_STSM_N; + auto tiled_store_kk = make_tiled_copy_C(Copy_Atom{}, kk_tiled_mma); + auto thr_store_kk = tiled_store_kk.get_thread_slice(kk_thread_idx); + auto tKKsKK = thr_store_kk.partition_D(sKK_inv); + auto tKKrKK_cv = thr_store_kk.retile_S(tKKrKK); + auto tKKrKK_cvt_cv = make_fragment_like(tKKrKK_cv); + cute::transform(tKKrKK_cv, tKKrKK_cvt_cv, [](auto v) { return InverseType(v); }); + copy(tiled_store_kk, tKKrKK_cvt_cv, tKKsKK); + + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, + DeltaRuleNamedBarriers::AuxMath); + + auto collective_inverse = CollectiveInverse(DeltaRuleNamedBarriers::AuxMath); + collective_inverse.compute(sKK_inv); + + // FIXME: we can ignore core matrices above diagonal + if constexpr (NeedsBeta || !std::is_same_v) { + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, + DeltaRuleNamedBarriers::AuxMath); + using CopyOpS2R = SM75_U32x4_LDSM_N; + auto tiled_load_kk = make_tiled_copy_C(Copy_Atom{}, kk_tiled_mma); + auto thr_load_kk = tiled_load_kk.get_thread_slice(kk_thread_idx); + auto tKKrKK_cpy = make_fragment_like(tKKrKK_cvt_cv); + auto tKKrKK_cvt = make_fragment_like(tKKrKK_cvt_cv); + auto tKKcMkk_cv = thr_load_kk.retile_D(tKKcMkk); + copy(tiled_load_kk, thr_load_kk.partition_S(sKK_inv), tKKrKK_cpy); + cute::transform(tKKrKK_cpy, tKKcMkk_cv, tKKrKK_cvt, [&](auto val, auto coord) { + auto [_, t] = coord; + if constexpr (NeedsBeta) { + return Element(float(val) * Beta(t, beta_smem_pipe_read.index())); + } else { + return Element(val); + } + }); + copy(tiled_store_kk, tKKrKK_cvt, recast(tKKsKK)); + } + }; + + auto sk_epi = [&](auto& tSKrSK, auto const& alpha_smem_pipe_read) INLINE_LAMBDA { + if constexpr (NeedsAlpha) { + transform(tSKrSK, tSKrAlpha(_, _, _, alpha_smem_pipe_read.index()), tSKrSK, + [&](auto sk, auto coeff) { return sk * coeff; }); + } + }; + + auto sk_load_v = [&](int pipe_idx) INLINE_LAMBDA { + Tensor tSKrV = + make_fragment_like(partition_fragment_C(sk_thr_mma, sV_DS(_, _, _0{}))); + Tensor tSKrV_cv = sk_thr_copy_C.retile_D(tSKrV); + Tensor tSKsV = sk_thr_copy_C.partition_S(sV_DS); + copy(sk_tiled_copy_C, tSKsV(_, _, _, pipe_idx), tSKrV_cv); + return tSKrV; + }; + + auto kv_decay_v = [&](auto& tKVrV, auto const& alpha_smem_pipe_read, auto is_final_block_, + auto B) INLINE_LAMBDA { + constexpr bool is_final_block = decltype(is_final_block_)::value; + if constexpr (NeedsAlpha) { + Tensor Alpha_cumsum_log = Alpha(_, AlphaCumSumLogIdx{}, alpha_smem_pipe_read.index()); + float block_coeff_log = Alpha_cumsum_log(B - 1); + cute::transform(tKVrV, tKVcV, tKVrV, [&](auto val, auto coord) { + auto tok = get<1>(coord); + float coeff = [&] { + if constexpr (!is_final_block) { + return exp2f(block_coeff_log - Alpha_cumsum_log(tok)); + } else { + return tok < B ? exp2f(block_coeff_log - Alpha_cumsum_log(tok)) : 0.0f; + } + }(); + return decltype(val)(val * coeff); + }); + } + if constexpr (is_final_block) { + if constexpr (!NeedsAlpha) { + cute::transform(tKVrV, tKVcV, tKVrV, [&](auto val, auto coord) { + auto tok = get<1>(coord); + return tok < B ? val : zero; // mask v of tail oob values + }); + } + } + }; + + auto kv_load = [&](auto& tKVrKV) INLINE_LAMBDA { + DPRINTF0_WG("[%d,%d,%d,%d]>> load tKVgKV -> tKVrKV\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + int num_state_heads = problem_size.num_sab_heads; + int state_head_idx = work_desc.o_head_idx(); + auto gKV = make_tensor(make_gmem_ptr(params.ptr_input_state), + make_layout(make_shape(Int{}, Int{}, + num_state_heads, problem_size.num_seqs)))( + _, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous + + auto tiled_copy_kv = + make_tiled_copy_C(Copy_Atom{}, kv_tiled_mma); + auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); + + auto tKVgKV = thr_copy_kv.partition_S(select_tensor<1, 0>(gKV)); + copy(tiled_copy_kv, tKVgKV, tKVrKV); + }; + + auto kv_store = [&]() INLINE_LAMBDA { // tKVrKV is carried over whole mainloop + DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> tKVgKV\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + int num_state_heads = problem_size.num_sab_heads; + int state_head_idx = work_desc.o_head_idx(); // num_o_heads == num_sab_heads + auto gKV = make_tensor(make_gmem_ptr(params.ptr_output_state), + make_layout(make_shape(Int{}, Int{}, + num_state_heads, problem_size.num_seqs)))( + _, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous + + auto tiled_copy_kv = + make_tiled_copy_C(Copy_Atom{}, kv_tiled_mma); + auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); + + auto tKVgKV = thr_copy_kv.partition_D(select_tensor<1, 0>(gKV)); + copy(tiled_copy_kv, tKVrKV, tKVgKV); + }; + + auto kv_checkpoint_store = [&](int checkpoint_idx) INLINE_LAMBDA { + if constexpr (kEnableCheckpointing) { + DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> checkpoint[%d]\n", seq_idx, q_head_idx, + k_head_idx, v_head_idx, checkpoint_idx); + int num_state_heads = problem_size.num_sab_heads; + int state_head_idx = work_desc.o_head_idx(); + int64_t ckpt_offset = params.checkpoint_cu_starts[seq_idx] + checkpoint_idx; + + // Layout: [total_checkpoints, num_sab_heads, HeadSizeQK, HeadSizeV] LayoutLeft + auto gKV = + make_tensor(make_gmem_ptr(params.ptr_state_checkpoints + + ckpt_offset * num_state_heads * HeadSizeQK * HeadSizeV + + state_head_idx * HeadSizeQK * HeadSizeV), + make_layout(make_shape(Int{}, Int{}))); + + auto tiled_copy_kv = + make_tiled_copy_C(Copy_Atom{}, kv_tiled_mma); + auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); + + auto tKVgKV = thr_copy_kv.partition_D(select_tensor<1, 0>(gKV)); + copy(tiled_copy_kv, tKVrKV, tKVgKV); + } + }; + + auto o1_epi = [&](auto& tOrO1, auto const& alpha_smem_pipe_read) INLINE_LAMBDA { + if constexpr (NeedsAlpha) { + auto tOrAlphaScale_ = tOrAlphaScale(_, _, _, alpha_smem_pipe_read.index()); + CUTE_UNROLL + for (int i = 0; i < size(tOrO1); ++i) { + tOrO1(i) = tOrAlphaScale_(i) * tOrO1(i); + } + } else { + CUTE_UNROLL + for (int i = 0; i < size(tOrO1); ++i) { + tOrO1(i) = scale * tOrO1(i); + } + } + }; + + auto o_store = [&](auto tOrO) INLINE_LAMBDA { + auto tOrO_cvt = make_fragment_like(tOrO); + copy(tOrO, tOrO_cvt); + + DPRINTF0_WG("compute: o_pipeline.producer_wait: smem_pipe_write:%d\n", + o_smem_pipe_write.index()); + o_pipeline.producer_acquire(o_smem_pipe_write); + Tensor tOrO_cvt_cv = thr_copy_o.retile_S(tOrO_cvt); + cutlass::arch::fence_view_async_shared(); + copy(tiled_copy_o, tOrO_cvt_cv, tOsO(_, _, _, o_smem_pipe_write.index())); + cutlass::arch::fence_view_async_shared(); + o_pipeline.producer_commit(o_smem_pipe_write); + ++o_smem_pipe_write; + }; + + auto compute_loop_body = [&](int blk, auto is_first_block_, + auto is_final_block_) INLINE_LAMBDA { + constexpr bool is_first_block = decltype(is_first_block_)::value; + constexpr bool is_final_block = decltype(is_final_block_)::value; + int B = is_final_block ? valid_seq_len(work_desc, blk) : BlkSeqKV; + + DPRINTF0_WG("compute: k_pipeline.consumer_wait: smem_pipe_read:%d\n", + k_smem_pipe_read.index()); + k_pipeline.consumer_wait(k_smem_pipe_read); + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_wait(alpha_smem_pipe_read); + } + if constexpr (NeedsBeta) { + beta_pipeline.consumer_wait(beta_smem_pipe_read); + } + do { + if (!is_kk_wg) { + __syncwarp(); + break; + } + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch KK MMA\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + copy(kk_tiled_copy_A, tKKsA(_, _, _, k_smem_pipe_read.index()), tKKrA_cv); + copy(kk_tiled_copy_B, tKKsB(_, _, _, k_smem_pipe_read.index()), tKKrB_cv); + + Tensor tKKrKK = partition_fragment_C(TiledMmaKK{}, select<0, 1>(TileShapeKK{})); + clear(tKKrKK); + gemm(kk_tiled_mma, tKKrA, tKKrB, tKKrKK); + + kk_epi(tKKrKK, alpha_smem_pipe_read, beta_smem_pipe_read); + qk_or_kk_mask(tKKrKK, is_final_block_, B); + kk_store_and_inv(tKKrKK); + } while (0); + if constexpr (NeedsBeta) { + beta_pipeline.consumer_release(beta_smem_pipe_read); + ++beta_smem_pipe_read; + } + + DPRINTF0_WG("compute: q_pipeline.consumer_wait: smem_pipe_read:%d\n", + q_smem_pipe_read.index()); + q_pipeline.consumer_wait(q_smem_pipe_read); + do { + if (!is_qk_wg) { + __syncwarp(); + break; + } + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch QK MMA\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + copy(qk_tiled_copy_A, tQKsQ(_, _, _, q_smem_pipe_read.index()), tQKrQ_cv); + copy(qk_tiled_copy_B, tQKsK(_, _, _, k_smem_pipe_read.index()), tQKrK_cv); + + Tensor tQKrQK = partition_fragment_C(TiledMmaQK{}, select<0, 1>(TileShapeQK{})); + clear(tQKrQK); + gemm(qk_tiled_mma, tQKrQ, tQKrK, tQKrQK); + + qk_epi(tQKrQK, alpha_smem_pipe_read); + qk_or_kk_mask(tQKrQK, is_final_block_, B); + qk_store(tQKrQK); + } while (0); + + // 2.1 Q @ KV, NOTE: use old KV here + auto tOrO = partition_fragment_C(o1_thr_mma, select<0, 1>(TileShapeO1{})); + clear(tOrO); + if constexpr (!is_first_block) { + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch O1 MMA\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + copy(o1_tiled_copy_B, tOsQ(_, _, _, q_smem_pipe_read.index()), tOrQ_cv); + Tensor tOrKV = detail::SM80::make_acc_into_op(tKVrKV, o1_thr_mma); + gemm(o1_thr_mma, tOrKV, tOrQ, tOrO); + o1_epi(tOrO, alpha_smem_pipe_read); + } + DPRINTF0_WG("compute: q_pipeline.consumer_release: smem_pipe_read:%d\n", + q_smem_pipe_read.index()); + q_pipeline.consumer_release(q_smem_pipe_read); + ++q_smem_pipe_read; + + auto tSKrSK = partition_fragment_C(sk_thr_mma, sV_DS(_, _, _0{})); + if constexpr (!is_first_block) { + auto tSKrS = detail::SM80::make_acc_into_op(tKVrKV, sk_tiled_mma); + copy(sk_tiled_copy_B, tSKsK(_, _, _, k_smem_pipe_read.index()), tSKrK_cv); + clear(tSKrSK); + gemm(sk_tiled_mma, tSKrS, tSKrK, tSKrSK); + } + + DPRINTF0_WG("compute: v_pipeline.consumer_wait: smem_pipe_read:%d\n", + v_smem_pipe_read.index()); + v_pipeline.consumer_wait(v_smem_pipe_read); + auto tSKrV = sk_load_v(v_smem_pipe_read.index()); + if constexpr (!is_first_block) { + sk_epi(tSKrSK, alpha_smem_pipe_read); + transform(tSKrV, tSKrSK, tSKrV, [](auto v, auto sk) { return v - Element(sk); }); + } + + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch NewV MMA\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + auto tNewVrA = detail::SM80::make_acc_into_op(tSKrV, newv_tiled_mma); + auto tNewVrC = partition_fragment_C(newv_thr_mma, select<0, 1>(TileShapeNewV{})); + math_barriers.ordered_or_wait(warpgroup_idx); + copy(newv_tiled_copy_B, tNewVsB, tNewVrB_cv); + clear(tNewVrC); + gemm(newv_tiled_mma, tNewVrA, tNewVrB, tNewVrC); + math_barriers.notify_next_blocked(warpgroup_idx); + DPRINTF0_WG("compute: v_pipeline.consumer_release: smem_pipe_read:%d\n", + v_smem_pipe_read.index()); + v_pipeline.consumer_release(v_smem_pipe_read); + ++v_smem_pipe_read; + + ///////////////////////////////////////////////////////////////////////// + // 2. compute qkv + // 2.2 QK @ V, NOTE: use old KV here and QK is scaled + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch O2 MMA\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); + auto tOrV_or_tKVrV = detail::SM80::make_acc_into_op(tNewVrC, o2_tiled_mma); + math_barriers.ordered_or_wait(warpgroup_idx); + copy(o2_tiled_copy_B, tOsQK, tOrQK_cv); + gemm(o2_tiled_mma, tOrV_or_tKVrV, tOrQK, tOrO); + math_barriers.notify_next_blocked(warpgroup_idx); + o_store(tOrO); + + ///////////////////////////////////////////////////////////////////////// + // 3. update KV + float block_coeff = 1.0f; + if constexpr (NeedsAlpha) { + block_coeff = Alpha(B - 1, AlphaCumProdIdx{}, alpha_smem_pipe_read.index()); + } + + cute::transform(tKVrKV, [&](auto kv) { return block_coeff * kv; }); + kv_decay_v(tOrV_or_tKVrV, alpha_smem_pipe_read, is_final_block_, B); + + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch KV MMA\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); + copy(kv_tiled_copy_B, tKVsK(_, _, _, k_smem_pipe_read.index()), tKVrK_cv); + gemm(kv_tiled_mma, tOrV_or_tKVrV, tKVrK, tKVrKV); + + DPRINTF0_WG("compute: k_pipeline.consumer_release: smem_pipe_read:%d\n", + k_smem_pipe_read.index()); + k_pipeline.consumer_release(k_smem_pipe_read); + ++k_smem_pipe_read; + + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_release(alpha_smem_pipe_read); + ++alpha_smem_pipe_read; + } + }; + + int ckpt_blk_interval = + (params.checkpoint_every_n_tokens > 0) ? params.checkpoint_every_n_tokens / BlkSeqKV : 0; + int ckpt_count = 0; + + if constexpr (!kInitStateFromInput) { + clear(tKVrKV); + compute_loop_body(0, /*is_first_block_=*/cute::true_type{}, + /*is_final_block_=*/cute::true_type{}); + } else { + kv_load(tKVrKV); + compute_loop_body(0, /*is_first_block_=*/cute::false_type{}, + /*is_final_block_=*/cute::true_type{}); + } + if constexpr (kEnableCheckpointing) { + if (ckpt_blk_interval == 1) { + kv_checkpoint_store(ckpt_count++); + } + } + CUTE_NO_UNROLL + for (int blk = 1; blk < num_blocks - 1; ++blk) { + compute_loop_body(blk, /*is_first_block_=*/cute::false_type{}, + /*is_final_block_=*/cute::false_type{}); + if constexpr (kEnableCheckpointing) { + if ((blk + 1) % ckpt_blk_interval == 0) { + kv_checkpoint_store(ckpt_count++); + } + } + } + if (num_blocks != 1) { + compute_loop_body(num_blocks - 1, /*is_first_block_=*/cute::false_type{}, + /*is_final_block_=*/cute::true_type{}); + // Only checkpoint on exact boundaries; the final (possibly partial) state + // is always available via output_state from kv_store() below. + if constexpr (kEnableCheckpointing) { + if (num_blocks % ckpt_blk_interval == 0) { + kv_checkpoint_store(ckpt_count); + } + } + } + kv_store(); + } + + template + CUTE_DEVICE int valid_seq_len(WorkDesc work_desc, int blk_idx) { + int remain_len = work_desc.seq_len - BlkSeqKV * blk_idx; + return remain_len <= BlkSeqKV ? remain_len : BlkSeqKV; + } +}; + +} // namespace flat::collective diff --git a/include/flashinfer/flat/sm120/device/device_universal.hpp b/include/flashinfer/flat/sm120/device/device_universal.hpp new file mode 100644 index 0000000000..410c1ec3a3 --- /dev/null +++ b/include/flashinfer/flat/sm120/device/device_universal.hpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2026 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// Simply reuse hopper code, this file is exposed to kernel instantiation, +// so we put it under sm120/ to avoid confusion. +#include "flashinfer/flat/hopper/device/device_universal.hpp" diff --git a/include/flashinfer/flat/sm120/kernel/flat_kernel_builder_delta_rule.hpp b/include/flashinfer/flat/sm120/kernel/flat_kernel_builder_delta_rule.hpp new file mode 100644 index 0000000000..e87ba1b2d2 --- /dev/null +++ b/include/flashinfer/flat/sm120/kernel/flat_kernel_builder_delta_rule.hpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2026 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "flashinfer/flat/sm120/collective/flat_collective_tma_warpspecialized_delta_rule.hpp" +#include "flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp" + +// Reuse hopper code +#include "flashinfer/flat/hopper/kernel/flat_options.hpp" +#include "flashinfer/flat/hopper/kernel/flat_tile_scheduler.hpp" +#include "flashinfer/flat/type_traits.hpp" + +namespace flat::kernel { + +template +struct FlatBuilderDeltaRule; + +template +struct FlatBuilderDeltaRule { + using CollectiveMainloop = flat::collective::FlatMainloopTmaWarpSpecializedDeltaRule< + Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShape, LayoutQ, LayoutK, LayoutV, + LayoutO, Options>; + + static constexpr bool kIsPersistent = + find_option_t::value; + static_assert(!kIsPersistent, "not implemented"); + + static constexpr bool kIsGVA = find_option_t::value; + using GroupingTag = std::conditional_t; + using TileScheduler = flat::kernel::IndividualTileScheduler; + // using TileScheduler = std::conditional_t; + + using Kernel = flat::kernel::FlatKernelTmaWarpSpecializedDeltaRule; +}; + +} // namespace flat::kernel diff --git a/include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp b/include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp new file mode 100644 index 0000000000..c5715d2053 --- /dev/null +++ b/include/flashinfer/flat/sm120/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2026 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/pipeline.hpp" + +// Reuse hopper code +#include "flashinfer/flat/common.hpp" +#include "flashinfer/flat/hopper/kernel/flat_options.hpp" +#include "flashinfer/flat/unused.hpp" + +namespace flat::kernel { + +using namespace cute; + +template +constexpr T1 round_down(T1 a, T2 b) { + return (a / b) * b; +} + +constexpr std::tuple get_register_requirements( + uint32_t max_threads_per_block, uint32_t min_blocks_per_multiprocessor, + uint32_t num_mma_warp_groups) { + uint32_t reg_alloc_granularity = 8; + +#if !defined(FLAT_DEBUG_PRINT) || !FLAT_DEBUG_PRINT + uint32_t load_registers = 40 - 2 * reg_alloc_granularity; +#else + uint32_t load_registers = 40; +#endif + uint32_t total_registers = round_down(64 * 1024 / min_blocks_per_multiprocessor, + max_threads_per_block * reg_alloc_granularity) / + cutlass::NumThreadsPerWarpGroup; + uint32_t mma_registers = + round_down((total_registers - load_registers) / num_mma_warp_groups, reg_alloc_granularity); + + // max reg is 255, 248 round to multiple of reg_alloc_granularity; + return {cute::min(248, load_registers), cute::min(248, mma_registers)}; +} + +template +struct FlatKernelTmaWarpSpecializedDeltaRule { + using ArchTag = cutlass::arch::Sm120; + + static const int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = CollectiveMainloop::NumMmaWarpGroups; + + static constexpr int NeedsAlpha = CollectiveMainloop::NeedsAlpha; + static constexpr int NeedsBeta = CollectiveMainloop::NeedsBeta; + + using TileShape = typename CollectiveMainloop::TileShape; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using MainloopQPipeline = typename CollectiveMainloop::MainloopQPipeline; + using MainloopKPipeline = typename CollectiveMainloop::MainloopKPipeline; + using MainloopVPipeline = typename CollectiveMainloop::MainloopVPipeline; + using MainloopOPipeline = typename CollectiveMainloop::MainloopOPipeline; + + using MainloopAlphaPipeline = typename CollectiveMainloop::MainloopAlphaPipeline; + using MainloopBetaPipeline = typename CollectiveMainloop::MainloopBetaPipeline; + + using OrderedMathBarriers = typename CollectiveMainloop::OrderedMathBarriers; + + static constexpr uint32_t StagesPerMathWarpGroup = 2; + + // FIXME: remove this after moving to HMMA + using MathWarpGroupOrderBarrier = + cutlass::OrderedSequenceBarrier; + + struct TensorStorage { + typename CollectiveMainloop::SharedStorage mainloop; + }; + + struct SharedStorage { + TensorStorage tensors; + + using QPipelineStorage = typename MainloopQPipeline::SharedStorage; + using KPipelineStorage = typename MainloopKPipeline::SharedStorage; + using VPipelineStorage = typename MainloopVPipeline::SharedStorage; + using OPipelineStorage = typename MainloopOPipeline::SharedStorage; + + alignas(16) QPipelineStorage q_pipeline_storage; + alignas(16) KPipelineStorage k_pipeline_storage; + alignas(16) VPipelineStorage v_pipeline_storage; + alignas(16) OPipelineStorage o_pipeline_storage; + + using AlphaPipelineStorage = typename MainloopAlphaPipeline::SharedStorage; + using BetaPipelineStorage = typename MainloopBetaPipeline::SharedStorage; + alignas(16) AlphaPipelineStorage alpha_pipeline_storage; + alignas(16) BetaPipelineStorage beta_pipeline_storage; + + alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct VarlenProblemShape { + int64_t const* cu_seqlens; + int64_t total_seqlen; + int32_t num_seqs; + int32_t num_q_heads; + int32_t num_k_heads; + int32_t num_v_heads; + int32_t num_o_heads; + int32_t num_sab_heads; // state, alpha, beta + int32_t head_size; // d + }; + using ProblemShape = VarlenProblemShape; + + struct Arguments { + ProblemShape problem_size; + typename CollectiveMainloop::Arguments mainloop; + cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_size; + typename CollectiveMainloop::Params mainloop; + typename TileScheduler::Params scheduler; + }; + + using QPipelineParams = typename MainloopQPipeline::Params; + using QPipelineState = typename cutlass::PipelineState; + + using KPipelineParams = typename MainloopKPipeline::Params; + using KPipelineState = typename cutlass::PipelineState; + + using VPipelineParams = typename MainloopVPipeline::Params; + using VPipelineState = typename cutlass::PipelineState; + + using OPipelineParams = typename MainloopOPipeline::Params; + using OPipelineState = typename cutlass::PipelineState; + + using AlphaPipelineParams = + std::conditional_t; + using AlphaPipelineState = + std::conditional_t, Unused>; + + using BetaPipelineParams = + std::conditional_t; + using BetaPipelineState = + std::conditional_t, Unused>; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int MaxThreadsPerBlock = + (NumLoadWarpGroups + NumMmaWarpGroups) * cutlass::NumThreadsPerWarpGroup; + + static constexpr auto RegisterRequirements = + get_register_requirements(MaxThreadsPerBlock, MinBlocksPerMultiprocessor, NumMmaWarpGroups); + static constexpr uint32_t LdStRegisterRequirement = get<0>(RegisterRequirements); + static constexpr uint32_t MmaRegisterRequirement = get<1>(RegisterRequirements); + + static size_t get_workspace_size(Arguments const& args) { + return CollectiveMainloop::get_workspace_size(args.mainloop, args.hw_info.sm_count); + } + + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace, + cudaStream_t stream) { + return CollectiveMainloop::initialize_workspace(args.problem_size, args.mainloop, workspace, + stream); + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_size, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_size, + CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace), + TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, + TileShape{})}; + } + + CUTE_DEVICE void operator()(const Params& params, char* smem) { + enum class WarpGroupRole { + LdSt = 0, + Math0 = 1, + Math1 = 2, + }; + + // NOTE: CollectiveInverse will have more utilization on warp 0&1 + // so we put beta and alpha preprocessing on warp 2&3 + enum class LdStWarpRole { + LoadQKV = 0, + StoreO = 1, + LoadBeta = 2, + LoadAlpha = 3, + }; + + TileScheduler scheduler{params.scheduler}; + + // Shared memory. + auto& storage = *reinterpret_cast(smem); + + int lane_idx = cutlass::canonical_lane_idx(); + int warp_idx = cutlass::canonical_warp_idx_sync(); + int warp_idx_in_wg = warp_idx % cutlass::NumWarpsPerWarpGroup; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + auto warp_group_role = WarpGroupRole(warp_group_idx); + auto ldst_warp_role = LdStWarpRole(warp_idx_in_wg); + + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + + QPipelineParams q_pipeline_params; + q_pipeline_params.transaction_bytes = CollectiveMainloop::LoadQBytes; + q_pipeline_params.is_leader = lane_predicate && (ldst_warp_role == LdStWarpRole::LoadQKV); + q_pipeline_params.num_consumers = NumMmaThreads; + + KPipelineParams k_pipeline_params; + k_pipeline_params.transaction_bytes = CollectiveMainloop::LoadKBytes; + k_pipeline_params.is_leader = lane_predicate && (ldst_warp_role == LdStWarpRole::LoadQKV); + k_pipeline_params.num_consumers = NumMmaThreads; + + VPipelineParams v_pipeline_params; + v_pipeline_params.transaction_bytes = CollectiveMainloop::LoadVBytes; + v_pipeline_params.is_leader = lane_predicate && (ldst_warp_role == LdStWarpRole::LoadQKV); + v_pipeline_params.num_consumers = NumMmaThreads; + + OPipelineParams o_pipeline_params; + o_pipeline_params.producer_arv_count = NumMmaThreads; + o_pipeline_params.consumer_arv_count = cutlass::NumThreadsPerWarp; + + AlphaPipelineParams alpha_pipeline_params; + if constexpr (NeedsAlpha) { + alpha_pipeline_params.producer_arv_count = cutlass::NumThreadsPerWarp; + alpha_pipeline_params.consumer_arv_count = NumMmaThreads; + } + + BetaPipelineParams beta_pipeline_params; + if constexpr (NeedsBeta) { + beta_pipeline_params.producer_arv_count = cutlass::NumThreadsPerWarp; + beta_pipeline_params.consumer_arv_count = NumMmaThreads; + } + + OrderedMathBarriers math_barriers; + + if (warp_group_role == WarpGroupRole::LdSt && ldst_warp_role == LdStWarpRole::LoadQKV) { + DPRINTF0_W("ldst_warp_role: LoadQKV\n"); + q_pipeline_params.role = MainloopQPipeline::ThreadCategory::Producer; + k_pipeline_params.role = MainloopKPipeline::ThreadCategory::Producer; + v_pipeline_params.role = MainloopVPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::LdSt && ldst_warp_role == LdStWarpRole::StoreO) { + DPRINTF0_W("ldst_warp_role: StoreO\n"); + o_pipeline_params.role = MainloopOPipeline::ThreadCategory::Consumer; + } + if (warp_group_role == WarpGroupRole::LdSt && ldst_warp_role == LdStWarpRole::LoadBeta) { + if constexpr (NeedsBeta) { + beta_pipeline_params.role = MainloopBetaPipeline::ThreadCategory::Producer; + } + } + if (warp_group_role == WarpGroupRole::LdSt && ldst_warp_role == LdStWarpRole::LoadAlpha) { + if constexpr (NeedsAlpha) { + alpha_pipeline_params.role = MainloopAlphaPipeline::ThreadCategory::Producer; + } + } + if (warp_group_role == WarpGroupRole::Math0 || warp_group_role == WarpGroupRole::Math1) { + DPRINTF0_WG("warp_group_role: MathX\n"); + q_pipeline_params.role = MainloopQPipeline::ThreadCategory::Consumer; + k_pipeline_params.role = MainloopKPipeline::ThreadCategory::Consumer; + v_pipeline_params.role = MainloopVPipeline::ThreadCategory::Consumer; + o_pipeline_params.role = MainloopOPipeline::ThreadCategory::Producer; + + if constexpr (NeedsAlpha) { + alpha_pipeline_params.role = MainloopAlphaPipeline::ThreadCategory::Consumer; + } + + math_barriers.init(warp_group_idx - 1); + } + + MainloopQPipeline q_pipeline(storage.q_pipeline_storage, q_pipeline_params, ClusterShape{}); + MainloopKPipeline k_pipeline(storage.k_pipeline_storage, k_pipeline_params, ClusterShape{}); + MainloopVPipeline v_pipeline(storage.v_pipeline_storage, v_pipeline_params, ClusterShape{}); + MainloopOPipeline o_pipeline(storage.o_pipeline_storage, o_pipeline_params, + /*InitBarriers=*/cute::true_type{}); + + MainloopAlphaPipeline alpha_pipeline(storage.alpha_pipeline_storage, alpha_pipeline_params, + /*InitBarriers=*/cute::true_type{}); + MainloopBetaPipeline beta_pipeline(storage.beta_pipeline_storage, beta_pipeline_params, + /*InitBarriers=*/cute::true_type{}); + + QPipelineState q_smem_pipe_read; + QPipelineState q_smem_pipe_write = cutlass::make_producer_start_state(); + KPipelineState k_smem_pipe_read; + KPipelineState k_smem_pipe_write = cutlass::make_producer_start_state(); + VPipelineState v_smem_pipe_read; + VPipelineState v_smem_pipe_write = cutlass::make_producer_start_state(); + OPipelineState o_smem_pipe_read; + OPipelineState o_smem_pipe_write = cutlass::make_producer_start_state(); + + AlphaPipelineState alpha_smem_pipe_read; + AlphaPipelineState alpha_smem_pipe_write; + if constexpr (NeedsAlpha) { + alpha_smem_pipe_write = cutlass::make_producer_start_state(); + } + BetaPipelineState beta_smem_pipe_read; + BetaPipelineState beta_smem_pipe_write; + if constexpr (NeedsBeta) { + beta_smem_pipe_write = cutlass::make_producer_start_state(); + } + + // barrier sm or cluster level for initialization + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + DPRINTF0_WG("warpspecialized grid initialized\n"); + + CollectiveMainloop collective_mainloop; + + if (warp_group_role == WarpGroupRole::LdSt) { + DPRINTF0_WG("LsSt warp_group_idx:%d, RegisterRequirement:%d\n", warp_group_idx, + LdStRegisterRequirement); + cutlass::arch::warpgroup_reg_dealloc(); + if (ldst_warp_role == LdStWarpRole::LoadQKV) { + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + CUTE_NO_UNROLL + for (; work_desc.is_valid(params.scheduler); + work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) { + DPRINTF0_WG( + "LsSt working on LoadQ/K/V, seq_idx:%d, q/k/v_head_idx:(%d,%d,%d), seq_len:%lld)\n", + work_desc.seq_idx, work_desc.q_head_idx(), work_desc.k_head_idx(), + work_desc.v_head_idx(), work_desc.seq_len); + auto tile_shape = typename CollectiveMainloop::TileShape{}; + collective_mainloop.load_qkv(params.mainloop, params.problem_size, tile_shape, work_desc, + q_pipeline, q_smem_pipe_write, k_pipeline, k_smem_pipe_write, + v_pipeline, v_smem_pipe_write, storage.tensors.mainloop); + } + } else if (ldst_warp_role == LdStWarpRole::LoadBeta) { + if constexpr (NeedsBeta) { + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + CUTE_NO_UNROLL + for (; work_desc.is_valid(params.scheduler); + work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) { + DPRINTF0_WG("LsSt working on LoadBeta, seq_idx:%d, sab_head_idx:%d, seq_len:%lld)\n", + work_desc.seq_idx, work_desc.o_head_idx(), work_desc.seq_len); + auto tile_shape = typename CollectiveMainloop::TileShape{}; + collective_mainloop.load_beta(params.mainloop, params.problem_size, tile_shape, + work_desc, beta_pipeline, beta_smem_pipe_write, + storage.tensors.mainloop); + } + } + } else if (ldst_warp_role == LdStWarpRole::LoadAlpha) { + if constexpr (NeedsAlpha) { + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + CUTE_NO_UNROLL + for (; work_desc.is_valid(params.scheduler); + work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) { + DPRINTF0_WG("LsSt working on LoadAlpha, seq_idx:%d, sab_head_idx:%d, seq_len:%lld)\n", + work_desc.seq_idx, work_desc.o_head_idx(), work_desc.seq_len); + auto tile_shape = typename CollectiveMainloop::TileShape{}; + collective_mainloop.load_alpha(params.mainloop, params.problem_size, tile_shape, + work_desc, alpha_pipeline, alpha_smem_pipe_write, + storage.tensors.mainloop); + } + } + } else if (ldst_warp_role == LdStWarpRole::StoreO) { + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + DPRINTF0_WG("LsSt working on StoreO, seq_idx:%d, o_head_idx:%d, seq_len:%lld)\n", + work_desc.seq_idx, work_desc.o_head_idx(), work_desc.seq_len); + auto tile_shape = typename CollectiveMainloop::TileShape{}; + collective_mainloop.store(params.mainloop.tma_store_o, params.mainloop.tensormaps, + params.problem_size, tile_shape, work_desc, o_pipeline, + o_smem_pipe_read, storage.tensors.mainloop.smem_o); + } + } else if (warp_group_role == WarpGroupRole::Math0 || warp_group_role == WarpGroupRole::Math1) { + DPRINTF0_WG("Compute[state]: warp_group_idx:%d, RegisterRequirement:%d\n", warp_group_idx, + MmaRegisterRequirement); + cutlass::arch::warpgroup_reg_alloc(); + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + CUTE_NO_UNROLL + for (; work_desc.is_valid(params.scheduler); + work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) { + DPRINTF0_WG("Compute[state]: seq_idx:%d, qk/v/o_head_idx:(%d,%d,%d,%d), seq_len:%lld)\n", + work_desc.seq_idx, work_desc.q_head_idx(), work_desc.k_head_idx(), + work_desc.v_head_idx(), work_desc.o_head_idx(), work_desc.seq_len); + collective_mainloop.compute(params.mainloop, params.problem_size, work_desc, q_pipeline, + q_smem_pipe_read, k_pipeline, k_smem_pipe_read, v_pipeline, + v_smem_pipe_read, o_pipeline, o_smem_pipe_write, alpha_pipeline, + alpha_smem_pipe_read, beta_pipeline, beta_smem_pipe_read, + math_barriers, storage.tensors.mainloop); + } + } else { + DPRINTF0_WG("Unknown warp role, warp_group_idx:%d\n", warp_group_idx); + } + + __syncthreads(); + } +}; + +} // namespace flat::kernel diff --git a/tests/gdn/test_prefill_delta_rule.py b/tests/gdn/test_prefill_delta_rule.py index 85dac83efd..471bac7ddb 100644 --- a/tests/gdn/test_prefill_delta_rule.py +++ b/tests/gdn/test_prefill_delta_rule.py @@ -25,7 +25,11 @@ from .reference_delta_rule import exclusive_cumsum, blockwise_delta_rule -from flashinfer.utils import is_sm90a_supported, is_sm100a_supported +from flashinfer.utils import ( + is_sm90a_supported, + is_sm100a_supported, + is_sm120a_supported, +) from flashinfer.gdn_prefill import chunk_gated_delta_rule @@ -38,6 +42,15 @@ def _skip_if_unsupported(): pytest.skip( f"SM100 GDN prefill requires CUDA 13+, got {torch.version.cuda}" ) + elif is_sm120a_supported(device): + cuda_major, cuda_minor = 0, 0 + if torch.version.cuda: + cuda_version = torch.version.cuda.split(".")[:2] + cuda_major, cuda_minor = (int(x) for x in cuda_version) + if (cuda_major, cuda_minor) < (12, 8): + pytest.skip( + f"SM120 GDN prefill requires CUDA 12.8+, got {torch.version.cuda}" + ) elif not is_sm90a_supported(device): pytest.skip("GDN prefill requires SM90 or SM100")