|
| 1 | +//! CUDA fast path for GGUF matmul with BF16/F32 activations. |
| 2 | +
|
| 3 | +use std::collections::HashMap; |
| 4 | +use std::sync::{Mutex, OnceLock}; |
| 5 | + |
| 6 | +use super::cuda::{QCudaStorage, MATRIX_ROW_PADDING}; |
| 7 | +use super::GgmlDType; |
| 8 | +use crate::cuda_backend::DeviceId; |
| 9 | +use crate::{backend::BackendStorage, CudaDevice, CudaStorage, DType, Result, Shape}; |
| 10 | + |
| 11 | +use cudarc::driver::{CudaSlice, DevicePtr}; |
| 12 | + |
| 13 | +const Q8_1_BLOCK_SIZE: usize = 32; |
| 14 | +const Q8_1_TYPE_SIZE: usize = 36; // 2 halves (4 bytes) + QK8_1 int8 = 4 + 32 = 36 |
| 15 | + |
| 16 | +#[inline] |
| 17 | +fn pad(p: usize, q: usize) -> usize { |
| 18 | + p.div_ceil(q) * q |
| 19 | +} |
| 20 | + |
| 21 | +/// Quant types supported by the fast MMVQ kernels. |
| 22 | +fn supports(dtype: GgmlDType) -> bool { |
| 23 | + matches!( |
| 24 | + dtype, |
| 25 | + GgmlDType::Q4_0 |
| 26 | + | GgmlDType::Q4_1 |
| 27 | + | GgmlDType::Q5_0 |
| 28 | + | GgmlDType::Q5_1 |
| 29 | + | GgmlDType::Q8_0 |
| 30 | + | GgmlDType::Q2K |
| 31 | + | GgmlDType::Q3K |
| 32 | + | GgmlDType::Q4K |
| 33 | + | GgmlDType::Q5K |
| 34 | + | GgmlDType::Q6K |
| 35 | + ) |
| 36 | +} |
| 37 | + |
| 38 | +const MMVQ_MAX_BATCH: usize = 8; |
| 39 | + |
| 40 | +// --------------------------------------------------------------------------- |
| 41 | +// Per-device Q8_1 scratch workspace (grows-only, reused across calls). |
| 42 | +// --------------------------------------------------------------------------- |
| 43 | + |
| 44 | +struct WorkspaceSlot { |
| 45 | + slice: CudaSlice<u8>, |
| 46 | + cap: usize, |
| 47 | +} |
| 48 | + |
| 49 | +static WORKSPACE: OnceLock<Mutex<HashMap<DeviceId, WorkspaceSlot>>> = OnceLock::new(); |
| 50 | + |
| 51 | +/// Returns a device pointer to the scratch workspace, growing it if needed. |
| 52 | +/// The returned `MutexGuard` must be held alive until the kernels using |
| 53 | +/// this pointer have been launched (all launches are on the device's |
| 54 | +/// default stream, so they are serialised). |
| 55 | +fn workspace_ensure( |
| 56 | + dev: &CudaDevice, |
| 57 | + bytes: usize, |
| 58 | +) -> Result<( |
| 59 | + u64, |
| 60 | + std::sync::MutexGuard<'static, HashMap<DeviceId, WorkspaceSlot>>, |
| 61 | +)> { |
| 62 | + let map = WORKSPACE.get_or_init(|| Mutex::new(HashMap::new())); |
| 63 | + let device_key = dev.id(); |
| 64 | + let mut guard = map.lock().unwrap(); |
| 65 | + let slot = match guard.entry(device_key) { |
| 66 | + std::collections::hash_map::Entry::Occupied(entry) => { |
| 67 | + let slot = entry.into_mut(); |
| 68 | + if slot.cap < bytes { |
| 69 | + slot.slice = unsafe { dev.alloc::<u8>(bytes)? }; |
| 70 | + slot.cap = bytes; |
| 71 | + } |
| 72 | + slot |
| 73 | + } |
| 74 | + std::collections::hash_map::Entry::Vacant(entry) => { |
| 75 | + let slice = unsafe { dev.alloc::<u8>(bytes)? }; |
| 76 | + entry.insert(WorkspaceSlot { slice, cap: bytes }) |
| 77 | + } |
| 78 | + }; |
| 79 | + let ptr = slot.slice.device_ptr(slot.slice.stream()).0; |
| 80 | + Ok((ptr, guard)) |
| 81 | +} |
| 82 | + |
| 83 | +// --------------------------------------------------------------------------- |
| 84 | +// Launcher dispatch by weight dtype and output dtype. |
| 85 | +// --------------------------------------------------------------------------- |
| 86 | + |
| 87 | +type PlainLauncher = unsafe extern "C" fn( |
| 88 | + vx: *const std::ffi::c_void, |
| 89 | + vy: *const std::ffi::c_void, |
| 90 | + dst: *mut std::ffi::c_void, |
| 91 | + ncols_x: i32, |
| 92 | + nrows_x: i32, |
| 93 | + stride_col_y: i32, |
| 94 | + stride_col_dst: i32, |
| 95 | + b_size: i32, |
| 96 | + stream: *mut std::ffi::c_void, |
| 97 | +); |
| 98 | + |
| 99 | +fn plain_launcher_bf16(dtype: GgmlDType) -> Option<PlainLauncher> { |
| 100 | + use candle_kernels::ffi; |
| 101 | + let f: PlainLauncher = match dtype { |
| 102 | + GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_bf16_plain, |
| 103 | + GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_bf16_plain, |
| 104 | + GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_bf16_plain, |
| 105 | + GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_bf16_plain, |
| 106 | + GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_bf16_plain, |
| 107 | + GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_bf16_plain, |
| 108 | + GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_bf16_plain, |
| 109 | + GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_bf16_plain, |
| 110 | + GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_bf16_plain, |
| 111 | + GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_bf16_plain, |
| 112 | + _ => return None, |
| 113 | + }; |
| 114 | + Some(f) |
| 115 | +} |
| 116 | + |
| 117 | +fn plain_launcher_f16(dtype: GgmlDType) -> Option<PlainLauncher> { |
| 118 | + use candle_kernels::ffi; |
| 119 | + let f: PlainLauncher = match dtype { |
| 120 | + GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_f16_plain, |
| 121 | + GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_f16_plain, |
| 122 | + GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_f16_plain, |
| 123 | + GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_f16_plain, |
| 124 | + GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_f16_plain, |
| 125 | + GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_f16_plain, |
| 126 | + GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_f16_plain, |
| 127 | + GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_f16_plain, |
| 128 | + GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_f16_plain, |
| 129 | + GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_f16_plain, |
| 130 | + _ => return None, |
| 131 | + }; |
| 132 | + Some(f) |
| 133 | +} |
| 134 | + |
| 135 | +fn plain_launcher_f32(dtype: GgmlDType) -> Option<PlainLauncher> { |
| 136 | + use candle_kernels::ffi; |
| 137 | + let f: PlainLauncher = match dtype { |
| 138 | + GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_f32_plain, |
| 139 | + GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_f32_plain, |
| 140 | + GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_f32_plain, |
| 141 | + GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_f32_plain, |
| 142 | + GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_f32_plain, |
| 143 | + GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_f32_plain, |
| 144 | + GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_f32_plain, |
| 145 | + GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_f32_plain, |
| 146 | + GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_f32_plain, |
| 147 | + GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_f32_plain, |
| 148 | + _ => return None, |
| 149 | + }; |
| 150 | + Some(f) |
| 151 | +} |
| 152 | + |
| 153 | +// --------------------------------------------------------------------------- |
| 154 | +// Public entry point |
| 155 | +// --------------------------------------------------------------------------- |
| 156 | + |
| 157 | +/// Try the fast MMVQ path. Returns `Ok(None)` when the fast path is not applicable: |
| 158 | +/// - unsupported quant dtype |
| 159 | +/// - batch too large |
| 160 | +/// - non-BF16/F32 input |
| 161 | +pub fn try_fwd( |
| 162 | + qstorage: &QCudaStorage, |
| 163 | + self_shape: &Shape, |
| 164 | + rhs: &CudaStorage, |
| 165 | + rhs_l: &crate::Layout, |
| 166 | +) -> Result<Option<(CudaStorage, Shape)>> { |
| 167 | + use candle_kernels::ffi; |
| 168 | + |
| 169 | + // Gate checks. |
| 170 | + let w_dtype = qstorage.dtype(); |
| 171 | + if !supports(w_dtype) { |
| 172 | + return Ok(None); |
| 173 | + } |
| 174 | + let input_dtype = rhs.dtype(); |
| 175 | + if !matches!(input_dtype, DType::BF16 | DType::F16 | DType::F32) { |
| 176 | + return Ok(None); |
| 177 | + } |
| 178 | + |
| 179 | + let (nrows, ncols) = self_shape.dims2()?; |
| 180 | + |
| 181 | + let (b_size, k) = match rhs_l.shape().dims() { |
| 182 | + [b, m, k] => (b * m, *k), |
| 183 | + [b, k] => (*b, *k), |
| 184 | + _ => return Ok(None), |
| 185 | + }; |
| 186 | + if ncols != k { |
| 187 | + return Ok(None); |
| 188 | + } |
| 189 | + if b_size == 0 || b_size > MMVQ_MAX_BATCH { |
| 190 | + return Ok(None); |
| 191 | + } |
| 192 | + |
| 193 | + let (o1, o2) = match rhs_l.contiguous_offsets() { |
| 194 | + Some(offsets) => offsets, |
| 195 | + None => return Ok(None), |
| 196 | + }; |
| 197 | + |
| 198 | + let dev = qstorage.device(); |
| 199 | + let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void; |
| 200 | + |
| 201 | + let k_padded = pad(k, MATRIX_ROW_PADDING); |
| 202 | + let num_blocks_per_row = k_padded / Q8_1_BLOCK_SIZE; |
| 203 | + let dst_row_bytes = num_blocks_per_row * Q8_1_TYPE_SIZE; |
| 204 | + let scratch_bytes = b_size * dst_row_bytes; |
| 205 | + |
| 206 | + let (scratch_ptr, _workspace_guard) = workspace_ensure(dev, scratch_bytes)?; |
| 207 | + let scratch_ptr = scratch_ptr as *mut std::ffi::c_void; |
| 208 | + let stride_col_y = (k_padded / Q8_1_BLOCK_SIZE) as i32; |
| 209 | + let stride_col_dst = nrows as i32; |
| 210 | + let weight_ptr = qstorage.device_ptr()? as *const std::ffi::c_void; |
| 211 | + |
| 212 | + let mut out_shape = rhs_l.shape().dims().to_vec(); |
| 213 | + out_shape.pop(); |
| 214 | + out_shape.push(nrows); |
| 215 | + |
| 216 | + let stream = dev.cuda_stream(); |
| 217 | + |
| 218 | + match input_dtype { |
| 219 | + DType::BF16 => { |
| 220 | + let rhs_slice = rhs.as_cuda_slice::<half::bf16>()?; |
| 221 | + let rhs_slice = rhs_slice.slice(o1..o2); |
| 222 | + let out = unsafe { dev.alloc::<half::bf16>(nrows * b_size)? }; |
| 223 | + |
| 224 | + let rhs_ptr = rhs_slice.device_ptr(&stream).0 as *const std::ffi::c_void; |
| 225 | + let out_ptr = out.device_ptr(&stream).0 as *mut std::ffi::c_void; |
| 226 | + |
| 227 | + unsafe { |
| 228 | + ffi::launch_mmvq_gguf_quantize_q8_1_bf16( |
| 229 | + rhs_ptr, |
| 230 | + scratch_ptr, |
| 231 | + k as i32, |
| 232 | + k_padded as i32, |
| 233 | + b_size as i32, |
| 234 | + stream_ptr, |
| 235 | + ); |
| 236 | + let launcher = plain_launcher_bf16(w_dtype).unwrap(); |
| 237 | + launcher( |
| 238 | + weight_ptr, |
| 239 | + scratch_ptr as *const std::ffi::c_void, |
| 240 | + out_ptr, |
| 241 | + k as i32, |
| 242 | + nrows as i32, |
| 243 | + stride_col_y, |
| 244 | + stride_col_dst, |
| 245 | + b_size as i32, |
| 246 | + stream_ptr, |
| 247 | + ); |
| 248 | + } |
| 249 | + |
| 250 | + let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone()); |
| 251 | + Ok(Some((out_storage, out_shape.into()))) |
| 252 | + } |
| 253 | + DType::F16 => { |
| 254 | + let rhs_slice = rhs.as_cuda_slice::<half::f16>()?; |
| 255 | + let rhs_slice = rhs_slice.slice(o1..o2); |
| 256 | + let out = unsafe { dev.alloc::<half::f16>(nrows * b_size)? }; |
| 257 | + |
| 258 | + let rhs_ptr = rhs_slice.device_ptr(&stream).0 as *const std::ffi::c_void; |
| 259 | + let out_ptr = out.device_ptr(&stream).0 as *mut std::ffi::c_void; |
| 260 | + |
| 261 | + unsafe { |
| 262 | + ffi::launch_mmvq_gguf_quantize_q8_1_f16( |
| 263 | + rhs_ptr, |
| 264 | + scratch_ptr, |
| 265 | + k as i32, |
| 266 | + k_padded as i32, |
| 267 | + b_size as i32, |
| 268 | + stream_ptr, |
| 269 | + ); |
| 270 | + let launcher = plain_launcher_f16(w_dtype).unwrap(); |
| 271 | + launcher( |
| 272 | + weight_ptr, |
| 273 | + scratch_ptr as *const std::ffi::c_void, |
| 274 | + out_ptr, |
| 275 | + k as i32, |
| 276 | + nrows as i32, |
| 277 | + stride_col_y, |
| 278 | + stride_col_dst, |
| 279 | + b_size as i32, |
| 280 | + stream_ptr, |
| 281 | + ); |
| 282 | + } |
| 283 | + |
| 284 | + let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone()); |
| 285 | + Ok(Some((out_storage, out_shape.into()))) |
| 286 | + } |
| 287 | + DType::F32 => { |
| 288 | + let rhs_slice = rhs.as_cuda_slice::<f32>()?; |
| 289 | + let rhs_slice = rhs_slice.slice(o1..o2); |
| 290 | + let out = unsafe { dev.alloc::<f32>(nrows * b_size)? }; |
| 291 | + |
| 292 | + let rhs_ptr = rhs_slice.device_ptr(&stream).0 as *const std::ffi::c_void; |
| 293 | + let out_ptr = out.device_ptr(&stream).0 as *mut std::ffi::c_void; |
| 294 | + |
| 295 | + unsafe { |
| 296 | + ffi::launch_mmvq_gguf_quantize_q8_1_f32( |
| 297 | + rhs_ptr, |
| 298 | + scratch_ptr, |
| 299 | + k as i32, |
| 300 | + k_padded as i32, |
| 301 | + b_size as i32, |
| 302 | + stream_ptr, |
| 303 | + ); |
| 304 | + let launcher = plain_launcher_f32(w_dtype).unwrap(); |
| 305 | + launcher( |
| 306 | + weight_ptr, |
| 307 | + scratch_ptr as *const std::ffi::c_void, |
| 308 | + out_ptr, |
| 309 | + k as i32, |
| 310 | + nrows as i32, |
| 311 | + stride_col_y, |
| 312 | + stride_col_dst, |
| 313 | + b_size as i32, |
| 314 | + stream_ptr, |
| 315 | + ); |
| 316 | + } |
| 317 | + |
| 318 | + let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone()); |
| 319 | + Ok(Some((out_storage, out_shape.into()))) |
| 320 | + } |
| 321 | + _ => Ok(None), |
| 322 | + } |
| 323 | +} |
0 commit comments