@@ -15,6 +15,15 @@ using Libdl
1515
1616# # global state
1717
18+ const toolkit_dirs = Ref {Vector{String}} ()
19+
20+ """
21+ prefix()
22+
23+ Returns the installation prefix directories of the CUDA toolkit in use.
24+ """
25+ prefix () = toolkit_dirs[]
26+
1827const toolkit_version = Ref {VersionNumber} ()
1928
2029"""
@@ -39,7 +48,6 @@ const ptx_support = Ref{Vector{VersionNumber}}()
3948const libdevice = Ref {Union{String,Dict{VersionNumber,String}}} ()
4049const libcudadevrt = Ref {String} ()
4150const nvdisasm = Ref {String} ()
42- const ptxas = Ref {String} ()
4351
4452
4553# # source code includes
@@ -109,58 +117,54 @@ function __init__()
109117
110118 # CUDA
111119
112- toolkit_dirs = find_toolkit ()
113- toolkit_version[] = find_toolkit_version (toolkit_dirs)
120+ toolkit_dirs[] = find_toolkit ()
121+
122+ let val = find_cuda_binary (" nvdisasm" , toolkit_dirs[])
123+ val === nothing && error (" Your CUDA installation does not provide the nvdisasm binary" )
124+ nvdisasm[] = val
125+ end
126+
127+ toolkit_version[] = parse_toolkit_version (nvdisasm[])
114128 if release () < v " 9"
115129 silent || @warn " CUDAnative.jl only supports CUDA 9.0 or higher (your toolkit provides CUDA $(release ()) )"
116130 elseif release () > CUDAdrv. release ()
117131 silent || @warn """ You are using CUDA toolkit $(release ()) with a driver that only supports up to $(CUDAdrv. release ()) .
118132 It is recommended to upgrade your driver."""
119133 end
120134
121- llvm_support = llvm_compat (llvm_version)
122- cuda_support = cuda_compat ()
123-
124- target_support[] = sort (collect (llvm_support. cap ∩ cuda_support. cap))
125- isempty (target_support[]) && error (" Your toolchain does not support any device capability" )
126-
127- ptx_support[] = sort (collect (llvm_support. ptx ∩ cuda_support. ptx))
128- isempty (ptx_support[]) && error (" Your toolchain does not support any PTX ISA" )
129-
130- @debug (" CUDAnative supports devices $(verlist (target_support[])) ; PTX $(verlist (ptx_support[])) " )
131-
132- let val = find_libdevice (target_support[], toolkit_dirs)
135+ let val = find_libdevice (toolkit_dirs[])
133136 val === nothing && error (" Your CUDA installation does not provide libdevice" )
134137 libdevice[] = val
135138 end
136139
137- let val = find_libcudadevrt (toolkit_dirs)
140+ let val = find_libcudadevrt (toolkit_dirs[] )
138141 val === nothing && error (" Your CUDA installation does not provide libcudadevrt" )
139142 libcudadevrt[] = val
140143 end
141144
142- let val = find_cuda_binary (" nvdisasm" , toolkit_dirs)
143- val === nothing && error (" Your CUDA installation does not provide the nvdisasm binary" )
144- nvdisasm[] = val
145- end
146-
147- let val = find_cuda_binary (" ptxas" , toolkit_dirs)
148- val === nothing && error (" Your CUDA installation does not provide the ptxas binary" )
149- ptxas[] = val
150- end
151-
152- let val = find_cuda_library (" nvtx" , toolkit_dirs)
145+ let val = find_cuda_library (" nvtx" , toolkit_dirs[], [v " 1" ])
153146 val === nothing && error (" Your CUDA installation does not provide the NVTX library" )
154147 NVTX. libnvtx[] = val
155148 end
156149
157- toolkit_extras_dirs = filter (dir-> isdir (joinpath (dir, " extras" )), toolkit_dirs)
150+ toolkit_extras_dirs = filter (dir-> isdir (joinpath (dir, " extras" )), toolkit_dirs[] )
158151 cupti_dirs = map (dir-> joinpath (dir, " extras" , " CUPTI" ), toolkit_extras_dirs)
159- let val = find_cuda_library (" cupti" , cupti_dirs)
152+ let val = find_cuda_library (" cupti" , cupti_dirs, [toolkit_version[]] )
160153 val === nothing && error (" Your CUDA installation does not provide the CUPTI library" )
161154 CUPTI. libcupti[] = val
162155 end
163156
157+ llvm_support = llvm_compat (llvm_version)
158+ cuda_support = cuda_compat ()
159+
160+ target_support[] = sort (collect (llvm_support. cap ∩ cuda_support. cap))
161+ isempty (target_support[]) && error (" Your toolchain does not support any device capability" )
162+
163+ ptx_support[] = sort (collect (llvm_support. ptx ∩ cuda_support. ptx))
164+ isempty (ptx_support[]) && error (" Your toolchain does not support any PTX ISA" )
165+
166+ @debug (" CUDAnative supports devices $(verlist (target_support[])) ; PTX $(verlist (ptx_support[])) " )
167+
164168
165169 # # actual initialization
166170
0 commit comments