Skip to content

Commit a7cafc0

Browse files
committed
upd
1 parent 413cc5f commit a7cafc0

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

flashinfer/jit/gdn.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import itertools
1818
import os
19-
import pathlib
2019

2120
import jinja2
2221

@@ -67,21 +66,22 @@ def gen_gdn_prefill_sm90_module() -> JitSpec:
6766
write_if_different(dest_path, source)
6867

6968
# Copy source files to gen_directory (like POD module does)
70-
# Include .cuh and .inc files so relative includes work
7169
for filename in [
7270
"gdn_prefill_launcher.cu",
7371
"flat/prefill/prefill_kernel_delta_rule_sm90.cu",
72+
]:
73+
src_path = jit_env.FLASHINFER_CSRC_DIR / filename
74+
dest_path = gen_directory / src_path.name
75+
source_paths.append(dest_path)
76+
write_if_different(dest_path, src_path.read_text())
77+
78+
# Copy header files so relative includes work
79+
for filename in [
7480
"flat/prefill/prefill_kernel_delta_rule_sm90.cuh",
7581
"flat/prefill/prefill_kernel_delta_rule_sm90_extern.inc",
7682
]:
7783
src_path = jit_env.FLASHINFER_CSRC_DIR / filename
78-
dest_path = gen_directory / pathlib.Path(filename).name
79-
with open(src_path, "r") as f:
80-
source = f.read()
81-
write_if_different(dest_path, source)
82-
# Only add .cu files to source_paths for compilation
83-
if filename.endswith(".cu"):
84-
source_paths.append(dest_path)
84+
write_if_different(gen_directory / src_path.name, src_path.read_text())
8585

8686
return gen_jit_spec(
8787
uri,

0 commit comments

Comments
 (0)