@@ -4,12 +4,14 @@ use std::path::{Path, PathBuf};
44pub fn get_cuda_path ( ) -> Option < & ' static Path > {
55 #[ cfg( target_os = "linux" ) ]
66 {
7- let path = Path :: new ( "/usr/local/cuda" ) ;
8- if path. exists ( ) {
9- Some ( path)
10- } else {
11- None
7+ for path_name in [ option_env ! ( "CUDA_PATH" ) , Some ( "/usr/local/cuda" ) ] . iter ( ) . flatten ( ) {
8+ let path = Path :: new ( path_name) ;
9+ if path. exists ( ) {
10+ println ! ( "CUDA installation found at `{}`" , path. display( ) ) ;
11+ return Some ( path)
12+ }
1213 }
14+ println ! ( "CUDA installation path not found" )
1315 }
1416 #[ cfg( target_os = "windows" ) ]
1517 {
@@ -42,12 +44,24 @@ pub fn get_cuda_lib_path() -> Option<PathBuf> {
4244
4345pub fn get_cuda_version ( ) -> Option < String > {
4446 if let Some ( version) = option_env ! ( "CUDA_VERSION" ) {
45- Some ( version. to_string ( ) )
47+ println ! ( "CUDA version defined in CUDA_VERSION as `{}`" , version) ;
48+ version. to_string ( )
4649 } else if let Some ( path) = get_cuda_path ( ) {
47- let file = File :: open ( path. join ( "version.json" ) ) . expect ( "CUDA Toolkit should be installed" ) ;
48- let reader = std:: io:: BufReader :: new ( file) ;
49- let value: serde_json:: Value = serde_json:: from_reader ( reader) . unwrap ( ) ;
50- Some ( value[ "cuda" ] [ "version" ] . as_str ( ) . unwrap ( ) . to_string ( ) )
50+ println ! ( "inferring CUDA version from nvcc output..." ) ;
51+ let re = regex_lite:: Regex :: new ( r"V(?<version>\d{2}\.\d+\.\d+)" ) . unwrap ( ) ;
52+ let nvcc_out = std:: process:: Command :: new ( "nvcc" )
53+ . arg ( "--version" )
54+ . output ( )
55+ . expect ( "failed to start `nvcc`" ) ;
56+ let nvcc_str = std:: str:: from_utf8 ( & nvcc_out. stdout ) . expect ( "`nvcc` output is not UTF8" ) ;
57+ let captures = re. captures ( & nvcc_str) . unwrap ( ) ;
58+ let version = captures
59+ . get ( 0 )
60+ . expect ( "unable to find nvcc version in the form VMM.mm.pp in the output of `nvcc --version`:\n {nvcc_str}" )
61+ . as_str ( )
62+ . to_string ( ) ;
63+ println ! ( "CUDA version inferred to be `{version}`." ) ;
64+ Some ( version)
5165 } else {
5266 None
5367 }
0 commit comments