Skip to content

Commit 449681e

Browse files
committed
[CUDA_Driver] Refactor to use a C binary to inspect the driver
This is much faster than calling the Julia script.
1 parent 70866c8 commit 449681e

File tree

3 files changed

+127
-87
lines changed

3 files changed

+127
-87
lines changed

C/CUDA/CUDA_Driver/build_tarballs.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ cuda_version = v"13.1"
1717
driver_version = "590.48.01"
1818

1919
script = raw"""
20+
# Build the driver inspection binary
21+
mkdir -p ${bindir}
22+
${CC} -std=c99 -ldl cuda_inspect_driver.c -o ${bindir}/cuda_inspect_driver
23+
2024
mkdir -p ${libdir}
2125
2226
cd ${WORKSPACE}/srcdir/cuda_compat*
@@ -48,6 +52,7 @@ products = [
4852
LibraryProduct("libnvidia-nvvm", :libnvidia_nvvm; dont_dlopen=true),
4953
LibraryProduct("libnvidia-ptxjitcompiler", :libnvidia_ptxjitcompiler; dont_dlopen=true),
5054
LibraryProduct("libnvidia-tileiras", :libnvidia_tileiras; dont_dlopen=true),
55+
ExecutableProduct("cuda_inspect_driver", :cuda_inspect_driver)
5156
]
5257

5358
dependencies = []
@@ -63,6 +68,7 @@ for platform in platforms
6368

6469
sources = get_sources("nvidia-driver", ["cuda_compat"]; version=driver_version,
6570
platform=augmented_platform, variant="cuda$(cuda_version.major).$(cuda_version.minor)")
71+
push!(sources, DirectorySource("./src"))
6672

6773
push!(builds, (; platforms=[platform], sources))
6874
end

C/CUDA/CUDA_Driver/init.jl

Lines changed: 22 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -55,89 +55,24 @@ if Libdl.dlopen(libcuda_system, Libdl.RTLD_NOLOAD; throw_error=false) !== nothin
5555
return
5656
end
5757

58+
# This shaves ~120ms off the load time
59+
precompile(Base.cmd_gen, (Tuple{Tuple{Base.Cmd}, Tuple{String}, Tuple{Bool}, Tuple{Array{String, 1}}},))
60+
precompile(Base.read, (Base.Cmd, Type{String}))
61+
5862
# helper function to load a driver, query its version, and optionally query device
5963
# capabilities. needs to happen in a separate process because dlclose is unreliable.
6064
function inspect_driver(driver, deps=String[]; inspect_devices=false)
61-
script = raw"""
62-
using Libdl
63-
64-
const DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75
65-
const DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76
66-
67-
function main(driver, inspect_devices, deps...)
68-
inspect_devices = parse(Bool, inspect_devices)
69-
70-
for dep in deps
71-
Libdl.dlopen(dep; throw_error=false) === nothing && exit(-1)
72-
end
73-
74-
library_handle = Libdl.dlopen(driver; throw_error=false)
75-
library_handle === nothing && return -1
76-
77-
cuInit = Libdl.dlsym(library_handle, "cuInit")
78-
status = ccall(cuInit, Cint, (UInt32,), 0)
79-
status == 0 || return -2
80-
81-
cuDriverGetVersion = Libdl.dlsym(library_handle, "cuDriverGetVersion")
82-
version = Ref{Cint}()
83-
status = ccall(cuDriverGetVersion, Cint, (Ptr{Cint},), version)
84-
status == 0 || return -3
85-
major, ver = divrem(version[], 1000)
86-
minor, patch = divrem(ver, 10)
87-
println(major, ".", minor, ".", patch)
88-
89-
if inspect_devices
90-
cuDeviceGetCount = Libdl.dlsym(library_handle, "cuDeviceGetCount")
91-
device_count = Ref{Cint}()
92-
status = ccall(cuDeviceGetCount, Cint, (Ptr{Cint},), device_count)
93-
status == 0 || return -4
94-
95-
cuDeviceGet = Libdl.dlsym(library_handle, "cuDeviceGet")
96-
cuDeviceGetAttribute = Libdl.dlsym(library_handle, "cuDeviceGetAttribute")
97-
for i in 1:device_count[]
98-
device = Ref{Cint}()
99-
status = ccall(cuDeviceGet, Cint, (Ptr{Cint}, Cint), device, i-1)
100-
status == 0 || return -5
101-
102-
major = Ref{Cint}()
103-
status = ccall(cuDeviceGetAttribute, Cint, (Ptr{Cint}, UInt32, Cint), major, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device[])
104-
status == 0 || return -6
105-
minor = Ref{Cint}()
106-
status = ccall(cuDeviceGetAttribute, Cint, (Ptr{Cint}, UInt32, Cint), minor, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device[])
107-
status == 0 || return -7
108-
println(major[], ".", minor[])
109-
end
110-
end
111-
112-
return 0
113-
end
114-
115-
exit(main(ARGS...))
116-
"""
117-
118-
# make sure we don't include any system image flags here since this will cause an infinite loop of __init__()
119-
cmd = ```$(Cmd(filter(!startswith(r"-J|--sysimage"), Base.julia_cmd().exec)))
120-
-O0 --compile=min -t1 --startup-file=no
121-
-e $script $driver $inspect_devices $deps```
122-
123-
# make sure we use a fresh environment we can load Libdl in
124-
cmd = addenv(cmd, "JULIA_LOAD_PATH" => nothing, "JULIA_DEPOT_PATH" => nothing)
65+
cmd = `$(cuda_inspect_driver()) $driver $inspect_devices $deps`
12566

12667
# run the command
127-
out = Pipe()
128-
proc = run(pipeline(cmd, stdin=devnull, stdout=out), wait=false)
129-
close(out.in)
130-
out_reader = @static if VERSION >= v"1.12-"
131-
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
132-
Threads.@spawn :samepool String.(readlines(out))
133-
else
134-
Threads.@spawn String.(readlines(out))
68+
version_strings = String[]
69+
try
70+
version_strings = split(read(cmd, String))
71+
catch _
72+
return nothing
13573
end
136-
wait(proc)
137-
success(proc) || return nothing
13874

13975
# parse the versions
140-
version_strings = fetch(out_reader)
14176
driver_version = parse(VersionNumber, version_strings[1])
14277
if inspect_devices
14378
device_capabilities = map(str -> parse(VersionNumber, str), version_strings[2:end])
@@ -148,18 +83,18 @@ function inspect_driver(driver, deps=String[]; inspect_devices=false)
14883
end
14984

15085
# fetch driver details
151-
compat_driver_task = @static if VERSION >= v"1.12-"
152-
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
153-
Threads.@spawn :samepool inspect_driver(libcuda_compat, libcuda_deps)
154-
else
155-
Threads.@spawn inspect_driver(libcuda_compat, libcuda_deps)
156-
end
157-
system_driver_task = @static if VERSION >= v"1.12-"
158-
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
159-
Threads.@spawn :samepool inspect_driver(libcuda_system; inspect_devices=true)
160-
else
161-
Threads.@spawn inspect_driver(libcuda_system; inspect_devices=true)
162-
end
86+
compat_driver_task = @static if VERSION >= v"1.12-"
87+
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
88+
Threads.@spawn :samepool inspect_driver(libcuda_compat, libcuda_deps)
89+
else
90+
Threads.@spawn inspect_driver(libcuda_compat, libcuda_deps)
91+
end
92+
system_driver_task = @static if VERSION >= v"1.12-"
93+
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
94+
Threads.@spawn :samepool inspect_driver(libcuda_system; inspect_devices=true)
95+
else
96+
Threads.@spawn inspect_driver(libcuda_system; inspect_devices=true)
97+
end
16398
compat_driver_details = fetch(compat_driver_task)
16499
if compat_driver_details === nothing
165100
@debug "Failed to load forwards-compatible driver."
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/* Script to check the usability of CUDA drivers. */
2+
3+
#include <dlfcn.h>
4+
#include <stdio.h>
5+
#include <stdlib.h>
6+
#include <string.h>
7+
8+
// These match the flags of Libdl.dlopen():
9+
// https://docs.julialang.org/en/v1/stdlib/Libdl/#Base.Libc.Libdl.dlopen
10+
#ifdef __APPLE__
11+
#define DLOPEN_FLAGS RTLD_LAZY | RTLD_DEEPBIND | RTLD_GLOBAL
12+
#else
13+
#define DLOPEN_FLAGS RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL
14+
#endif
15+
16+
const int DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75;
17+
const int DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76;
18+
19+
typedef int (*cuInit_t)(unsigned int);
20+
typedef int (*cuDriverGetVersion_t)(int *);
21+
typedef int (*cuDeviceGetCount_t)(int *);
22+
typedef int (*cuDeviceGet_t)(int *, int);
23+
typedef int (*cuDeviceGetAttribute_t)(int *, unsigned int, int);
24+
25+
26+
int main(int argc, char *argv[]) {
27+
if (argc < 3) {
28+
fprintf(stderr, "Usage: %s <driver> <inspect_devices> [deps...]\n", argv[0]);
29+
return -1;
30+
}
31+
32+
const char *driver = argv[1];
33+
int inspect_devices = (strcmp(argv[2], "true") == 0 || strcmp(argv[2], "1") == 0);
34+
35+
for (int i = 3; i < argc; i++) {
36+
if (dlopen(argv[i], DLOPEN_FLAGS) == NULL) {
37+
return -1;
38+
}
39+
}
40+
41+
void *library_handle = dlopen(driver, DLOPEN_FLAGS);
42+
if (library_handle == NULL) {
43+
return -1;
44+
}
45+
46+
cuInit_t cuInit = (cuInit_t)dlsym(library_handle, "cuInit");
47+
int status = cuInit(0);
48+
if (status != 0) {
49+
return -2;
50+
}
51+
52+
cuDriverGetVersion_t cuDriverGetVersion = (cuDriverGetVersion_t)dlsym(library_handle, "cuDriverGetVersion");
53+
int version;
54+
status = cuDriverGetVersion(&version);
55+
if (status != 0) {
56+
return -3;
57+
}
58+
int major = version / 1000;
59+
int ver = version % 1000;
60+
int minor = ver / 10;
61+
int patch = ver % 10;
62+
printf("%d.%d.%d\n", major, minor, patch);
63+
64+
if (inspect_devices) {
65+
cuDeviceGetCount_t cuDeviceGetCount = (cuDeviceGetCount_t)dlsym(library_handle, "cuDeviceGetCount");
66+
int device_count;
67+
status = cuDeviceGetCount(&device_count);
68+
if (status != 0) {
69+
return -4;
70+
}
71+
72+
cuDeviceGet_t cuDeviceGet = (cuDeviceGet_t)dlsym(library_handle, "cuDeviceGet");
73+
cuDeviceGetAttribute_t cuDeviceGetAttribute = (cuDeviceGetAttribute_t)dlsym(library_handle, "cuDeviceGetAttribute");
74+
75+
for (int i = 0; i < device_count; i++) {
76+
int device = -1;
77+
status = cuDeviceGet(&device, i);
78+
if (status != 0) {
79+
return -5;
80+
}
81+
82+
int dev_major;
83+
status = cuDeviceGetAttribute(&dev_major, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device);
84+
if (status != 0) {
85+
return -6;
86+
}
87+
88+
int dev_minor;
89+
status = cuDeviceGetAttribute(&dev_minor, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device);
90+
if (status != 0) {
91+
return -7;
92+
}
93+
94+
printf("%d.%d\n", dev_major, dev_minor);
95+
}
96+
}
97+
98+
return 0;
99+
}

0 commit comments

Comments
 (0)