diff --git a/src/lib.rs b/src/lib.rs index a6632a94..7e6a6c90 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,9 @@ #![allow(clippy::uninlined_format_args)] #![cfg_attr(test, feature(test))] +#[cfg(feature = "vulkan")] +pub mod vulkan; + mod common_logging; mod error; mod ggml_logging_hook; diff --git a/src/vulkan.rs b/src/vulkan.rs new file mode 100644 index 00000000..67456d8b --- /dev/null +++ b/src/vulkan.rs @@ -0,0 +1,72 @@ +use std::{ffi::CStr, os::raw::c_int}; +use whisper_rs_sys::{ + ggml_backend_buffer_type_t, ggml_backend_vk_buffer_type, ggml_backend_vk_get_device_count, + ggml_backend_vk_get_device_description, ggml_backend_vk_get_device_memory, +}; + +#[derive(Debug, Clone)] +pub struct VKVram { + pub free: usize, + pub total: usize, +} + +/// Human-readable device information +#[derive(Debug, Clone)] +pub struct VkDeviceInfo { + pub id: i32, + pub name: String, + pub vram: VKVram, + /// Buffer type to pass to `whisper::Backend::create_buffer` + pub buf_type: ggml_backend_buffer_type_t, +} +/// Enumerate every physical GPU ggml can see. +/// +/// Note: integrated GPUs are returned *after* discrete ones, +/// mirroring ggml’s C logic. +pub fn list_devices() -> Vec { + unsafe { + let n = ggml_backend_vk_get_device_count(); + (0..n) + .map(|id| { + // 256 bytes is plenty (spec says 128 is enough) + let mut tmp = [0i8; 256]; + ggml_backend_vk_get_device_description(id as c_int, tmp.as_mut_ptr(), tmp.len()); + let mut free = 0usize; + let mut total = 0usize; + ggml_backend_vk_get_device_memory(id, &mut free, &mut total); + VkDeviceInfo { + id, + name: CStr::from_ptr(tmp.as_ptr()).to_string_lossy().into_owned(), + vram: VKVram { free, total }, + buf_type: ggml_backend_vk_buffer_type(id as usize), + } + }) + .collect() + } +} + +#[cfg(test)] +mod vulkan_tests { + use super::*; + + #[test] + fn enumerate_must_not_panic() { + let _ = list_devices(); + } + + #[test] + fn sane_device_info() { + let gpus = list_devices(); + let mut seen = std::collections::HashSet::new(); + + for dev in &gpus { + assert!(seen.insert(dev.id), "duplicated id {}", dev.id); + assert!(!dev.name.trim().is_empty(), "GPU {} has empty name", dev.id); + assert!( + dev.vram.total >= dev.vram.free, + "GPU {} total < free", + dev.id + ); + } + } +} diff --git a/sys/build.rs b/sys/build.rs index 5db8a88a..d7e6befc 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -119,10 +119,18 @@ fn main() { let _: u64 = std::fs::copy("src/bindings.rs", out.join("bindings.rs")) .expect("Failed to copy bindings.rs"); } else { - let bindings = bindgen::Builder::default().header("wrapper.h"); + let mut bindings = bindgen::Builder::default().header("wrapper.h"); #[cfg(feature = "metal")] - let bindings = bindings.header("whisper.cpp/ggml/include/ggml-metal.h"); + { + bindings = bindings.header("whisper.cpp/ggml/include/ggml-metal.h"); + } + #[cfg(feature = "vulkan")] + { + bindings = bindings + .header("whisper.cpp/ggml/include/ggml-vulkan.h") + .clang_arg("-DGGML_USE_VULKAN=1"); + } let bindings = bindings .clang_arg("-I./whisper.cpp/") diff --git a/sys/wrapper.h b/sys/wrapper.h index 0b0bbb95..26d880c0 100644 --- a/sys/wrapper.h +++ b/sys/wrapper.h @@ -1,2 +1,6 @@ #include #include + +#ifdef GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif