Skip to content

Commit 6ae5bfe

Browse files
yzh119claude
andauthored
refactor: reduce hopper's gdn prefill compilation time and fix docstring. (#2422)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR implements these features: 1. accelerate hopper's gdn prefill compilation time by split compilation 2. fix the docstring of gdn prefill kernel, instead of [N, H, K, V], it expects [N, H, V, K] ## 🔍 Related Issues #2276 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @guangyunh-nv <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Enhanced JIT module generation for GDN prefill kernels with template-driven compilation and separate kernel instantiation. * **Improvements** * JIT specification now intelligently handles C++ standard flags, applying defaults only when not already specified. * **Documentation** * Clarified final state memory layout description for GDN prefill operations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 0eb69bb commit 6ae5bfe

29 files changed

Lines changed: 207 additions & 41 deletions
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
// Extern template declarations to prevent implicit instantiation in the dispatcher.
17+
// Explicit instantiations are in separate generated files for parallel compilation.
18+
19+
#pragma once
20+
21+
#include <cuda_bf16.h>
22+
#include <cuda_fp16.h>
23+
#include "cutlass/arch/arch.h"
24+
25+
namespace flat {
26+
27+
// clang-format off
28+
29+
#define FOR_EACH_BOOL_4(MACRO, ...) \
30+
MACRO(false, false, false, false, __VA_ARGS__) \
31+
MACRO(false, false, false, true, __VA_ARGS__) \
32+
MACRO(false, false, true, false, __VA_ARGS__) \
33+
MACRO(false, false, true, true, __VA_ARGS__) \
34+
MACRO(false, true, false, false, __VA_ARGS__) \
35+
MACRO(false, true, false, true, __VA_ARGS__) \
36+
MACRO(false, true, true, false, __VA_ARGS__) \
37+
MACRO(false, true, true, true, __VA_ARGS__) \
38+
MACRO(true, false, false, false, __VA_ARGS__) \
39+
MACRO(true, false, false, true, __VA_ARGS__) \
40+
MACRO(true, false, true, false, __VA_ARGS__) \
41+
MACRO(true, false, true, true, __VA_ARGS__) \
42+
MACRO(true, true, false, false, __VA_ARGS__) \
43+
MACRO(true, true, false, true, __VA_ARGS__) \
44+
MACRO(true, true, true, false, __VA_ARGS__) \
45+
MACRO(true, true, true, true, __VA_ARGS__)
46+
47+
#define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, ctype) \
48+
extern template void launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, cutlass::arch::Sm90, ctype, ctype, float>( \
49+
cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \
50+
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, \
51+
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t);
52+
53+
// Extern template declarations for half
54+
FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, half)
55+
56+
// Extern template declarations for nv_bfloat16
57+
FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, nv_bfloat16)
58+
59+
#undef DECLARE_TEMPLATE_INSTANCE
60+
#undef FOR_EACH_BOOL_4
61+
62+
// clang-format on
63+
64+
} // namespace flat

csrc/gdn_prefill_launcher.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#include <iostream>
2626
#include <sstream>
2727

28-
#include "flat/prefill/prefill_kernel.hpp"
28+
#include "flashinfer/flat/prefill/prefill_kernel.hpp"
2929

3030
using tvm::ffi::Optional;
3131
using tvm::ffi::TensorView;
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
// Auto-generated file for separate compilation of GDN prefill kernel variants.
17+
// Template parameters: dtype={{ dtype }}, is_gva={{ is_gva }}, needs_beta={{ needs_beta }},
18+
// needs_alpha={{ needs_alpha }}, init_state={{ init_state }}
19+
20+
// CUDA type definitions for half and nv_bfloat16
21+
#include <cuda_bf16.h>
22+
#include <cuda_fp16.h>
23+
24+
// Include the header which defines the function template
25+
// The header includes all necessary CUTLASS type definitions
26+
#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh"
27+
28+
namespace flat {
29+
30+
// Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai
31+
// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm90_extern.inc
32+
template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, cutlass::arch::Sm90, {{ dtype }}, {{ dtype }}, float>(
33+
cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*,
34+
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t,
35+
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t);
36+
37+
} // namespace flat

csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu renamed to csrc/prefill_kernel_delta_rule_sm90.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
*/
1616
#include <cuda_bf16.h>
1717

18-
#include "prefill_kernel_delta_rule_sm90.cuh"
18+
#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh"
19+
20+
// Extern template declarations prevent implicit instantiation here.
21+
// Explicit instantiations are in separate generated files for parallel compilation.
22+
#include "flat_prefill_kernel_delta_rule_sm90_extern.inc"
1923

2024
namespace flat {
2125

@@ -87,6 +91,8 @@ void launch_delta_rule_prefill_kernel(cudaStream_t stream, TO* output, TState* o
8791
#undef LAUNCH
8892
}
8993

94+
// Explicit instantiations for the outer dispatch function only.
95+
// The inner launch_delta_rule_prefill_kernel_gbai instantiations are in separate files.
9096
template void launch_delta_rule_prefill_kernel<cutlass::arch::Sm90, half, half, float>(
9197
cudaStream_t stream, half* output, float* state, half const* q, half const* k, half const* v,
9298
float const* input_state, float const* alpha, float const* beta, int64_t const* cu_seqlens,

flashinfer/gdn_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def chunk_gated_delta_rule(
152152
Note:
153153
- Supports GQA: ``num_q_heads > num_k_heads = num_v_heads``
154154
- Supports GVA: ``num_v_heads > num_q_heads = num_k_heads``
155-
- The final state is in k-major layout ``[N, H, K, V]``.
155+
- The final state is in k-last layout ``[N, H, V, K]``.
156156
- Requires SM90 (Hopper) architecture.
157157
"""
158158
assert cu_seqlens is not None, "cu_seqlens is required for varlen mode"

flashinfer/jit/core.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,29 @@ def gen_jit_spec(
417417
verbose_env = os.environ.get("FLASHINFER_JIT_VERBOSE", "0")
418418
debug = (debug_env if debug_env is not None else verbose_env) == "1"
419419

420-
cflags = ["-std=c++17", "-Wno-switch-bool"]
420+
# Only add default C++ standard if not specified in extra flags
421+
cflags_has_std = extra_cflags is not None and any(
422+
f.startswith("-std=") for f in extra_cflags
423+
)
424+
cuda_cflags_has_std = extra_cuda_cflags is not None and any(
425+
f.startswith("-std=") for f in extra_cuda_cflags
426+
)
427+
428+
cflags = ["-Wno-switch-bool"]
429+
if not cflags_has_std:
430+
cflags.insert(0, "-std=c++17")
431+
421432
cuda_cflags = [
422-
"-std=c++17",
423433
f"--threads={os.environ.get('FLASHINFER_NVCC_THREADS', '1')}",
424434
"-use_fast_math",
425435
"-DFLASHINFER_ENABLE_F16",
426436
"-DFLASHINFER_ENABLE_BF16",
427437
"-DFLASHINFER_ENABLE_FP8_E4M3",
428438
"-DFLASHINFER_ENABLE_FP8_E5M2",
429439
]
440+
if not cuda_cflags_has_std:
441+
cuda_cflags.insert(0, "-std=c++17")
442+
430443
if debug:
431444
cflags += ["-O0", "-g"]
432445
cuda_cflags += [

flashinfer/jit/gdn.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,70 @@
1414
limitations under the License.
1515
"""
1616

17+
import itertools
18+
import os
19+
20+
import jinja2
21+
1722
from . import env as jit_env
1823
from .core import (
1924
JitSpec,
2025
gen_jit_spec,
2126
sm90a_nvcc_flags,
2227
)
28+
from .utils import write_if_different
2329

2430

2531
def gen_gdn_prefill_sm90_module() -> JitSpec:
32+
"""Generate JIT module for GDN prefill kernel with separate compilation.
33+
34+
This generates 32 separate kernel instantiation files (2 dtypes × 16 boolean combinations)
35+
plus the original launcher file. The separate files enable parallel compilation by ninja,
36+
significantly reducing build time on multi-core machines.
37+
"""
38+
uri = "gdn_prefill_sm90"
39+
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
40+
os.makedirs(gen_directory, exist_ok=True)
41+
42+
source_paths = []
43+
44+
# Load kernel instantiation template
45+
with open(jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_sm90_kernel_inst.jinja") as f:
46+
kernel_inst_templ = jinja2.Template(f.read())
47+
48+
# Generate 32 separate instance files (2 dtypes × 16 boolean combinations)
49+
dtypes = [("half", "half"), ("bf16", "nv_bfloat16")]
50+
for dtype_name, dtype in dtypes:
51+
for is_gva, needs_beta, needs_alpha, init_state in itertools.product(
52+
[False, True], repeat=4
53+
):
54+
suffix = f"{dtype_name}_g{int(is_gva)}b{int(needs_beta)}a{int(needs_alpha)}i{int(init_state)}"
55+
filename = f"gdn_prefill_kernel_{suffix}.cu"
56+
dest_path = gen_directory / filename
57+
source_paths.append(dest_path)
58+
59+
source = kernel_inst_templ.render(
60+
dtype=dtype,
61+
is_gva=str(is_gva).lower(),
62+
needs_beta=str(needs_beta).lower(),
63+
needs_alpha=str(needs_alpha).lower(),
64+
init_state=str(init_state).lower(),
65+
)
66+
write_if_different(dest_path, source)
67+
68+
# Copy source files to gen_directory for compilation
69+
# Headers are now in include/flashinfer/flat/ and accessible via standard include paths
70+
for filename in [
71+
"gdn_prefill_launcher.cu",
72+
"prefill_kernel_delta_rule_sm90.cu",
73+
]:
74+
src_path = jit_env.FLASHINFER_CSRC_DIR / filename
75+
dest_path = gen_directory / src_path.name
76+
source_paths.append(dest_path)
77+
write_if_different(dest_path, src_path.read_text())
78+
2679
return gen_jit_spec(
27-
name="gdn_prefill_launcher",
28-
sources=[
29-
jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu",
30-
jit_env.FLASHINFER_CSRC_DIR
31-
/ "flat"
32-
/ "prefill"
33-
/ "prefill_kernel_delta_rule_sm90.cu",
34-
],
80+
uri,
81+
source_paths,
3582
extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
36-
extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
3783
)

csrc/flat/ampere/collective/flat_collective_inverse.hpp renamed to include/flashinfer/flat/ampere/collective/flat_collective_inverse.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "cute/tensor.hpp"
1919
#include "cutlass/arch/barrier.h"
2020
#include "cutlass/cutlass.h"
21-
#include "flat/cute_ext.hpp"
21+
#include "flashinfer/flat/cute_ext.hpp"
2222

2323
namespace flat::collective {
2424

csrc/flat/ampere/collective/flat_collective_load.hpp renamed to include/flashinfer/flat/ampere/collective/flat_collective_load.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "cute/tensor.hpp"
1919
#include "cutlass/cutlass.h"
2020
#include "cutlass/pipeline/sm90_pipeline.hpp"
21-
#include "flat/unused.hpp"
21+
#include "flashinfer/flat/unused.hpp"
2222

2323
namespace flat::collective {
2424

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <stdexcept>
2020
#include <string>
2121

22-
#include "debug.hpp"
22+
#include "flashinfer/flat/debug.hpp"
2323

2424
#define FLAT_UNUSED_PARAMETER(x) (void)x
2525

0 commit comments

Comments
 (0)