Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 61 additions & 30 deletions llamafile/cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ static bool LinkCuda(const char *dso) {
else
*(void **)(&g_cuda.backend_reg.default_abi) = sym;

// Optional - don't fail if not found
// Required: TryGpuBackend uses this to reject 0-device DSOs
sym = cosmo_dlsym(lib, "ggml_backend_cuda_get_device_count");
ok &= (sym != NULL);
if (IsWindows())
*(void **)(&g_cuda.get_device_count.windows_abi) = sym;
else
Expand Down Expand Up @@ -158,6 +159,53 @@ static bool LinkCuda(const char *dso) {
return true;
}

static void UnlinkCuda(void) {
if (g_cuda.lib_handle) {
cosmo_dlclose(g_cuda.lib_handle);
g_cuda.lib_handle = NULL;
}
memset(&g_cuda.backend_init, 0, sizeof(g_cuda.backend_init));
memset(&g_cuda.backend_reg, 0, sizeof(g_cuda.backend_reg));
memset(&g_cuda.get_device_count, 0, sizeof(g_cuda.get_device_count));
memset(&g_cuda.get_device_description, 0, sizeof(g_cuda.get_device_description));
memset(&g_cuda.log_set, 0, sizeof(g_cuda.log_set));
}

static bool TryGpuBackend(const char *dso, bool is_amd) {
if (!llamafile_try_load_prebuilt_dso(dso, "cuda", LinkCuda))
return false;

// Suppress the DSO's ggml logging before we touch any function that
// triggers ggml_cuda_init() (e.g. get_device_count). Without this, a
// failed init on the wrong backend would print a confusing error to
// stderr even when --verbose is not set.
if (!FLAG_verbose && (g_cuda.log_set.default_abi || g_cuda.log_set.windows_abi)) {
if (IsWindows())
g_cuda.log_set.windows_abi(llamafile_log_callback_null, NULL);
else
g_cuda.log_set.default_abi(llamafile_log_callback_null, NULL);
}

// Verify the backend has at least one device before committing. The DSO
// loads fine even when no compatible hardware is present, so we must
// probe device count to avoid registering a 0-device backend (which
// would then prevent fallback to other GPU backends in AUTO mode).
int count;
if (IsWindows())
count = g_cuda.get_device_count.windows_abi();
else
count = g_cuda.get_device_count.default_abi();
if (count <= 0) {
llamafile_info("cuda", "%s library loaded but no devices detected; trying next backend",
is_amd ? "ROCm" : "CUDA");
UnlinkCuda();
return false;
}

g_cuda.is_amd = is_amd;
return true;
}

static bool ImportCudaImpl(void) {
// Skip on Apple Silicon (use Metal instead)
if (IsXnuSilicon()) {
Expand All @@ -168,9 +216,7 @@ static bool ImportCudaImpl(void) {
switch (FLAG_gpu) {
case LLAMAFILE_GPU_AUTO:
case LLAMAFILE_GPU_NVIDIA:
break;
case LLAMAFILE_GPU_AMD:
g_cuda.is_amd = true;
break;
default:
return false;
Expand All @@ -183,19 +229,16 @@ static bool ImportCudaImpl(void) {
snprintf(cuda_dso, sizeof(cuda_dso), "ggml-cuda.%s", ext);
snprintf(rocm_dso, sizeof(rocm_dso), "ggml-rocm.%s", ext);

// Try to load pre-built DSO
if (FLAG_gpu == LLAMAFILE_GPU_AMD || FLAG_gpu == LLAMAFILE_GPU_AUTO) {
if (llamafile_try_load_prebuilt_dso(rocm_dso, "cuda", LinkCuda)) {
g_cuda.is_amd = true;
// In AUTO mode, prefer CUDA over ROCm: it covers the common NVIDIA case
// and lets ROCm be the fallback when CUDA is absent or has no devices.
if (FLAG_gpu == LLAMAFILE_GPU_NVIDIA || FLAG_gpu == LLAMAFILE_GPU_AUTO) {
if (TryGpuBackend(cuda_dso, false))
goto RegisterBackend;
}
}

if (FLAG_gpu == LLAMAFILE_GPU_NVIDIA || FLAG_gpu == LLAMAFILE_GPU_AUTO) {
if (llamafile_try_load_prebuilt_dso(cuda_dso, "cuda", LinkCuda)) {
g_cuda.is_amd = false;
if (FLAG_gpu == LLAMAFILE_GPU_AMD || FLAG_gpu == LLAMAFILE_GPU_AUTO) {
if (TryGpuBackend(rocm_dso, true))
goto RegisterBackend;
}
}

// No pre-built DSO found
Expand All @@ -206,16 +249,6 @@ static bool ImportCudaImpl(void) {
return false;

RegisterBackend:
// Suppress DSO's ggml logging before backend registration, which triggers
// ggml_cuda_init() inside the DSO. Without this, CUDA device enumeration
// messages appear even when --verbose is not set.
if (!FLAG_verbose && (g_cuda.log_set.default_abi || g_cuda.log_set.windows_abi)) {
if (IsWindows())
g_cuda.log_set.windows_abi(llamafile_log_callback_null, NULL);
else
g_cuda.log_set.default_abi(llamafile_log_callback_null, NULL);
}

// Register the CUDA backend with GGML
if (g_cuda.backend_reg.default_abi || g_cuda.backend_reg.windows_abi) {
ggml_backend_reg_t reg;
Expand All @@ -238,14 +271,12 @@ static void ImportCuda(void) {
g_cuda.supported = true;
llamafile_info("cuda", "%s GPU support successfully loaded",
g_cuda.is_amd ? "AMD ROCm" : "NVIDIA CUDA");
if (g_cuda.get_device_count.default_abi || g_cuda.get_device_count.windows_abi) {
int count;
if (IsWindows())
count = g_cuda.get_device_count.windows_abi();
else
count = g_cuda.get_device_count.default_abi();
llamafile_info("cuda", "found %d GPU device(s)", count);
}
int count;
if (IsWindows())
count = g_cuda.get_device_count.windows_abi();
else
count = g_cuda.get_device_count.default_abi();
llamafile_info("cuda", "found %d GPU device(s)", count);
} else if (FLAG_gpu == LLAMAFILE_GPU_NVIDIA || FLAG_gpu == LLAMAFILE_GPU_AMD) {
fprintf(stderr, "fatal error: support for --gpu %s was explicitly requested, "
"but it wasn't available\n", llamafile_describe_gpu());
Expand Down
Loading