|
16 | 16 |
|
17 | 17 | import itertools |
18 | 18 | import os |
19 | | -import pathlib |
20 | 19 |
|
21 | 20 | import jinja2 |
22 | 21 |
|
@@ -67,21 +66,22 @@ def gen_gdn_prefill_sm90_module() -> JitSpec: |
67 | 66 | write_if_different(dest_path, source) |
68 | 67 |
|
69 | 68 | # Copy source files to gen_directory (like POD module does) |
70 | | - # Include .cuh and .inc files so relative includes work |
71 | 69 | for filename in [ |
72 | 70 | "gdn_prefill_launcher.cu", |
73 | 71 | "flat/prefill/prefill_kernel_delta_rule_sm90.cu", |
| 72 | + ]: |
| 73 | + src_path = jit_env.FLASHINFER_CSRC_DIR / filename |
| 74 | + dest_path = gen_directory / src_path.name |
| 75 | + source_paths.append(dest_path) |
| 76 | + write_if_different(dest_path, src_path.read_text()) |
| 77 | + |
| 78 | + # Copy header files so relative includes work |
| 79 | + for filename in [ |
74 | 80 | "flat/prefill/prefill_kernel_delta_rule_sm90.cuh", |
75 | 81 | "flat/prefill/prefill_kernel_delta_rule_sm90_extern.inc", |
76 | 82 | ]: |
77 | 83 | src_path = jit_env.FLASHINFER_CSRC_DIR / filename |
78 | | - dest_path = gen_directory / pathlib.Path(filename).name |
79 | | - with open(src_path, "r") as f: |
80 | | - source = f.read() |
81 | | - write_if_different(dest_path, source) |
82 | | - # Only add .cu files to source_paths for compilation |
83 | | - if filename.endswith(".cu"): |
84 | | - source_paths.append(dest_path) |
| 84 | + write_if_different(gen_directory / src_path.name, src_path.read_text()) |
85 | 85 |
|
86 | 86 | return gen_jit_spec( |
87 | 87 | uri, |
|
0 commit comments