Skip to content
Merged
6 changes: 6 additions & 0 deletions csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

#include "prefill_kernel_delta_rule_sm90.cuh"

// Extern template declarations prevent implicit instantiation here.
// Explicit instantiations are in separate generated files for parallel compilation.
#include "prefill_kernel_delta_rule_sm90_extern.inc"

namespace flat {

using namespace cute;
Expand Down Expand Up @@ -86,6 +90,8 @@ void launch_delta_rule_prefill_kernel(cudaStream_t stream, TO* output, TState* o
#undef LAUNCH
}

// Explicit instantiations for the outer dispatch function only.
// The inner launch_delta_rule_prefill_kernel_gbai instantiations are in separate files.
template void launch_delta_rule_prefill_kernel<cutlass::arch::Sm90, half, half, float>(
cudaStream_t stream, half* output, float* state, half const* q, half const* k, half const* v,
float const* input_state, float const* alpha, float const* beta, int64_t const* cu_seqlens,
Expand Down
64 changes: 64 additions & 0 deletions csrc/flat/prefill/prefill_kernel_delta_rule_sm90_extern.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Extern template declarations to prevent implicit instantiation in the dispatcher.
// Explicit instantiations are in separate generated files for parallel compilation.

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "cutlass/arch/arch.h"

namespace flat {

// clang-format off

#define FOR_EACH_BOOL_4(MACRO, ...) \
MACRO(false, false, false, false, __VA_ARGS__) \
MACRO(false, false, false, true, __VA_ARGS__) \
MACRO(false, false, true, false, __VA_ARGS__) \
MACRO(false, false, true, true, __VA_ARGS__) \
MACRO(false, true, false, false, __VA_ARGS__) \
MACRO(false, true, false, true, __VA_ARGS__) \
MACRO(false, true, true, false, __VA_ARGS__) \
MACRO(false, true, true, true, __VA_ARGS__) \
MACRO(true, false, false, false, __VA_ARGS__) \
MACRO(true, false, false, true, __VA_ARGS__) \
MACRO(true, false, true, false, __VA_ARGS__) \
MACRO(true, false, true, true, __VA_ARGS__) \
MACRO(true, true, false, false, __VA_ARGS__) \
MACRO(true, true, false, true, __VA_ARGS__) \
MACRO(true, true, true, false, __VA_ARGS__) \
MACRO(true, true, true, true, __VA_ARGS__)

#define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, ctype) \
extern template void launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, cutlass::arch::Sm90, ctype, ctype, float>( \
cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \
float const*, float const*, float const*, int64_t const*, int32_t, int32_t, \
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t);

// Extern template declarations for half
FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, half)

// Extern template declarations for nv_bfloat16
FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, nv_bfloat16)

#undef DECLARE_TEMPLATE_INSTANCE
#undef FOR_EACH_BOOL_4

// clang-format on

} // namespace flat
38 changes: 38 additions & 0 deletions csrc/gdn_prefill_sm90_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Auto-generated file for separate compilation of GDN prefill kernel variants.
// Template parameters: dtype={{ dtype }}, is_gva={{ is_gva }}, needs_beta={{ needs_beta }},
// needs_alpha={{ needs_alpha }}, init_state={{ init_state }}

#include <cuda_bf16.h>
#include <cuda_fp16.h>

// Ensure cutlass arch types are defined
#include "cutlass/arch/arch.h"

// Use full path since generated files are in a different directory
#include "flat/prefill/prefill_kernel_delta_rule_sm90.cuh"

namespace flat {

// Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai
// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm90_extern.inc
template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, cutlass::arch::Sm90, {{ dtype }}, {{ dtype }}, float>(
cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*,
float const*, float const*, float const*, int64_t const*, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t);

} // namespace flat
3 changes: 3 additions & 0 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ def gen_all_modules(
add_misc: bool,
add_xqa: bool,
) -> List[JitSpec]:
# TEMPORARY: Only compile gdn_prefill_sm90 for testing
Comment thread
yzh119 marked this conversation as resolved.
Outdated
return [gen_gdn_prefill_sm90_module()]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

jit_specs: List[JitSpec] = []
jit_specs.append(gen_spdlog_module())
has_sm90 = sm_capabilities.get("sm90", False)
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/gdn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def chunk_gated_delta_rule(
Note:
- Supports GQA: ``num_q_heads > num_k_heads = num_v_heads``
- Supports GVA: ``num_v_heads > num_q_heads = num_k_heads``
- The final state is in k-major layout ``[N, H, K, V]``.
- The final state is in k-last layout ``[N, H, V, K]``.
- Requires SM90 (Hopper) architecture.
"""
assert cu_seqlens is not None, "cu_seqlens is required for varlen mode"
Expand Down
17 changes: 15 additions & 2 deletions flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,29 @@ def gen_jit_spec(
verbose_env = os.environ.get("FLASHINFER_JIT_VERBOSE", "0")
debug = (debug_env if debug_env is not None else verbose_env) == "1"

cflags = ["-std=c++17", "-Wno-switch-bool"]
# Only add default C++ standard if not specified in extra flags
cflags_has_std = extra_cflags is not None and any(
f.startswith("-std=") for f in extra_cflags
)
cuda_cflags_has_std = extra_cuda_cflags is not None and any(
f.startswith("-std=") for f in extra_cuda_cflags
)

cflags = ["-Wno-switch-bool"]
if not cflags_has_std:
cflags.insert(0, "-std=c++17")

cuda_cflags = [
"-std=c++17",
f"--threads={os.environ.get('FLASHINFER_NVCC_THREADS', '1')}",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
if not cuda_cflags_has_std:
cuda_cflags.insert(0, "-std=c++17")

if debug:
cflags += ["-O0", "-g"]
cuda_cflags += [
Expand Down
70 changes: 62 additions & 8 deletions flashinfer/jit/gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,78 @@
limitations under the License.
"""

import itertools
import os

import jinja2

from . import env as jit_env
from .core import (
JitSpec,
gen_jit_spec,
sm90a_nvcc_flags,
)
from .utils import write_if_different


def gen_gdn_prefill_sm90_module() -> JitSpec:
"""Generate JIT module for GDN prefill kernel with separate compilation.

This generates 32 separate kernel instantiation files (2 dtypes Γ— 16 boolean combinations)
plus the original launcher file. The separate files enable parallel compilation by ninja,
significantly reducing build time on multi-core machines.
"""
uri = "gdn_prefill_sm90"
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
os.makedirs(gen_directory, exist_ok=True)

source_paths = []

# Load kernel instantiation template
with open(jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_sm90_kernel_inst.jinja") as f:
kernel_inst_templ = jinja2.Template(f.read())

# Generate 32 separate instance files (2 dtypes Γ— 16 boolean combinations)
Comment thread
yzh119 marked this conversation as resolved.
dtypes = [("half", "half"), ("bf16", "nv_bfloat16")]
for dtype_name, dtype in dtypes:
for is_gva, needs_beta, needs_alpha, init_state in itertools.product(
[False, True], repeat=4
):
suffix = f"{dtype_name}_g{int(is_gva)}b{int(needs_beta)}a{int(needs_alpha)}i{int(init_state)}"
filename = f"gdn_prefill_kernel_{suffix}.cu"
dest_path = gen_directory / filename
source_paths.append(dest_path)

source = kernel_inst_templ.render(
dtype=dtype,
is_gva=str(is_gva).lower(),
needs_beta=str(needs_beta).lower(),
needs_alpha=str(needs_alpha).lower(),
init_state=str(init_state).lower(),
)
write_if_different(dest_path, source)

# Copy source files to gen_directory (like POD module does)
for filename in [
"gdn_prefill_launcher.cu",
"flat/prefill/prefill_kernel_delta_rule_sm90.cu",
]:
src_path = jit_env.FLASHINFER_CSRC_DIR / filename
dest_path = gen_directory / src_path.name
source_paths.append(dest_path)
write_if_different(dest_path, src_path.read_text())

# Copy header files so relative includes work
for filename in [
"flat/prefill/prefill_kernel_delta_rule_sm90.cuh",
"flat/prefill/prefill_kernel_delta_rule_sm90_extern.inc",
]:
src_path = jit_env.FLASHINFER_CSRC_DIR / filename
write_if_different(gen_directory / src_path.name, src_path.read_text())

return gen_jit_spec(
name="gdn_prefill_launcher",
sources=[
jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "flat"
/ "prefill"
/ "prefill_kernel_delta_rule_sm90.cu",
],
uri,
source_paths,
extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
)
Comment on lines 79 to 83
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Check gen_jit_spec function signature and supported_major_versions usage
rg -n "supported_major_versions" --type py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 7679


🏁 Script executed:

# Find gen_jit_spec definition
rg -n "def gen_jit_spec" --type py -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 829


🏁 Script executed:

# Check for other JIT module generators to see usage patterns
fd "\.py$" flashinfer/jit/ | head -10 | xargs grep -l "gen_jit_spec" | xargs -I {} sh -c 'echo "=== {} ===" && grep -A 5 "gen_jit_spec(" {}'

Repository: flashinfer-ai/flashinfer

Length of output: 4164


🏁 Script executed:

cat -n flashinfer/jit/gdn.py

Repository: flashinfer-ai/flashinfer

Length of output: 3736


🏁 Script executed:

grep -n "sm90a_nvcc_flags" flashinfer/jit/core.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 92


🏁 Script executed:

# Get context around the definition
rg -n "sm90a_nvcc_flags\s*=" flashinfer/jit/core.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

sed -n '120,135p' flashinfer/jit/core.py

Repository: flashinfer-ai/flashinfer

Length of output: 810


🏁 Script executed:

# Check how other SM90-specific generators handle the architecture restriction
rg -n "SM90\|sm_90" flashinfer/jit/ --type py -B 3 -A 3 | grep -E "(supported_major_versions|sm90)" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

head -30 flashinfer/jit/gdn.py

Repository: flashinfer-ai/flashinfer

Length of output: 839


🏁 Script executed:

# Check if CompilationContext is available and how it's used
rg -n "CompilationContext\|current_compilation_context" flashinfer/jit/ --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check what common_nvcc_flags contains
sed -n '1,130p' flashinfer/jit/core.py | grep -A 10 "common_nvcc_flags"

Repository: flashinfer-ai/flashinfer

Length of output: 567


Refactor to use CompilationContext.get_nvcc_flags_list(supported_major_versions=[9]) pattern.

Per coding guidelines for JIT modules, supported_major_versions should be specified when creating nvcc flags. Other JIT modules (fused_moe.py, xqa.py, mamba, gemm, comm, attention) consistently use CompilationContext().get_nvcc_flags_list(supported_major_versions=[...]) before passing flags to gen_jit_spec(). This module should follow the same pattern instead of using the pre-defined sm90a_nvcc_flags constant:

compilation_context = CompilationContext()
nvcc_flags = compilation_context.get_nvcc_flags_list(supported_major_versions=[9])
nvcc_flags += ["-DFLAT_SM90A_ENABLED", "-std=c++20"]
return gen_jit_spec(
    uri,
    source_paths,
    extra_cuda_cflags=nvcc_flags,
    extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
)
🧰 Tools
πŸͺ› Ruff (0.14.14)

84-84: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

πŸ€– Prompt for AI Agents
In `@flashinfer/jit/gdn.py` around lines 81 - 86, Replace the hard-coded
sm90a_nvcc_flags usage with a CompilationContext-derived nvcc flags list: create
a CompilationContext(), call
CompilationContext.get_nvcc_flags_list(supported_major_versions=[9]) to get
nvcc_flags, append ["-DFLAT_SM90A_ENABLED","-std=c++20"], and pass that as
extra_cuda_cflags to gen_jit_spec (keep
extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR] unchanged); update references
to sm90a_nvcc_flags in this function to use the new nvcc_flags variable and
ensure CompressionContext is imported or available.

4 changes: 2 additions & 2 deletions tests/gdn/test_prefill_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _test_prefill_kernel(

torch.cuda.synchronize()

# postprocessing raw output, ref_state is v-major, our_state is k-major, unify to v-major for testing
# postprocessing raw output: ref_state is v-last [H,K,V], our_state is k-last [H,V,K], transpose to match
our_state = our_state.transpose(-1, -2)

ref_o, ref_state = blockwise_delta_rule(
Expand Down Expand Up @@ -330,7 +330,7 @@ def _test_chunked_prefill(

torch.cuda.synchronize()

# postprocessing raw output, ref_state is v-major, our_state is k-major, unify to v-major for testing
# postprocessing raw output: ref_state is v-last [H,K,V], our_state is k-last [H,V,K], transpose to match
our_state = our_state.transpose(-1, -2)

def concat_varlen(t1, cu_seq_lens1, t2, cu_seq_lens2):
Expand Down
Loading