@@ -20,36 +20,57 @@ pub mod cuda_kernel_generator;
2020
2121fn main ( ) {
2222 let manifest_dir = env:: var ( "CARGO_MANIFEST_DIR" ) . expect ( "Failed to get manifest dir" ) ;
23- let kernels_dir = Path :: new ( & manifest_dir) . join ( "kernels" ) ;
2423
25- // Always emit the kernels directory path as a compile-time env var so any binary
24+ // Source directory for kernels (hand-written and generated .cu/.cuh files)
25+ let kernels_src = Path :: new ( & manifest_dir) . join ( "kernels/src" ) ;
26+ // Output directory for compiled .ptx files
27+ let kernels_gen = Path :: new ( & manifest_dir) . join ( "kernels/gen" ) ;
28+
29+ std:: fs:: create_dir_all ( & kernels_gen) . expect ( "Failed to create kernels/gen directory" ) ;
30+
31+ // Always emit the kernels output directory path as a compile-time env var so any binary
2632 // linking against vortex-cuda can find the PTX files. This must be set regardless
2733 // of CUDA availability since the code using env!() is always compiled.
2834 // At runtime, VORTEX_CUDA_KERNELS_DIR can be set to override this path.
2935 println ! (
3036 "cargo:rustc-env=VORTEX_CUDA_KERNELS_DIR={}" ,
31- kernels_dir . display( )
37+ kernels_gen . display( )
3238 ) ;
3339
34- println ! ( "cargo:rerun-if-changed={}" , kernels_dir. to_str( ) . unwrap( ) ) ;
35-
36- generate_unpack :: < u8 > ( & kernels_dir, 32 ) . expect ( "Failed to generate unpack for u8" ) ;
37- generate_unpack :: < u16 > ( & kernels_dir, 32 ) . expect ( "Failed to generate unpack for u16" ) ;
38- generate_unpack :: < u32 > ( & kernels_dir, 32 ) . expect ( "Failed to generate unpack for u32" ) ;
39- generate_unpack :: < u64 > ( & kernels_dir, 16 ) . expect ( "Failed to generate unpack for u64" ) ;
40+ // Regenerate bit_unpack kernels only when the generator changes
41+ for entry in std:: fs:: read_dir ( Path :: new ( & manifest_dir) . join ( "cuda_kernel_generator" ) )
42+ . expect ( "Failed to read cuda_kernel_generator directory" )
43+ . flatten ( )
44+ {
45+ println ! ( "cargo:rerun-if-changed={}" , entry. path( ) . display( ) ) ;
46+ }
47+ generate_unpack :: < u8 > ( & kernels_src, 32 ) . expect ( "Failed to generate unpack for u8" ) ;
48+ generate_unpack :: < u16 > ( & kernels_src, 32 ) . expect ( "Failed to generate unpack for u16" ) ;
49+ generate_unpack :: < u32 > ( & kernels_src, 32 ) . expect ( "Failed to generate unpack for u32" ) ;
50+ generate_unpack :: < u64 > ( & kernels_src, 16 ) . expect ( "Failed to generate unpack for u64" ) ;
4051
4152 if !is_cuda_available ( ) {
4253 return ;
4354 }
4455
45- if let Ok ( entries) = std:: fs:: read_dir ( & kernels_dir) {
56+ // Watch and compile .cu and .cuh files from kernels/src to PTX in kernels/gen
57+ if let Ok ( entries) = std:: fs:: read_dir ( & kernels_src) {
4658 for path in entries. flatten ( ) . map ( |entry| entry. path ( ) ) {
59+ let is_generated = path
60+ . file_name ( )
61+ . and_then ( |n| n. to_str ( ) )
62+ . is_some_and ( |n| n. starts_with ( "bit_unpack_" ) ) ;
63+
4764 match path. extension ( ) . and_then ( |e| e. to_str ( ) ) {
4865 Some ( "cuh" ) => println ! ( "cargo:rerun-if-changed={}" , path. display( ) ) ,
4966 Some ( "cu" ) => {
50- println ! ( "cargo:rerun-if-changed={}" , path. display( ) ) ;
51- // Compile .cu files to PTX
52- nvcc_compile_ptx ( & kernels_dir, & path)
67+ // Only watch hand-written .cu files, not generated ones
68+ // (generated files are rebuilt when cuda_kernel_generator changes)
69+ if !is_generated {
70+ println ! ( "cargo:rerun-if-changed={}" , path. display( ) ) ;
71+ }
72+ // Compile all .cu files to PTX in gen directory
73+ nvcc_compile_ptx ( & kernels_src, & kernels_gen, & path)
5374 . map_err ( |e| {
5475 format ! ( "Failed to compile CUDA kernel {}: {}" , path. display( ) , e)
5576 } )
@@ -67,7 +88,7 @@ fn generate_unpack<T: FastLanes>(output_dir: &Path, thread_count: usize) -> io::
6788 generate_cuda_unpack_for_width :: < T , _ > ( & mut cu_writer, thread_count)
6889}
6990
70- fn nvcc_compile_ptx ( kernel_dir : & Path , cu_path : & Path ) -> io:: Result < ( ) > {
91+ fn nvcc_compile_ptx ( include_dir : & Path , output_dir : & Path , cu_path : & Path ) -> io:: Result < ( ) > {
7192 // https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-build-scripts
7293 let profile = env:: var ( "PROFILE" ) . unwrap ( ) ;
7394
@@ -98,18 +119,23 @@ fn nvcc_compile_ptx(kernel_dir: &Path, cu_path: &Path) -> io::Result<()> {
98119 cmd. arg ( "-O3" ) ;
99120 }
100121
122+ // Output PTX file goes to output_dir with same base name
123+ let ptx_path = output_dir
124+ . join ( cu_path. file_name ( ) . unwrap ( ) )
125+ . with_extension ( "ptx" ) ;
126+
101127 cmd. arg ( "-std=c++17" )
102128 . arg ( "-arch=native" )
103129 // Flags forwarded to Clang.
104130 . arg ( "--compiler-options=-Wall -Wextra -Wpedantic -Werror" )
105131 . arg ( "--restrict" )
106132 . arg ( "--ptx" )
107133 . arg ( "--include-path" )
108- . arg ( kernel_dir )
134+ . arg ( include_dir )
109135 . arg ( "-c" )
110136 . arg ( cu_path)
111137 . arg ( "-o" )
112- . arg ( cu_path . with_extension ( "ptx" ) ) ;
138+ . arg ( & ptx_path ) ;
113139
114140 let res = cmd. output ( ) ?;
115141
0 commit comments