Skip to content

Commit b503458

Browse files
authored
Add fast CUDA MMVQ GGUF kernels (huggingface#3463)
* Add fast CUDA MMVQ GGUF kernels * Support f16 * Apply review comments * Format * Fix
1 parent 34625ab commit b503458

12 files changed

Lines changed: 1671 additions & 60 deletions

File tree

candle-core/src/quantized/cuda.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ pub struct QCudaStorage {
1919
device: CudaDevice,
2020
}
2121

22-
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
22+
pub(crate) static FORCE_DMMV: std::sync::atomic::AtomicBool =
23+
std::sync::atomic::AtomicBool::new(false);
2324

2425
pub fn set_force_dmmv(f: bool) {
2526
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
@@ -720,6 +721,14 @@ impl QCudaStorage {
720721
storage: &CudaStorage,
721722
layout: &crate::Layout,
722723
) -> Result<(CudaStorage, crate::Shape)> {
724+
// Try the fast MMVQ path first (supports BF16//F16/F32, batch 1-8, all quant types, reuses per-device workspace).
725+
if !FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
726+
if let Some(result) = super::fast_mmvq::try_fwd(self, self_shape, storage, layout)? {
727+
return Ok(result);
728+
}
729+
}
730+
731+
// Fallback
723732
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
724733
1
725734
} else {
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
//! CUDA fast path for GGUF matmul with BF16/F32 activations.
2+
3+
use std::collections::HashMap;
4+
use std::sync::{Mutex, OnceLock};
5+
6+
use super::cuda::{QCudaStorage, MATRIX_ROW_PADDING};
7+
use super::GgmlDType;
8+
use crate::cuda_backend::DeviceId;
9+
use crate::{backend::BackendStorage, CudaDevice, CudaStorage, DType, Result, Shape};
10+
11+
use cudarc::driver::{CudaSlice, DevicePtr};
12+
13+
const Q8_1_BLOCK_SIZE: usize = 32;
14+
const Q8_1_TYPE_SIZE: usize = 36; // 2 halves (4 bytes) + QK8_1 int8 = 4 + 32 = 36
15+
16+
#[inline]
17+
fn pad(p: usize, q: usize) -> usize {
18+
p.div_ceil(q) * q
19+
}
20+
21+
/// Quant types supported by the fast MMVQ kernels.
22+
fn supports(dtype: GgmlDType) -> bool {
23+
matches!(
24+
dtype,
25+
GgmlDType::Q4_0
26+
| GgmlDType::Q4_1
27+
| GgmlDType::Q5_0
28+
| GgmlDType::Q5_1
29+
| GgmlDType::Q8_0
30+
| GgmlDType::Q2K
31+
| GgmlDType::Q3K
32+
| GgmlDType::Q4K
33+
| GgmlDType::Q5K
34+
| GgmlDType::Q6K
35+
)
36+
}
37+
38+
const MMVQ_MAX_BATCH: usize = 8;
39+
40+
// ---------------------------------------------------------------------------
41+
// Per-device Q8_1 scratch workspace (grows-only, reused across calls).
42+
// ---------------------------------------------------------------------------
43+
44+
struct WorkspaceSlot {
45+
slice: CudaSlice<u8>,
46+
cap: usize,
47+
}
48+
49+
static WORKSPACE: OnceLock<Mutex<HashMap<DeviceId, WorkspaceSlot>>> = OnceLock::new();
50+
51+
/// Returns a device pointer to the scratch workspace, growing it if needed.
52+
/// The returned `MutexGuard` must be held alive until the kernels using
53+
/// this pointer have been launched (all launches are on the device's
54+
/// default stream, so they are serialised).
55+
fn workspace_ensure(
56+
dev: &CudaDevice,
57+
bytes: usize,
58+
) -> Result<(
59+
u64,
60+
std::sync::MutexGuard<'static, HashMap<DeviceId, WorkspaceSlot>>,
61+
)> {
62+
let map = WORKSPACE.get_or_init(|| Mutex::new(HashMap::new()));
63+
let device_key = dev.id();
64+
let mut guard = map.lock().unwrap();
65+
let slot = match guard.entry(device_key) {
66+
std::collections::hash_map::Entry::Occupied(entry) => {
67+
let slot = entry.into_mut();
68+
if slot.cap < bytes {
69+
slot.slice = unsafe { dev.alloc::<u8>(bytes)? };
70+
slot.cap = bytes;
71+
}
72+
slot
73+
}
74+
std::collections::hash_map::Entry::Vacant(entry) => {
75+
let slice = unsafe { dev.alloc::<u8>(bytes)? };
76+
entry.insert(WorkspaceSlot { slice, cap: bytes })
77+
}
78+
};
79+
let ptr = slot.slice.device_ptr(slot.slice.stream()).0;
80+
Ok((ptr, guard))
81+
}
82+
83+
// ---------------------------------------------------------------------------
84+
// Launcher dispatch by weight dtype and output dtype.
85+
// ---------------------------------------------------------------------------
86+
87+
type PlainLauncher = unsafe extern "C" fn(
88+
vx: *const std::ffi::c_void,
89+
vy: *const std::ffi::c_void,
90+
dst: *mut std::ffi::c_void,
91+
ncols_x: i32,
92+
nrows_x: i32,
93+
stride_col_y: i32,
94+
stride_col_dst: i32,
95+
b_size: i32,
96+
stream: *mut std::ffi::c_void,
97+
);
98+
99+
fn plain_launcher_bf16(dtype: GgmlDType) -> Option<PlainLauncher> {
100+
use candle_kernels::ffi;
101+
let f: PlainLauncher = match dtype {
102+
GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_bf16_plain,
103+
GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_bf16_plain,
104+
GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_bf16_plain,
105+
GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_bf16_plain,
106+
GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_bf16_plain,
107+
GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_bf16_plain,
108+
GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_bf16_plain,
109+
GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_bf16_plain,
110+
GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_bf16_plain,
111+
GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_bf16_plain,
112+
_ => return None,
113+
};
114+
Some(f)
115+
}
116+
117+
fn plain_launcher_f16(dtype: GgmlDType) -> Option<PlainLauncher> {
118+
use candle_kernels::ffi;
119+
let f: PlainLauncher = match dtype {
120+
GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_f16_plain,
121+
GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_f16_plain,
122+
GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_f16_plain,
123+
GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_f16_plain,
124+
GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_f16_plain,
125+
GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_f16_plain,
126+
GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_f16_plain,
127+
GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_f16_plain,
128+
GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_f16_plain,
129+
GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_f16_plain,
130+
_ => return None,
131+
};
132+
Some(f)
133+
}
134+
135+
fn plain_launcher_f32(dtype: GgmlDType) -> Option<PlainLauncher> {
136+
use candle_kernels::ffi;
137+
let f: PlainLauncher = match dtype {
138+
GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_f32_plain,
139+
GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_f32_plain,
140+
GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_f32_plain,
141+
GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_f32_plain,
142+
GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_f32_plain,
143+
GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_f32_plain,
144+
GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_f32_plain,
145+
GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_f32_plain,
146+
GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_f32_plain,
147+
GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_f32_plain,
148+
_ => return None,
149+
};
150+
Some(f)
151+
}
152+
153+
// ---------------------------------------------------------------------------
154+
// Public entry point
155+
// ---------------------------------------------------------------------------
156+
157+
/// Try the fast MMVQ path. Returns `Ok(None)` when the fast path is not applicable:
158+
/// - unsupported quant dtype
159+
/// - batch too large
160+
/// - non-BF16/F32 input
161+
pub fn try_fwd(
162+
qstorage: &QCudaStorage,
163+
self_shape: &Shape,
164+
rhs: &CudaStorage,
165+
rhs_l: &crate::Layout,
166+
) -> Result<Option<(CudaStorage, Shape)>> {
167+
use candle_kernels::ffi;
168+
169+
// Gate checks.
170+
let w_dtype = qstorage.dtype();
171+
if !supports(w_dtype) {
172+
return Ok(None);
173+
}
174+
let input_dtype = rhs.dtype();
175+
if !matches!(input_dtype, DType::BF16 | DType::F16 | DType::F32) {
176+
return Ok(None);
177+
}
178+
179+
let (nrows, ncols) = self_shape.dims2()?;
180+
181+
let (b_size, k) = match rhs_l.shape().dims() {
182+
[b, m, k] => (b * m, *k),
183+
[b, k] => (*b, *k),
184+
_ => return Ok(None),
185+
};
186+
if ncols != k {
187+
return Ok(None);
188+
}
189+
if b_size == 0 || b_size > MMVQ_MAX_BATCH {
190+
return Ok(None);
191+
}
192+
193+
let (o1, o2) = match rhs_l.contiguous_offsets() {
194+
Some(offsets) => offsets,
195+
None => return Ok(None),
196+
};
197+
198+
let dev = qstorage.device();
199+
let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void;
200+
201+
let k_padded = pad(k, MATRIX_ROW_PADDING);
202+
let num_blocks_per_row = k_padded / Q8_1_BLOCK_SIZE;
203+
let dst_row_bytes = num_blocks_per_row * Q8_1_TYPE_SIZE;
204+
let scratch_bytes = b_size * dst_row_bytes;
205+
206+
let (scratch_ptr, _workspace_guard) = workspace_ensure(dev, scratch_bytes)?;
207+
let scratch_ptr = scratch_ptr as *mut std::ffi::c_void;
208+
let stride_col_y = (k_padded / Q8_1_BLOCK_SIZE) as i32;
209+
let stride_col_dst = nrows as i32;
210+
let weight_ptr = qstorage.device_ptr()? as *const std::ffi::c_void;
211+
212+
let mut out_shape = rhs_l.shape().dims().to_vec();
213+
out_shape.pop();
214+
out_shape.push(nrows);
215+
216+
let stream = dev.cuda_stream();
217+
218+
match input_dtype {
219+
DType::BF16 => {
220+
let rhs_slice = rhs.as_cuda_slice::<half::bf16>()?;
221+
let rhs_slice = rhs_slice.slice(o1..o2);
222+
let out = unsafe { dev.alloc::<half::bf16>(nrows * b_size)? };
223+
224+
let rhs_ptr = rhs_slice.device_ptr(&stream).0 as *const std::ffi::c_void;
225+
let out_ptr = out.device_ptr(&stream).0 as *mut std::ffi::c_void;
226+
227+
unsafe {
228+
ffi::launch_mmvq_gguf_quantize_q8_1_bf16(
229+
rhs_ptr,
230+
scratch_ptr,
231+
k as i32,
232+
k_padded as i32,
233+
b_size as i32,
234+
stream_ptr,
235+
);
236+
let launcher = plain_launcher_bf16(w_dtype).unwrap();
237+
launcher(
238+
weight_ptr,
239+
scratch_ptr as *const std::ffi::c_void,
240+
out_ptr,
241+
k as i32,
242+
nrows as i32,
243+
stride_col_y,
244+
stride_col_dst,
245+
b_size as i32,
246+
stream_ptr,
247+
);
248+
}
249+
250+
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
251+
Ok(Some((out_storage, out_shape.into())))
252+
}
253+
DType::F16 => {
254+
let rhs_slice = rhs.as_cuda_slice::<half::f16>()?;
255+
let rhs_slice = rhs_slice.slice(o1..o2);
256+
let out = unsafe { dev.alloc::<half::f16>(nrows * b_size)? };
257+
258+
let rhs_ptr = rhs_slice.device_ptr(&stream).0 as *const std::ffi::c_void;
259+
let out_ptr = out.device_ptr(&stream).0 as *mut std::ffi::c_void;
260+
261+
unsafe {
262+
ffi::launch_mmvq_gguf_quantize_q8_1_f16(
263+
rhs_ptr,
264+
scratch_ptr,
265+
k as i32,
266+
k_padded as i32,
267+
b_size as i32,
268+
stream_ptr,
269+
);
270+
let launcher = plain_launcher_f16(w_dtype).unwrap();
271+
launcher(
272+
weight_ptr,
273+
scratch_ptr as *const std::ffi::c_void,
274+
out_ptr,
275+
k as i32,
276+
nrows as i32,
277+
stride_col_y,
278+
stride_col_dst,
279+
b_size as i32,
280+
stream_ptr,
281+
);
282+
}
283+
284+
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
285+
Ok(Some((out_storage, out_shape.into())))
286+
}
287+
DType::F32 => {
288+
let rhs_slice = rhs.as_cuda_slice::<f32>()?;
289+
let rhs_slice = rhs_slice.slice(o1..o2);
290+
let out = unsafe { dev.alloc::<f32>(nrows * b_size)? };
291+
292+
let rhs_ptr = rhs_slice.device_ptr(&stream).0 as *const std::ffi::c_void;
293+
let out_ptr = out.device_ptr(&stream).0 as *mut std::ffi::c_void;
294+
295+
unsafe {
296+
ffi::launch_mmvq_gguf_quantize_q8_1_f32(
297+
rhs_ptr,
298+
scratch_ptr,
299+
k as i32,
300+
k_padded as i32,
301+
b_size as i32,
302+
stream_ptr,
303+
);
304+
let launcher = plain_launcher_f32(w_dtype).unwrap();
305+
launcher(
306+
weight_ptr,
307+
scratch_ptr as *const std::ffi::c_void,
308+
out_ptr,
309+
k as i32,
310+
nrows as i32,
311+
stride_col_y,
312+
stride_col_dst,
313+
b_size as i32,
314+
stream_ptr,
315+
);
316+
}
317+
318+
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
319+
Ok(Some((out_storage, out_shape.into())))
320+
}
321+
_ => Ok(None),
322+
}
323+
}

candle-core/src/quantized/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ mod metal {
2222
}
2323
#[cfg(feature = "cuda")]
2424
pub mod cuda;
25+
#[cfg(feature = "cuda")]
26+
pub mod fast_mmvq;
2527
#[cfg(not(feature = "cuda"))]
2628
mod cuda {
2729
pub use super::dummy_cuda::*;

candle-examples/examples/gemma4/main.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,7 @@ fn main() -> Result<()> {
293293
None => {
294294
let config_file = repo.get("config.json")?;
295295
// For text-only, try to parse the text_config sub-object
296-
let raw: serde_json::Value =
297-
serde_json::from_slice(&std::fs::read(config_file)?)?;
296+
let raw: serde_json::Value = serde_json::from_slice(&std::fs::read(config_file)?)?;
298297
if let Some(text_cfg) = raw.get("text_config") {
299298
serde_json::from_value(text_cfg.clone())?
300299
} else {

candle-kernels/build.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ fn main() -> Result<()> {
1313
let ptx_path = out_dir.join("ptx.rs");
1414
let bindings = KernelBuilder::new()
1515
.source_dir("src") // Scan src/ for .cu files
16-
.exclude(&["moe_*.cu"]) // Exclude moe kernels for ptx build
16+
.exclude(&["moe_*.cu", "mmvq_gguf.cu"]) // Exclude statically compiled kernels from ptx build
1717
.arg("--expt-relaxed-constexpr")
1818
.arg("-std=c++17")
1919
.arg("-O3")
@@ -26,6 +26,7 @@ fn main() -> Result<()> {
2626
"src/moe/moe_gguf.cu",
2727
"src/moe/moe_wmma.cu",
2828
"src/moe/moe_wmma_gguf.cu",
29+
"src/mmvq_gguf.cu",
2930
])
3031
.arg("--expt-relaxed-constexpr")
3132
.arg("-std=c++17")

0 commit comments

Comments
 (0)