@@ -55,89 +55,24 @@ if Libdl.dlopen(libcuda_system, Libdl.RTLD_NOLOAD; throw_error=false) !== nothin
5555 return
5656end
5757
58+ # This shaves ~120ms off the load time
59+ precompile (Base. cmd_gen, (Tuple{Tuple{Base. Cmd}, Tuple{String}, Tuple{Bool}, Tuple{Array{String, 1 }}},))
60+ precompile (Base. read, (Base. Cmd, Type{String}))
61+
5862# helper function to load a driver, query its version, and optionally query device
5963# capabilities. needs to happen in a separate process because dlclose is unreliable.
6064function inspect_driver (driver, deps= String[]; inspect_devices= false )
61- script = raw """
62- using Libdl
63-
64- const DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75
65- const DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76
66-
67- function main(driver, inspect_devices, deps...)
68- inspect_devices = parse(Bool, inspect_devices)
69-
70- for dep in deps
71- Libdl.dlopen(dep; throw_error=false) === nothing && exit(-1)
72- end
73-
74- library_handle = Libdl.dlopen(driver; throw_error=false)
75- library_handle === nothing && return -1
76-
77- cuInit = Libdl.dlsym(library_handle, "cuInit")
78- status = ccall(cuInit, Cint, (UInt32,), 0)
79- status == 0 || return -2
80-
81- cuDriverGetVersion = Libdl.dlsym(library_handle, "cuDriverGetVersion")
82- version = Ref{Cint}()
83- status = ccall(cuDriverGetVersion, Cint, (Ptr{Cint},), version)
84- status == 0 || return -3
85- major, ver = divrem(version[], 1000)
86- minor, patch = divrem(ver, 10)
87- println(major, ".", minor, ".", patch)
88-
89- if inspect_devices
90- cuDeviceGetCount = Libdl.dlsym(library_handle, "cuDeviceGetCount")
91- device_count = Ref{Cint}()
92- status = ccall(cuDeviceGetCount, Cint, (Ptr{Cint},), device_count)
93- status == 0 || return -4
94-
95- cuDeviceGet = Libdl.dlsym(library_handle, "cuDeviceGet")
96- cuDeviceGetAttribute = Libdl.dlsym(library_handle, "cuDeviceGetAttribute")
97- for i in 1:device_count[]
98- device = Ref{Cint}()
99- status = ccall(cuDeviceGet, Cint, (Ptr{Cint}, Cint), device, i-1)
100- status == 0 || return -5
101-
102- major = Ref{Cint}()
103- status = ccall(cuDeviceGetAttribute, Cint, (Ptr{Cint}, UInt32, Cint), major, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device[])
104- status == 0 || return -6
105- minor = Ref{Cint}()
106- status = ccall(cuDeviceGetAttribute, Cint, (Ptr{Cint}, UInt32, Cint), minor, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device[])
107- status == 0 || return -7
108- println(major[], ".", minor[])
109- end
110- end
111-
112- return 0
113- end
114-
115- exit(main(ARGS...))
116- """
117-
118- # make sure we don't include any system image flags here since this will cause an infinite loop of __init__()
119- cmd = ``` $(Cmd (filter (! startswith (r" -J|--sysimage" ), Base. julia_cmd (). exec)))
120- -O0 --compile=min -t1 --startup-file=no
121- -e $script $driver $inspect_devices $deps ```
122-
123- # make sure we use a fresh environment we can load Libdl in
124- cmd = addenv (cmd, " JULIA_LOAD_PATH" => nothing , " JULIA_DEPOT_PATH" => nothing )
65+ cmd = ` $(cuda_inspect_driver ()) $driver $inspect_devices $deps `
12566
12667 # run the command
127- out = Pipe ()
128- proc = run (pipeline (cmd, stdin = devnull , stdout = out), wait= false )
129- close (out. in)
130- out_reader = @static if VERSION >= v " 1.12-"
131- # XXX : avoid concurrent compilation (JuliaLang/julia#59834)
132- Threads. @spawn :samepool String .(readlines (out))
133- else
134- Threads. @spawn String .(readlines (out))
68+ version_strings = String[]
69+ try
70+ version_strings = split (read (cmd, String))
71+ catch _
72+ return nothing
13573 end
136- wait (proc)
137- success (proc) || return nothing
13874
13975 # parse the versions
140- version_strings = fetch (out_reader)
14176 driver_version = parse (VersionNumber, version_strings[1 ])
14277 if inspect_devices
14378 device_capabilities = map (str -> parse (VersionNumber, str), version_strings[2 : end ])
@@ -148,18 +83,18 @@ function inspect_driver(driver, deps=String[]; inspect_devices=false)
14883end
14984
15085# fetch driver details
151- compat_driver_task = @static if VERSION >= v " 1.12-"
152- # XXX : avoid concurrent compilation (JuliaLang/julia#59834)
153- Threads. @spawn :samepool inspect_driver (libcuda_compat, libcuda_deps)
154- else
155- Threads. @spawn inspect_driver (libcuda_compat, libcuda_deps)
156- end
157- system_driver_task = @static if VERSION >= v " 1.12-"
158- # XXX : avoid concurrent compilation (JuliaLang/julia#59834)
159- Threads. @spawn :samepool inspect_driver (libcuda_system; inspect_devices= true )
160- else
161- Threads. @spawn inspect_driver (libcuda_system; inspect_devices= true )
162- end
86+ compat_driver_task = @static if VERSION >= v " 1.12-"
87+ # XXX : avoid concurrent compilation (JuliaLang/julia#59834)
88+ Threads. @spawn :samepool inspect_driver (libcuda_compat, libcuda_deps)
89+ else
90+ Threads. @spawn inspect_driver (libcuda_compat, libcuda_deps)
91+ end
92+ system_driver_task = @static if VERSION >= v " 1.12-"
93+ # XXX : avoid concurrent compilation (JuliaLang/julia#59834)
94+ Threads. @spawn :samepool inspect_driver (libcuda_system; inspect_devices= true )
95+ else
96+ Threads. @spawn inspect_driver (libcuda_system; inspect_devices= true )
97+ end
16398compat_driver_details = fetch (compat_driver_task)
16499if compat_driver_details === nothing
165100 @debug " Failed to load forwards-compatible driver."
0 commit comments