-
Notifications
You must be signed in to change notification settings - Fork 920
Implement Gated Delta Rule for sm_120a (Blackwell RTX) #3088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
9c83fd6
bbbba9f
ba18e9f
2a357e6
e77385a
f809316
b17d688
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| /* | ||
| * Copyright (c) 2025 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| // Extern template declarations to prevent implicit instantiation in the dispatcher. | ||
| // Explicit instantiations are in separate generated files for parallel compilation. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
| #include "cutlass/arch/arch.h" | ||
|
|
||
| namespace flat { | ||
|
|
||
| // clang-format off | ||
|
|
||
| #define FOR_EACH_BOOL_5(MACRO, ...) \ | ||
| MACRO(false, false, false, false, false, __VA_ARGS__) \ | ||
| MACRO(false, false, false, false, true, __VA_ARGS__) \ | ||
| MACRO(false, false, false, true, false, __VA_ARGS__) \ | ||
| MACRO(false, false, false, true, true, __VA_ARGS__) \ | ||
| MACRO(false, false, true, false, false, __VA_ARGS__) \ | ||
| MACRO(false, false, true, false, true, __VA_ARGS__) \ | ||
| MACRO(false, false, true, true, false, __VA_ARGS__) \ | ||
| MACRO(false, false, true, true, true, __VA_ARGS__) \ | ||
| MACRO(false, true, false, false, false, __VA_ARGS__) \ | ||
| MACRO(false, true, false, false, true, __VA_ARGS__) \ | ||
| MACRO(false, true, false, true, false, __VA_ARGS__) \ | ||
| MACRO(false, true, false, true, true, __VA_ARGS__) \ | ||
| MACRO(false, true, true, false, false, __VA_ARGS__) \ | ||
| MACRO(false, true, true, false, true, __VA_ARGS__) \ | ||
| MACRO(false, true, true, true, false, __VA_ARGS__) \ | ||
| MACRO(false, true, true, true, true, __VA_ARGS__) \ | ||
| MACRO(true, false, false, false, false, __VA_ARGS__) \ | ||
| MACRO(true, false, false, false, true, __VA_ARGS__) \ | ||
| MACRO(true, false, false, true, false, __VA_ARGS__) \ | ||
| MACRO(true, false, false, true, true, __VA_ARGS__) \ | ||
| MACRO(true, false, true, false, false, __VA_ARGS__) \ | ||
| MACRO(true, false, true, false, true, __VA_ARGS__) \ | ||
| MACRO(true, false, true, true, false, __VA_ARGS__) \ | ||
| MACRO(true, false, true, true, true, __VA_ARGS__) \ | ||
| MACRO(true, true, false, false, false, __VA_ARGS__) \ | ||
| MACRO(true, true, false, false, true, __VA_ARGS__) \ | ||
| MACRO(true, true, false, true, false, __VA_ARGS__) \ | ||
| MACRO(true, true, false, true, true, __VA_ARGS__) \ | ||
| MACRO(true, true, true, false, false, __VA_ARGS__) \ | ||
| MACRO(true, true, true, false, true, __VA_ARGS__) \ | ||
| MACRO(true, true, true, true, false, __VA_ARGS__) \ | ||
| MACRO(true, true, true, true, true, __VA_ARGS__) | ||
|
|
||
| #define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, enable_ckpt, ctype) \ | ||
| extern template void launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, enable_ckpt, cutlass::arch::Sm120, ctype, ctype, float>( \ | ||
| cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \ | ||
| float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, \ | ||
| int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t, float*, int64_t const*, \ | ||
| int32_t); | ||
|
|
||
| // Extern template declarations for half | ||
| FOR_EACH_BOOL_5(DECLARE_TEMPLATE_INSTANCE, half) | ||
|
|
||
| // Extern template declarations for nv_bfloat16 | ||
| FOR_EACH_BOOL_5(DECLARE_TEMPLATE_INSTANCE, nv_bfloat16) | ||
|
|
||
| #undef DECLARE_TEMPLATE_INSTANCE | ||
| #undef FOR_EACH_BOOL_5 | ||
|
|
||
| // clang-format on | ||
|
|
||
| } // namespace flat |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /* | ||
| * Copyright (c) 2025 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| // Auto-generated file for separate compilation of GDN prefill kernel variants. | ||
| // Template parameters: dtype={{ dtype }}, is_gva={{ is_gva }}, needs_beta={{ needs_beta }}, | ||
| // needs_alpha={{ needs_alpha }}, init_state={{ init_state }}, | ||
| // enable_checkpointing={{ enable_checkpointing }} | ||
|
|
||
| // CUDA type definitions for half and nv_bfloat16 | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
|
|
||
| // Include the header which defines the function template | ||
| // The header includes all necessary CUTLASS type definitions | ||
| #include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh" | ||
|
|
||
| namespace flat { | ||
|
|
||
| // Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai | ||
| // Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm120_extern.inc | ||
| template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, {{ enable_checkpointing }}, cutlass::arch::Sm120, {{ dtype }}, {{ dtype }}, float>( | ||
| cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*, | ||
| float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, | ||
| int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t, float*, int64_t const*, | ||
| int32_t); | ||
|
|
||
| } // namespace flat | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| /* | ||
| * Copyright (c) 2025 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #include <cuda_bf16.h> | ||
|
|
||
| #include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh" | ||
|
|
||
| // Extern template declarations prevent implicit instantiation here. | ||
| // Explicit instantiations are in separate generated files for parallel compilation. | ||
| #include "flat_prefill_kernel_delta_rule_sm120_extern.inc" | ||
|
|
||
| namespace flat { | ||
|
|
||
| using namespace cute; | ||
|
|
||
| template <typename ArchTag, // FIXME: hide this | ||
| typename TO, typename TQKV, typename TState> | ||
| void launch_delta_rule_prefill_kernel( | ||
| cudaStream_t stream, TO* output, TState* output_state, TQKV const* q, TQKV const* k, | ||
| TQKV const* v, TState const* input_state, float const* alpha, float const* beta, | ||
| int64_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, int32_t num_q_heads, | ||
| int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads, int32_t head_size, | ||
| int64_t total_seqlen, float scale, int32_t sm_count, float* state_checkpoints, | ||
| int64_t const* checkpoint_cu_starts, int32_t checkpoint_every_n_tokens) { | ||
| bool is_gva = num_v_heads > num_q_heads; | ||
| bool needs_beta = beta != nullptr; | ||
| bool needs_alpha = alpha != nullptr; | ||
| bool init_state = input_state != nullptr; | ||
| bool enable_ckpt = checkpoint_every_n_tokens > 0; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I can tell, this is only supported by the SM120 kernel, not the SM90 kernel, is it correct?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are all supported. |
||
|
|
||
| #define LAUNCH(is_gva, needs_beta, needs_alpha, init_state, enable_ckpt) \ | ||
| launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, enable_ckpt, \ | ||
| ArchTag>( \ | ||
| stream, output, output_state, q, k, v, input_state, alpha, beta, cu_seqlens, \ | ||
| workspace_buffer, num_seqs, num_q_heads, num_k_heads, num_v_heads, num_o_heads, head_size, \ | ||
| total_seqlen, scale, sm_count, state_checkpoints, checkpoint_cu_starts, \ | ||
| checkpoint_every_n_tokens); | ||
|
|
||
| #define DISPATCH_GBAI(init_state, enable_ckpt) \ | ||
| if (is_gva && needs_beta && needs_alpha) { \ | ||
| LAUNCH(true, true, true, init_state, enable_ckpt); \ | ||
| } else if (is_gva && needs_beta && !needs_alpha) { \ | ||
| LAUNCH(true, true, false, init_state, enable_ckpt); \ | ||
| } else if (is_gva && !needs_beta && needs_alpha) { \ | ||
| LAUNCH(true, false, true, init_state, enable_ckpt); \ | ||
| } else if (is_gva && !needs_beta && !needs_alpha) { \ | ||
| LAUNCH(true, false, false, init_state, enable_ckpt); \ | ||
| } else if (!is_gva && needs_beta && needs_alpha) { \ | ||
| LAUNCH(false, true, true, init_state, enable_ckpt); \ | ||
| } else if (!is_gva && needs_beta && !needs_alpha) { \ | ||
| LAUNCH(false, true, false, init_state, enable_ckpt); \ | ||
| } else if (!is_gva && !needs_beta && needs_alpha) { \ | ||
| LAUNCH(false, false, true, init_state, enable_ckpt); \ | ||
| } else if (!is_gva && !needs_beta && !needs_alpha) { \ | ||
| LAUNCH(false, false, false, init_state, enable_ckpt); \ | ||
| } else { \ | ||
| throw std::runtime_error("unreachable"); \ | ||
| } | ||
|
|
||
| if (enable_ckpt) { | ||
| if (init_state) { | ||
| DISPATCH_GBAI(true, true); | ||
| } else { | ||
| DISPATCH_GBAI(false, true); | ||
| } | ||
| } else { | ||
| if (init_state) { | ||
| DISPATCH_GBAI(true, false); | ||
| } else { | ||
| DISPATCH_GBAI(false, false); | ||
| } | ||
| } | ||
|
|
||
| #undef DISPATCH_GBAI | ||
| #undef LAUNCH | ||
| } | ||
|
|
||
| // Explicit instantiations for the outer dispatch function only. | ||
| // The inner launch_delta_rule_prefill_kernel_gbai instantiations are in separate files. | ||
| template void launch_delta_rule_prefill_kernel<cutlass::arch::Sm120, half, half, float>( | ||
| cudaStream_t stream, half* output, float* state, half const* q, half const* k, half const* v, | ||
| float const* input_state, float const* alpha, float const* beta, int64_t const* cu_seqlens, | ||
| uint8_t* workspace_buffer, int32_t num_seqs, int32_t num_q_heads, int32_t num_k_heads, | ||
| int32_t num_v_heads, int32_t num_o_heads, int32_t head_size, int64_t total_seqlen, float scale, | ||
| int32_t sm_count, float* state_checkpoints, int64_t const* checkpoint_cu_starts, | ||
| int32_t checkpoint_every_n_tokens); | ||
|
|
||
| template void | ||
| launch_delta_rule_prefill_kernel<cutlass::arch::Sm120, nv_bfloat16, nv_bfloat16, float>( | ||
| cudaStream_t stream, nv_bfloat16* output, float* state, nv_bfloat16 const* q, | ||
| nv_bfloat16 const* k, nv_bfloat16 const* v, float const* input_state, float const* alpha, | ||
| float const* beta, int64_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, | ||
| int32_t num_q_heads, int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads, | ||
| int32_t head_size, int64_t total_seqlen, float scale, int32_t sm_count, | ||
| float* state_checkpoints, int64_t const* checkpoint_cu_starts, | ||
| int32_t checkpoint_every_n_tokens); | ||
|
|
||
| } // namespace flat | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,20 +21,29 @@ | |
| import torch | ||
|
|
||
| from .api_logging import flashinfer_api | ||
| from .jit.gdn import gen_gdn_prefill_sm90_module | ||
| from .jit.gdn import gen_gdn_prefill_sm90_module, gen_gdn_prefill_sm120_module | ||
| from .utils import ( | ||
| register_custom_op, | ||
| register_fake_op, | ||
| get_device_sm_count, | ||
| is_sm90a_supported, | ||
| is_sm100a_supported, | ||
| is_sm120a_supported, | ||
| _get_cache_buf, | ||
| ) | ||
| from .gdn_kernels import chunk_gated_delta_rule_sm100, _has_blackwell_prefill | ||
|
|
||
|
|
||
| @functools.cache | ||
| def get_gdn_prefill_module(): | ||
| module = gen_gdn_prefill_sm90_module().build_and_load() | ||
| def get_gdn_prefill_module(device: torch.device): | ||
| if is_sm90a_supported(device): | ||
| module = gen_gdn_prefill_sm90_module().build_and_load() | ||
| elif is_sm120a_supported(device): | ||
| module = gen_gdn_prefill_sm120_module().build_and_load() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the test, we mention
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am actually not quite sure what is the recommended version. According to https://docs.nvidia.com/cuda/archive/12.8.1/blackwell-compatibility-guide/index.html
So I think we are safe to relax the requirement to 12.8.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Relaxed to 12.8 in test |
||
| else: | ||
| raise RuntimeError( | ||
| f"GDN prefill kernel requires SM90 or SM120, but device {device} is not supported" | ||
| ) | ||
|
|
||
| @register_custom_op( | ||
| "flashinfer::gdn_prefill", | ||
|
|
@@ -183,7 +192,7 @@ def chunk_gated_delta_rule( | |
| - Supports GQA: ``num_q_heads > num_k_heads = num_v_heads`` | ||
| - Supports GVA: ``num_v_heads > num_q_heads = num_k_heads`` | ||
| - The final state layout is ``[N, H, V, K]``. | ||
| - Requires SM90 (Hopper) or SM100 (Blackwell) architecture. | ||
| - Requires SM90 (Hopper) or SM100 (Blackwell) or SM120 (Blackwell RTX) architecture. | ||
| - SM100 path requires head_size == 128. | ||
| - SM100 path requires ``nvidia-cutlass-dsl[cu13]>=4.4.2`` | ||
| (install via ``pip install flashinfer-python[cu13]``). | ||
|
guangyunh-nv marked this conversation as resolved.
|
||
|
|
@@ -336,7 +345,7 @@ def chunk_gated_delta_rule( | |
| "gdn_prefill_workspace", workspace_size, device | ||
| ) | ||
|
|
||
| get_gdn_prefill_module().gdn_prefill( | ||
| get_gdn_prefill_module(q.device).gdn_prefill( | ||
| output, | ||
| output_state, | ||
| q, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,25 +24,37 @@ | |
| JitSpec, | ||
| gen_jit_spec, | ||
| sm90a_nvcc_flags, | ||
| sm120a_nvcc_flags, | ||
| ) | ||
| from .utils import write_if_different | ||
|
|
||
|
|
||
| def gen_gdn_prefill_sm90_module() -> JitSpec: | ||
| def _gen_gdn_prefill_module(arch: str) -> JitSpec: | ||
| """Generate JIT module for GDN prefill kernel with separate compilation. | ||
|
|
||
| This generates 32 separate kernel instantiation files (2 dtypes Γ 16 boolean combinations) | ||
| plus the original launcher file. The separate files enable parallel compilation by ninja, | ||
| significantly reducing build time on multi-core machines. | ||
| """ | ||
|
Comment on lines
33
to
38
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring count is stale: 32 combinations, not 16. The docstring says "32 separate kernel instantiation files (2 dtypes Γ 16 boolean combinations)", but the loop at Line 69 uses π‘ Proposed fix- This generates 32 separate kernel instantiation files (2 dtypes Γ 16 boolean combinations)
- plus the original launcher file. The separate files enable parallel compilation by ninja,
+ This generates 64 separate kernel instantiation files (2 dtypes x 32 boolean combinations)
+ plus the original launcher file. The separate files enable parallel compilation by ninja,
significantly reducing build time on multi-core machines.π§° Toolsπͺ Ruff (0.15.10)[warning] 35-35: Docstring contains ambiguous (RUF002) π€ Prompt for AI Agents
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is copied from SM90 version, we should open a seperate PR to fix them all in once.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
π§ Learnings used |
||
| uri = "gdn_prefill_sm90" | ||
| assert arch in ["sm90", "sm120"], ( | ||
| "GDN prefill kernel is only supported on sm_90a and sm_120a" | ||
| ) | ||
|
|
||
| if arch == "sm90": | ||
| arch_specific_flags = sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED"] | ||
| elif arch == "sm120": | ||
| arch_specific_flags = sm120a_nvcc_flags + ["-DFLAT_SM120A_ENABLED"] | ||
|
|
||
| uri = f"gdn_prefill_{arch}" | ||
| gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri | ||
| os.makedirs(gen_directory, exist_ok=True) | ||
|
|
||
| source_paths = [] | ||
|
|
||
| # Load kernel instantiation template | ||
| with open(jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_sm90_kernel_inst.jinja") as f: | ||
| with open( | ||
| jit_env.FLASHINFER_CSRC_DIR / f"gdn_prefill_{arch}_kernel_inst.jinja" | ||
| ) as f: | ||
| kernel_inst_templ = jinja2.Template(f.read()) | ||
|
|
||
| # Generate 64 separate instance files (2 dtypes Γ 32 boolean combinations) | ||
|
|
@@ -74,7 +86,7 @@ def gen_gdn_prefill_sm90_module() -> JitSpec: | |
| # Headers are now in include/flashinfer/flat/ and accessible via standard include paths | ||
| for filename in [ | ||
| "gdn_prefill_launcher.cu", | ||
| "prefill_kernel_delta_rule_sm90.cu", | ||
| f"prefill_kernel_delta_rule_{arch}.cu", | ||
| ]: | ||
| src_path = jit_env.FLASHINFER_CSRC_DIR / filename | ||
| dest_path = gen_directory / src_path.name | ||
|
|
@@ -84,5 +96,13 @@ def gen_gdn_prefill_sm90_module() -> JitSpec: | |
| return gen_jit_spec( | ||
| uri, | ||
| source_paths, | ||
| extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"], | ||
| extra_cuda_cflags=arch_specific_flags + ["-std=c++20"], | ||
| ) | ||
|
|
||
|
|
||
| def gen_gdn_prefill_sm90_module(): | ||
| return _gen_gdn_prefill_module("sm90") | ||
|
|
||
|
|
||
| def gen_gdn_prefill_sm120_module(): | ||
| return _gen_gdn_prefill_module("sm120") | ||
Uh oh!
There was an error while loading. Please reload this page.