Skip to content

Commit cdf10a6

Browse files
chore[cuda]: clean up build & dict decompress float values (#6202)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 4eb78c6 commit cdf10a6

File tree

15 files changed

+49
-17
lines changed

15 files changed

+49
-17
lines changed

REUSE.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ SPDX-FileCopyrightText = "Copyright the Vortex contributors"
3737
SPDX-License-Identifier = "Apache-2.0"
3838

3939
[[annotations]]
40-
path = ["vortex-cuda/kernels/bit_unpack_*"]
40+
path = ["vortex-cuda/kernels/src/bit_unpack_*"]
4141
precedence = "override"
4242
SPDX-FileCopyrightText = "Copyright the Vortex contributors"
4343
SPDX-License-Identifier = "Apache-2.0"

vortex-cuda/build.rs

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,57 @@ pub mod cuda_kernel_generator;
2020

2121
fn main() {
2222
let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("Failed to get manifest dir");
23-
let kernels_dir = Path::new(&manifest_dir).join("kernels");
2423

25-
// Always emit the kernels directory path as a compile-time env var so any binary
24+
// Source directory for kernels (hand-written and generated .cu/.cuh files)
25+
let kernels_src = Path::new(&manifest_dir).join("kernels/src");
26+
// Output directory for compiled .ptx files
27+
let kernels_gen = Path::new(&manifest_dir).join("kernels/gen");
28+
29+
std::fs::create_dir_all(&kernels_gen).expect("Failed to create kernels/gen directory");
30+
31+
// Always emit the kernels output directory path as a compile-time env var so any binary
2632
// linking against vortex-cuda can find the PTX files. This must be set regardless
2733
// of CUDA availability since the code using env!() is always compiled.
2834
// At runtime, VORTEX_CUDA_KERNELS_DIR can be set to override this path.
2935
println!(
3036
"cargo:rustc-env=VORTEX_CUDA_KERNELS_DIR={}",
31-
kernels_dir.display()
37+
kernels_gen.display()
3238
);
3339

34-
println!("cargo:rerun-if-changed={}", kernels_dir.to_str().unwrap());
35-
36-
generate_unpack::<u8>(&kernels_dir, 32).expect("Failed to generate unpack for u8");
37-
generate_unpack::<u16>(&kernels_dir, 32).expect("Failed to generate unpack for u16");
38-
generate_unpack::<u32>(&kernels_dir, 32).expect("Failed to generate unpack for u32");
39-
generate_unpack::<u64>(&kernels_dir, 16).expect("Failed to generate unpack for u64");
40+
// Regenerate bit_unpack kernels only when the generator changes
41+
for entry in std::fs::read_dir(Path::new(&manifest_dir).join("cuda_kernel_generator"))
42+
.expect("Failed to read cuda_kernel_generator directory")
43+
.flatten()
44+
{
45+
println!("cargo:rerun-if-changed={}", entry.path().display());
46+
}
47+
generate_unpack::<u8>(&kernels_src, 32).expect("Failed to generate unpack for u8");
48+
generate_unpack::<u16>(&kernels_src, 32).expect("Failed to generate unpack for u16");
49+
generate_unpack::<u32>(&kernels_src, 32).expect("Failed to generate unpack for u32");
50+
generate_unpack::<u64>(&kernels_src, 16).expect("Failed to generate unpack for u64");
4051

4152
if !is_cuda_available() {
4253
return;
4354
}
4455

45-
if let Ok(entries) = std::fs::read_dir(&kernels_dir) {
56+
// Watch and compile .cu and .cuh files from kernels/src to PTX in kernels/gen
57+
if let Ok(entries) = std::fs::read_dir(&kernels_src) {
4658
for path in entries.flatten().map(|entry| entry.path()) {
59+
let is_generated = path
60+
.file_name()
61+
.and_then(|n| n.to_str())
62+
.is_some_and(|n| n.starts_with("bit_unpack_"));
63+
4764
match path.extension().and_then(|e| e.to_str()) {
4865
Some("cuh") => println!("cargo:rerun-if-changed={}", path.display()),
4966
Some("cu") => {
50-
println!("cargo:rerun-if-changed={}", path.display());
51-
// Compile .cu files to PTX
52-
nvcc_compile_ptx(&kernels_dir, &path)
67+
// Only watch hand-written .cu files, not generated ones
68+
// (generated files are rebuilt when cuda_kernel_generator changes)
69+
if !is_generated {
70+
println!("cargo:rerun-if-changed={}", path.display());
71+
}
72+
// Compile all .cu files to PTX in gen directory
73+
nvcc_compile_ptx(&kernels_src, &kernels_gen, &path)
5374
.map_err(|e| {
5475
format!("Failed to compile CUDA kernel {}: {}", path.display(), e)
5576
})
@@ -67,7 +88,7 @@ fn generate_unpack<T: FastLanes>(output_dir: &Path, thread_count: usize) -> io::
6788
generate_cuda_unpack_for_width::<T, _>(&mut cu_writer, thread_count)
6889
}
6990

70-
fn nvcc_compile_ptx(kernel_dir: &Path, cu_path: &Path) -> io::Result<()> {
91+
fn nvcc_compile_ptx(include_dir: &Path, output_dir: &Path, cu_path: &Path) -> io::Result<()> {
7192
// https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-build-scripts
7293
let profile = env::var("PROFILE").unwrap();
7394

@@ -98,18 +119,23 @@ fn nvcc_compile_ptx(kernel_dir: &Path, cu_path: &Path) -> io::Result<()> {
98119
cmd.arg("-O3");
99120
}
100121

122+
// Output PTX file goes to output_dir with same base name
123+
let ptx_path = output_dir
124+
.join(cu_path.file_name().unwrap())
125+
.with_extension("ptx");
126+
101127
cmd.arg("-std=c++17")
102128
.arg("-arch=native")
103129
// Flags forwarded to Clang.
104130
.arg("--compiler-options=-Wall -Wextra -Wpedantic -Werror")
105131
.arg("--restrict")
106132
.arg("--ptx")
107133
.arg("--include-path")
108-
.arg(kernel_dir)
134+
.arg(include_dir)
109135
.arg("-c")
110136
.arg(cu_path)
111137
.arg("-o")
112-
.arg(cu_path.with_extension("ptx"));
138+
.arg(&ptx_path);
113139

114140
let res = cmd.output()?;
115141

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
#include <cuda.h>
5+
#include <cuda_fp16.h>
56
#include <cuda_runtime.h>
67
#include <stdint.h>
78

@@ -57,6 +58,11 @@ GENERATE_DICT_KERNELS_FOR_VALUE(i32, int32_t)
5758
GENERATE_DICT_KERNELS_FOR_VALUE(u64, uint64_t)
5859
GENERATE_DICT_KERNELS_FOR_VALUE(i64, int64_t)
5960

61+
// Float types
62+
GENERATE_DICT_KERNELS_FOR_VALUE(f16, __half)
63+
GENERATE_DICT_KERNELS_FOR_VALUE(f32, float)
64+
GENERATE_DICT_KERNELS_FOR_VALUE(f64, double)
65+
6066
// Decimal types (128-bit and 256-bit)
6167
GENERATE_DICT_KERNELS_FOR_VALUE(i128, int128_t)
6268
GENERATE_DICT_KERNELS_FOR_VALUE(i256, int256_t)

0 commit comments

Comments
 (0)