Skip to content

Commit bfe6991

Browse files
committed
upd
1 parent 2491b2b commit bfe6991

2 files changed

Lines changed: 20 additions & 21 deletions

File tree

csrc/gdn_prefill_sm90_kernel_inst.jinja

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,10 @@
2929
namespace flat {
3030

3131
// Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai
32-
template void launch_delta_rule_prefill_kernel_gbai<
33-
/*IsGVA=*/{{ is_gva }},
34-
/*NeedsBeta=*/{{ needs_beta }},
35-
/*NeedsAlpha=*/{{ needs_alpha }},
36-
/*InitStateFromInput=*/{{ init_state }},
37-
cutlass::arch::Sm90,
38-
{{ dtype }}, {{ dtype }}, float>(
39-
cudaStream_t stream, {{ dtype }}* output, float* output_state,
40-
{{ dtype }} const* q, {{ dtype }} const* k, {{ dtype }} const* v,
41-
float const* input_state, float const* alpha, float const* beta,
42-
int64_t const* cu_seqlens, int32_t num_seqs, int32_t num_q_heads,
43-
int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads,
44-
int32_t head_size, int64_t total_seqlen, float scale, int32_t sm_count);
32+
// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm90_extern.inc
33+
template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, cutlass::arch::Sm90, {{ dtype }}, {{ dtype }}, float>(
34+
cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*,
35+
float const*, float const*, float const*, int64_t const*, int32_t, int32_t,
36+
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t);
4537

4638
} // namespace flat

flashinfer/jit/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -416,16 +416,29 @@ def gen_jit_spec(
416416
verbose_env = os.environ.get("FLASHINFER_JIT_VERBOSE", "0")
417417
debug = (debug_env if debug_env is not None else verbose_env) == "1"
418418

419-
cflags = ["-std=c++17", "-Wno-switch-bool"]
419+
# Only add default C++ standard if not specified in extra flags
420+
cflags_has_std = extra_cflags is not None and any(
421+
f.startswith("-std=") for f in extra_cflags
422+
)
423+
cuda_cflags_has_std = extra_cuda_cflags is not None and any(
424+
f.startswith("-std=") for f in extra_cuda_cflags
425+
)
426+
427+
cflags = ["-Wno-switch-bool"]
428+
if not cflags_has_std:
429+
cflags.insert(0, "-std=c++17")
430+
420431
cuda_cflags = [
421-
"-std=c++17",
422432
f"--threads={os.environ.get('FLASHINFER_NVCC_THREADS', '1')}",
423433
"-use_fast_math",
424434
"-DFLASHINFER_ENABLE_F16",
425435
"-DFLASHINFER_ENABLE_BF16",
426436
"-DFLASHINFER_ENABLE_FP8_E4M3",
427437
"-DFLASHINFER_ENABLE_FP8_E5M2",
428438
]
439+
if not cuda_cflags_has_std:
440+
cuda_cflags.insert(0, "-std=c++17")
441+
429442
if debug:
430443
cflags += ["-O0", "-g"]
431444
cuda_cflags += [
@@ -446,14 +459,8 @@ def gen_jit_spec(
446459
cuda_cflags += ["-lineinfo"]
447460

448461
if extra_cflags is not None:
449-
# If extra_cflags contains a -std flag, remove the default one to avoid conflicts
450-
if any(f.startswith("-std=") for f in extra_cflags):
451-
cflags = [f for f in cflags if not f.startswith("-std=")]
452462
cflags += extra_cflags
453463
if extra_cuda_cflags is not None:
454-
# If extra_cuda_cflags contains a -std flag, remove the default one to avoid conflicts
455-
if any(f.startswith("-std=") for f in extra_cuda_cflags):
456-
cuda_cflags = [f for f in cuda_cflags if not f.startswith("-std=")]
457464
cuda_cflags += extra_cuda_cflags
458465

459466
spec = JitSpec(

0 commit comments

Comments
 (0)