File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1818// needs_alpha={{ needs_alpha }}, init_state={{ init_state }}
1919
2020#include <cuda _bf16.h >
21+ #include <cuda _fp16.h >
22+
23+ // Ensure cutlass arch types are defined
24+ #include "cutlass/arch/arch.h"
2125
2226// Use full path since generated files are in a different directory
2327#include "flat/prefill/prefill_kernel_delta_rule_sm90.cuh"
Original file line number Diff line number Diff line change @@ -446,8 +446,14 @@ def gen_jit_spec(
446446 cuda_cflags += ["-lineinfo" ]
447447
448448 if extra_cflags is not None :
449+ # If extra_cflags contains a -std flag, remove the default one to avoid conflicts
450+ if any (f .startswith ("-std=" ) for f in extra_cflags ):
451+ cflags = [f for f in cflags if not f .startswith ("-std=" )]
449452 cflags += extra_cflags
450453 if extra_cuda_cflags is not None :
454+ # If extra_cuda_cflags contains a -std flag, remove the default one to avoid conflicts
455+ if any (f .startswith ("-std=" ) for f in extra_cuda_cflags ):
456+ cuda_cflags = [f for f in cuda_cflags if not f .startswith ("-std=" )]
451457 cuda_cflags += extra_cuda_cflags
452458
453459 spec = JitSpec (
You can’t perform that action at this time.
0 commit comments