Skip to content

Commit 2491b2b

Browse files
committed
fix
1 parent a7cafc0 commit 2491b2b

2 files changed

Lines changed: 10 additions & 0 deletions

File tree

csrc/gdn_prefill_sm90_kernel_inst.jinja

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
// needs_alpha={{ needs_alpha }}, init_state={{ init_state }}
1919

2020
#include <cuda_bf16.h>
21+
#include <cuda_fp16.h>
22+
23+
// Ensure cutlass arch types are defined
24+
#include "cutlass/arch/arch.h"
2125

2226
// Use full path since generated files are in a different directory
2327
#include "flat/prefill/prefill_kernel_delta_rule_sm90.cuh"

flashinfer/jit/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,14 @@ def gen_jit_spec(
446446
cuda_cflags += ["-lineinfo"]
447447

448448
if extra_cflags is not None:
449+
# If extra_cflags contains a -std flag, remove the default one to avoid conflicts
450+
if any(f.startswith("-std=") for f in extra_cflags):
451+
cflags = [f for f in cflags if not f.startswith("-std=")]
449452
cflags += extra_cflags
450453
if extra_cuda_cflags is not None:
454+
# If extra_cuda_cflags contains a -std flag, remove the default one to avoid conflicts
455+
if any(f.startswith("-std=") for f in extra_cuda_cflags):
456+
cuda_cflags = [f for f in cuda_cflags if not f.startswith("-std=")]
451457
cuda_cflags += extra_cuda_cflags
452458

453459
spec = JitSpec(

0 commit comments

Comments
 (0)