Skip to content

Commit 4091de1

Browse files
committed
Make the driver feature not depend on nvrtc
1 parent 4ddc9e2 commit 4091de1

4 files changed

Lines changed: 19 additions & 5 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ dynamic-loading = []
6868
dynamic-linking = []
6969
static-linking = []
7070

71+
driver = []
7172
nvrtc = []
72-
driver = ["nvrtc"]
7373
cublas = ["driver"]
7474
cublaslt = ["driver"]
7575
runtime = ["driver"]

src/driver/safe/core.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,7 @@ impl CudaContext {
17021702
/// Dynamically load a compiled ptx into this context.
17031703
///
17041704
/// - `ptx` contains the compiled ptx
1705+
#[cfg(feature = "nvrtc")]
17051706
pub fn load_module(
17061707
self: &Arc<Self>,
17071708
ptx: crate::nvrtc::Ptx,

src/driver/safe/launch.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,13 @@ impl LaunchArgs<'_> {
273273

274274
#[cfg(test)]
275275
mod tests {
276-
use crate::{
277-
driver::{CudaContext, DriverError},
278-
nvrtc::compile_ptx_with_opts,
279-
};
276+
use crate::driver::{CudaContext, DriverError};
277+
#[cfg(feature = "nvrtc")]
278+
use crate::nvrtc::compile_ptx_with_opts;
280279

281280
use super::*;
282281

282+
#[cfg(feature = "nvrtc")]
283283
#[test]
284284
fn test_launch_arrays() -> Result<(), DriverError> {
285285
#[repr(C)]
@@ -341,6 +341,7 @@ extern \"C\" __global__ void sin_kernel(float *out, const float *inp, size_t num
341341
}
342342
}";
343343

344+
#[cfg(feature = "nvrtc")]
344345
#[test]
345346
fn test_launch_with_mut_and_ref_cudarc() {
346347
let ctx = CudaContext::new(0).unwrap();
@@ -375,6 +376,7 @@ extern \"C\" __global__ void sin_kernel(float *out, const float *inp, size_t num
375376
drop(a_dev);
376377
}
377378

379+
#[cfg(feature = "nvrtc")]
378380
#[test]
379381
fn test_large_launches() {
380382
let ctx = CudaContext::new(0).unwrap();
@@ -407,6 +409,7 @@ extern \"C\" __global__ void sin_kernel(float *out, const float *inp, size_t num
407409
}
408410
}
409411

412+
#[cfg(feature = "nvrtc")]
410413
#[test]
411414
fn test_launch_with_views() {
412415
let ctx = CudaContext::new(0).unwrap();
@@ -481,6 +484,7 @@ extern \"C\" __global__ void floating(float f, double d) {
481484
}
482485
";
483486

487+
#[cfg(feature = "nvrtc")]
484488
#[test]
485489
fn test_launch_with_8bit() {
486490
let ctx = CudaContext::new(0).unwrap();
@@ -501,6 +505,7 @@ extern \"C\" __global__ void floating(float f, double d) {
501505
stream.synchronize().unwrap();
502506
}
503507

508+
#[cfg(feature = "nvrtc")]
504509
#[test]
505510
fn test_launch_with_16bit() {
506511
let ctx = CudaContext::new(0).unwrap();
@@ -521,6 +526,7 @@ extern \"C\" __global__ void floating(float f, double d) {
521526
stream.synchronize().unwrap();
522527
}
523528

529+
#[cfg(feature = "nvrtc")]
524530
#[test]
525531
fn test_launch_with_32bit() {
526532
let ctx = CudaContext::new(0).unwrap();
@@ -541,6 +547,7 @@ extern \"C\" __global__ void floating(float f, double d) {
541547
stream.synchronize().unwrap();
542548
}
543549

550+
#[cfg(feature = "nvrtc")]
544551
#[test]
545552
fn test_launch_with_64bit() {
546553
let ctx = CudaContext::new(0).unwrap();
@@ -561,6 +568,7 @@ extern \"C\" __global__ void floating(float f, double d) {
561568
stream.synchronize().unwrap();
562569
}
563570

571+
#[cfg(feature = "nvrtc")]
564572
#[test]
565573
fn test_launch_with_floats() {
566574
let ctx = CudaContext::new(0).unwrap();
@@ -626,6 +634,7 @@ extern \"C\" __global__ void slow_worker(const float *data, const size_t len, fl
626634
}
627635
";
628636

637+
#[cfg(feature = "nvrtc")]
629638
#[test]
630639
fn test_par_launch() -> Result<(), DriverError> {
631640
let ptx = compile_ptx_with_opts(SLOW_KERNELS, Default::default()).unwrap();
@@ -698,6 +707,7 @@ extern \"C\" __global__ void slow_worker(const float *data, const size_t len, fl
698707
Ok(())
699708
}
700709

710+
#[cfg(feature = "nvrtc")]
701711
#[test]
702712
fn test_multi_stream_concurrent_reads() -> Result<(), DriverError> {
703713
let ptx = compile_ptx_with_opts(SLOW_KERNELS, Default::default()).unwrap();
@@ -739,6 +749,7 @@ extern \"C\" __global__ void slow_worker(const float *data, const size_t len, fl
739749
Ok(())
740750
}
741751

752+
#[cfg(feature = "nvrtc")]
742753
#[test]
743754
fn test_multi_stream_writes_block() -> Result<(), DriverError> {
744755
let ptx = compile_ptx_with_opts(SLOW_KERNELS, Default::default()).unwrap();
@@ -778,6 +789,7 @@ extern \"C\" __global__ void slow_worker(const float *data, const size_t len, fl
778789
Ok(())
779790
}
780791

792+
#[cfg(feature = "nvrtc")]
781793
#[test]
782794
#[ignore = "must be executed by itself"]
783795
fn test_device_side_assert() -> Result<(), DriverError> {

src/driver/safe/unified_memory.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ unsafe impl<'a, 'b: 'a, T> PushKernelArg<&'b mut UnifiedSlice<T>> for LaunchArgs
346346
}
347347
}
348348

349+
#[cfg(feature = "nvrtc")]
349350
#[cfg(test)]
350351
mod tests {
351352
#![allow(clippy::needless_range_loop)]

0 commit comments

Comments
 (0)