Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ dynamic-loading = []
dynamic-linking = []
static-linking = []

driver = []
nvrtc = []
driver = ["nvrtc"]
cublas = ["driver"]
cublaslt = ["driver"]
runtime = ["driver"]
Expand Down
1 change: 1 addition & 0 deletions src/driver/safe/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,7 @@ impl CudaContext {
/// Dynamically load a compiled ptx into this context.
///
/// - `ptx` contains the compiled ptx
#[cfg(feature = "nvrtc")]
pub fn load_module(
self: &Arc<Self>,
ptx: crate::nvrtc::Ptx,
Expand Down
20 changes: 16 additions & 4 deletions src/driver/safe/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,13 @@ impl LaunchArgs<'_> {

#[cfg(test)]
mod tests {
use crate::{
driver::{CudaContext, DriverError},
nvrtc::compile_ptx_with_opts,
};
use crate::driver::{CudaContext, DriverError};
#[cfg(feature = "nvrtc")]
use crate::nvrtc::compile_ptx_with_opts;

use super::*;

#[cfg(feature = "nvrtc")]
#[test]
fn test_launch_arrays() -> Result<(), DriverError> {
Comment on lines 274 to 284

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better maintainability, you could apply the nvrtc feature gate to the entire tests module instead of each individual test. This avoids repetition and ensures future tests in this file that rely on nvrtc are correctly gated. You've used a similar pattern in src/driver/safe/unified_memory.rs.

This change would make the code more concise and less error-prone.

Suggested change
#[cfg(test)]
mod tests {
use crate::{
driver::{CudaContext, DriverError},
nvrtc::compile_ptx_with_opts,
};
use crate::driver::{CudaContext, DriverError};
#[cfg(feature = "nvrtc")]
use crate::nvrtc::compile_ptx_with_opts;
use super::*;
#[cfg(feature = "nvrtc")]
#[test]
fn test_launch_arrays() -> Result<(), DriverError> {
#[cfg(all(test, feature = "nvrtc"))]
mod tests {
use crate::{
driver::{CudaContext, DriverError},
nvrtc::compile_ptx_with_opts,
};
use super::*;
#[test]
fn test_launch_arrays() -> Result<(), DriverError> {

#[repr(C)]
Expand Down Expand Up @@ -341,6 +341,7 @@ extern \"C\" __global__ void sin_kernel(float *out, const float *inp, size_t num
}
}";

#[cfg(feature = "nvrtc")]
#[test]
fn test_launch_with_mut_and_ref_cudarc() {
let ctx = CudaContext::new(0).unwrap();
Expand Down Expand Up @@ -375,6 +376,7 @@ extern \"C\" __global__ void sin_kernel(float *out, const float *inp, size_t num
drop(a_dev);
}

#[cfg(feature = "nvrtc")]
#[test]
fn test_large_launches() {
let ctx = CudaContext::new(0).unwrap();
Expand Down Expand Up @@ -407,6 +409,7 @@ extern \"C\" __global__ void sin_kernel(float *out, const float *inp, size_t num
}
}

#[cfg(feature = "nvrtc")]
#[test]
fn test_launch_with_views() {
let ctx = CudaContext::new(0).unwrap();
Expand Down Expand Up @@ -481,6 +484,7 @@ extern \"C\" __global__ void floating(float f, double d) {
}
";

#[cfg(feature = "nvrtc")]
#[test]
fn test_launch_with_8bit() {
let ctx = CudaContext::new(0).unwrap();
Expand All @@ -501,6 +505,7 @@ extern \"C\" __global__ void floating(float f, double d) {
stream.synchronize().unwrap();
}

#[cfg(feature = "nvrtc")]
#[test]
fn test_launch_with_16bit() {
let ctx = CudaContext::new(0).unwrap();
Expand All @@ -521,6 +526,7 @@ extern \"C\" __global__ void floating(float f, double d) {
stream.synchronize().unwrap();
}

#[cfg(feature = "nvrtc")]
#[test]
fn test_launch_with_32bit() {
let ctx = CudaContext::new(0).unwrap();
Expand All @@ -541,6 +547,7 @@ extern \"C\" __global__ void floating(float f, double d) {
stream.synchronize().unwrap();
}

#[cfg(feature = "nvrtc")]
#[test]
fn test_launch_with_64bit() {
let ctx = CudaContext::new(0).unwrap();
Expand All @@ -561,6 +568,7 @@ extern \"C\" __global__ void floating(float f, double d) {
stream.synchronize().unwrap();
}

#[cfg(feature = "nvrtc")]
#[test]
fn test_launch_with_floats() {
let ctx = CudaContext::new(0).unwrap();
Expand Down Expand Up @@ -626,6 +634,7 @@ extern \"C\" __global__ void slow_worker(const float *data, const size_t len, fl
}
";

#[cfg(feature = "nvrtc")]
#[test]
fn test_par_launch() -> Result<(), DriverError> {
let ptx = compile_ptx_with_opts(SLOW_KERNELS, Default::default()).unwrap();
Expand Down Expand Up @@ -698,6 +707,7 @@ extern \"C\" __global__ void slow_worker(const float *data, const size_t len, fl
Ok(())
}

#[cfg(feature = "nvrtc")]
#[test]
fn test_multi_stream_concurrent_reads() -> Result<(), DriverError> {
let ptx = compile_ptx_with_opts(SLOW_KERNELS, Default::default()).unwrap();
Expand Down Expand Up @@ -739,6 +749,7 @@ extern \"C\" __global__ void slow_worker(const float *data, const size_t len, fl
Ok(())
}

#[cfg(feature = "nvrtc")]
#[test]
fn test_multi_stream_writes_block() -> Result<(), DriverError> {
let ptx = compile_ptx_with_opts(SLOW_KERNELS, Default::default()).unwrap();
Expand Down Expand Up @@ -778,6 +789,7 @@ extern \"C\" __global__ void slow_worker(const float *data, const size_t len, fl
Ok(())
}

#[cfg(feature = "nvrtc")]
#[test]
#[ignore = "must be executed by itself"]
fn test_device_side_assert() -> Result<(), DriverError> {
Expand Down
1 change: 1 addition & 0 deletions src/driver/safe/unified_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ unsafe impl<'a, 'b: 'a, T> PushKernelArg<&'b mut UnifiedSlice<T>> for LaunchArgs
}
}

#[cfg(feature = "nvrtc")]
#[cfg(test)]
mod tests {
#![allow(clippy::needless_range_loop)]
Expand Down
Loading