Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion C/CUDA/CUDA_Driver/build_tarballs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ cuda_version = v"13.1"
driver_version = "590.48.01"

script = raw"""
# Build the driver inspection binary
mkdir -p ${bindir}
${CC} -std=c99 -ldl cuda_inspect_driver.c -o ${bindir}/cuda_inspect_driver

mkdir -p ${libdir}

cd ${WORKSPACE}/srcdir/cuda_compat*
Expand Down Expand Up @@ -48,6 +52,7 @@ products = [
LibraryProduct("libnvidia-nvvm", :libnvidia_nvvm; dont_dlopen=true),
LibraryProduct("libnvidia-ptxjitcompiler", :libnvidia_ptxjitcompiler; dont_dlopen=true),
LibraryProduct("libnvidia-tileiras", :libnvidia_tileiras; dont_dlopen=true),
ExecutableProduct("cuda_inspect_driver", :cuda_inspect_driver)
]

dependencies = []
Expand All @@ -63,6 +68,7 @@ for platform in platforms

sources = get_sources("nvidia-driver", ["cuda_compat"]; version=driver_version,
platform=augmented_platform, variant="cuda$(cuda_version.major).$(cuda_version.minor)")
push!(sources, DirectorySource("./src"))

push!(builds, (; platforms=[platform], sources))
end
Expand All @@ -78,5 +84,13 @@ for (i,build) in enumerate(builds)
build_tarballs(i == lastindex(builds) ? non_platform_ARGS : non_reg_ARGS,
name, cuda_version, build.sources, script,
build.platforms, products, dependencies;
skip_audit=true, init_block)
skip_audit=true, init_block, julia_compat="1.10",
augment_platform_block="""
# This shaves ~120ms off the load time
precompile(Base.cmd_gen, (Tuple{Tuple{Base.Cmd}, Tuple{String}, Tuple{Bool}, Tuple{Array{String, 1}}},))
precompile(Base.read, (Base.Cmd, Type{String}))
precompile(Tuple{typeof(Base.arg_gen), Bool})

augment_platform! = identity
""")
end
105 changes: 18 additions & 87 deletions C/CUDA/CUDA_Driver/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,86 +58,17 @@ end
# helper function to load a driver, query its version, and optionally query device
# capabilities. needs to happen in a separate process because dlclose is unreliable.
function inspect_driver(driver, deps=String[]; inspect_devices=false)
script = raw"""
using Libdl

const DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75
const DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76

function main(driver, inspect_devices, deps...)
inspect_devices = parse(Bool, inspect_devices)

for dep in deps
Libdl.dlopen(dep; throw_error=false) === nothing && exit(-1)
end

library_handle = Libdl.dlopen(driver; throw_error=false)
library_handle === nothing && return -1

cuInit = Libdl.dlsym(library_handle, "cuInit")
status = ccall(cuInit, Cint, (UInt32,), 0)
status == 0 || return -2

cuDriverGetVersion = Libdl.dlsym(library_handle, "cuDriverGetVersion")
version = Ref{Cint}()
status = ccall(cuDriverGetVersion, Cint, (Ptr{Cint},), version)
status == 0 || return -3
major, ver = divrem(version[], 1000)
minor, patch = divrem(ver, 10)
println(major, ".", minor, ".", patch)

if inspect_devices
cuDeviceGetCount = Libdl.dlsym(library_handle, "cuDeviceGetCount")
device_count = Ref{Cint}()
status = ccall(cuDeviceGetCount, Cint, (Ptr{Cint},), device_count)
status == 0 || return -4

cuDeviceGet = Libdl.dlsym(library_handle, "cuDeviceGet")
cuDeviceGetAttribute = Libdl.dlsym(library_handle, "cuDeviceGetAttribute")
for i in 1:device_count[]
device = Ref{Cint}()
status = ccall(cuDeviceGet, Cint, (Ptr{Cint}, Cint), device, i-1)
status == 0 || return -5

major = Ref{Cint}()
status = ccall(cuDeviceGetAttribute, Cint, (Ptr{Cint}, UInt32, Cint), major, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device[])
status == 0 || return -6
minor = Ref{Cint}()
status = ccall(cuDeviceGetAttribute, Cint, (Ptr{Cint}, UInt32, Cint), minor, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device[])
status == 0 || return -7
println(major[], ".", minor[])
end
end

return 0
end

exit(main(ARGS...))
"""

# make sure we don't include any system image flags here since this will cause an infinite loop of __init__()
cmd = ```$(Cmd(filter(!startswith(r"-J|--sysimage"), Base.julia_cmd().exec)))
-O0 --compile=min -t1 --startup-file=no
-e $script $driver $inspect_devices $deps```

# make sure we use a fresh environment we can load Libdl in
cmd = addenv(cmd, "JULIA_LOAD_PATH" => nothing, "JULIA_DEPOT_PATH" => nothing)
cmd = `$(cuda_inspect_driver()) $driver $inspect_devices $deps`

# run the command
out = Pipe()
proc = run(pipeline(cmd, stdin=devnull, stdout=out), wait=false)
close(out.in)
out_reader = @static if VERSION >= v"1.12-"
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
Threads.@spawn :samepool String.(readlines(out))
else
Threads.@spawn String.(readlines(out))
version_strings = String[]
try
version_strings = split(read(cmd, String))
catch _
return nothing
end
wait(proc)
success(proc) || return nothing

# parse the versions
version_strings = fetch(out_reader)
driver_version = parse(VersionNumber, version_strings[1])
if inspect_devices
device_capabilities = map(str -> parse(VersionNumber, str), version_strings[2:end])
Expand All @@ -148,18 +79,18 @@ function inspect_driver(driver, deps=String[]; inspect_devices=false)
end

# fetch driver details
compat_driver_task = @static if VERSION >= v"1.12-"
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
Threads.@spawn :samepool inspect_driver(libcuda_compat, libcuda_deps)
else
Threads.@spawn inspect_driver(libcuda_compat, libcuda_deps)
end
system_driver_task = @static if VERSION >= v"1.12-"
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
Threads.@spawn :samepool inspect_driver(libcuda_system; inspect_devices=true)
else
Threads.@spawn inspect_driver(libcuda_system; inspect_devices=true)
end
compat_driver_task = @static if VERSION >= v"1.12-"
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
Threads.@spawn :samepool inspect_driver(libcuda_compat, libcuda_deps)
else
Threads.@spawn inspect_driver(libcuda_compat, libcuda_deps)
end
system_driver_task = @static if VERSION >= v"1.12-"
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
Threads.@spawn :samepool inspect_driver(libcuda_system; inspect_devices=true)
else
Threads.@spawn inspect_driver(libcuda_system; inspect_devices=true)
end
compat_driver_details = fetch(compat_driver_task)
if compat_driver_details === nothing
@debug "Failed to load forwards-compatible driver."
Expand Down
99 changes: 99 additions & 0 deletions C/CUDA/CUDA_Driver/src/cuda_inspect_driver.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/* Script to check the usability of CUDA drivers. */

#include <dlfcn.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// These match the flags of Libdl.dlopen():
// https://docs.julialang.org/en/v1/stdlib/Libdl/#Base.Libc.Libdl.dlopen
#ifdef __APPLE__
#define DLOPEN_FLAGS RTLD_LAZY | RTLD_DEEPBIND | RTLD_GLOBAL
#else
#define DLOPEN_FLAGS RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL
#endif

const int DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75;
const int DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76;

typedef int (*cuInit_t)(unsigned int);
typedef int (*cuDriverGetVersion_t)(int *);
typedef int (*cuDeviceGetCount_t)(int *);
typedef int (*cuDeviceGet_t)(int *, int);
typedef int (*cuDeviceGetAttribute_t)(int *, unsigned int, int);


int main(int argc, char *argv[]) {
if (argc < 3) {
fprintf(stderr, "Usage: %s <driver> <inspect_devices> [deps...]\n", argv[0]);
return -1;
}

const char *driver = argv[1];
int inspect_devices = (strcmp(argv[2], "true") == 0 || strcmp(argv[2], "1") == 0);

for (int i = 3; i < argc; i++) {
if (dlopen(argv[i], DLOPEN_FLAGS) == NULL) {
return -1;
}
}

void *library_handle = dlopen(driver, DLOPEN_FLAGS);
if (library_handle == NULL) {
return -1;
}

cuInit_t cuInit = (cuInit_t)dlsym(library_handle, "cuInit");
int status = cuInit(0);
if (status != 0) {
return -2;
}

cuDriverGetVersion_t cuDriverGetVersion = (cuDriverGetVersion_t)dlsym(library_handle, "cuDriverGetVersion");
int version;
status = cuDriverGetVersion(&version);
if (status != 0) {
return -3;
}
int major = version / 1000;
int ver = version % 1000;
int minor = ver / 10;
int patch = ver % 10;
printf("%d.%d.%d\n", major, minor, patch);

if (inspect_devices) {
cuDeviceGetCount_t cuDeviceGetCount = (cuDeviceGetCount_t)dlsym(library_handle, "cuDeviceGetCount");
int device_count;
status = cuDeviceGetCount(&device_count);
if (status != 0) {
return -4;
}

cuDeviceGet_t cuDeviceGet = (cuDeviceGet_t)dlsym(library_handle, "cuDeviceGet");
cuDeviceGetAttribute_t cuDeviceGetAttribute = (cuDeviceGetAttribute_t)dlsym(library_handle, "cuDeviceGetAttribute");

for (int i = 0; i < device_count; i++) {
int device = -1;
status = cuDeviceGet(&device, i);
if (status != 0) {
return -5;
}

int dev_major;
status = cuDeviceGetAttribute(&dev_major, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device);
if (status != 0) {
return -6;
}

int dev_minor;
status = cuDeviceGetAttribute(&dev_minor, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device);
if (status != 0) {
return -7;
}

printf("%d.%d\n", dev_major, dev_minor);
}
}

return 0;
}