Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 70 additions & 14 deletions candle-core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
//! Candle-specific Error and Result
use std::{convert::Infallible, fmt::Display};

use crate::{DType, DeviceLocation, Layout, MetalError, Shape};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -209,6 +211,13 @@ pub enum Error {
#[error("{0}")]
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),

/// Arbitrary errors wrapping with context.
#[error("{wrapped:?}\n{context:?}")]
WrappedContext {
wrapped: Box<dyn std::error::Error + Send + Sync>,
context: String,
},

#[error("{context}\n{inner}")]
Context {
inner: Box<Self>,
Expand Down Expand Up @@ -299,40 +308,87 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
}
}

// Taken from anyhow.
pub trait Context<T> {
pub(crate) mod private {
pub trait Sealed {}

impl<T, E> Sealed for std::result::Result<T, E> where E: std::error::Error {}
impl<T> Sealed for Option<T> {}
}

/// Attach more context to an error.
///
/// Inspired by [`anyhow::Context`].
pub trait Context<T, E>: private::Sealed {
/// Wrap the error value with additional context.
fn context<C>(self, context: C) -> Result<T>
fn context<C>(self, context: C) -> std::result::Result<T, Error>
where
C: std::fmt::Display + Send + Sync + 'static;
C: Display + Send + Sync + 'static;

/// Wrap the error value with additional context that is evaluated lazily
/// only once an error does occur.
fn with_context<C, F>(self, f: F) -> Result<T>
fn with_context<C, F>(self, f: F) -> std::result::Result<T, Error>
where
C: std::fmt::Display + Send + Sync + 'static,
C: Display + Send + Sync + 'static,
F: FnOnce() -> C;
}

impl<T> Context<T> for Option<T> {
fn context<C>(self, context: C) -> Result<T>
impl<T, E> Context<T, E> for std::result::Result<T, E>
where
E: std::error::Error + Send + Sync + 'static,
{
fn context<C>(self, context: C) -> std::result::Result<T, Error>
where
C: std::fmt::Display + Send + Sync + 'static,
C: Display + Send + Sync + 'static,
{
// Not using map_err to save 2 useless frames off the captured backtrace
// in ext_context.
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(context).bt()),
Ok(ok) => Ok(ok),
Err(error) => Err(Error::WrappedContext {
wrapped: Box::new(error),
context: context.to_string(),
}
.bt()),
}
}

fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
where
C: Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
match self {
Ok(ok) => Ok(ok),
Err(error) => Err(Error::WrappedContext {
wrapped: Box::new(error),
context: context().to_string(),
}
.bt()),
}
}
}

impl<T> Context<T, Infallible> for Option<T> {
fn context<C>(self, context: C) -> std::result::Result<T, Error>
where
C: Display + Send + Sync + 'static,
{
// Not using ok_or_else to save 2 useless frames off the captured
// backtrace.
match self {
Some(ok) => Ok(ok),
None => Err(Error::msg(context).bt()),
}
}

fn with_context<C, F>(self, f: F) -> Result<T>
fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
where
C: std::fmt::Display + Send + Sync + 'static,
C: Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(f()).bt()),
None => Err(Error::UnwrapNone.context(context()).bt()),
}
}
}
23 changes: 23 additions & 0 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ pub const RESOURCE_OPTIONS: MTLResourceOptions =
//| MTLResourceOptions::HazardTrackingModeUntracked.bits(),
//);

// Resource options used for `new_private_buffer`. This uses `private` where supported.
#[cfg(target_os = "ios")]
pub const PRIVATE_RESOURCE_OPTIONS: MTLResourceOptions = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
pub const PRIVATE_RESOURCE_OPTIONS: MTLResourceOptions = MTLResourceOptions::StorageModePrivate;

impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.id)
Expand Down Expand Up @@ -167,6 +173,23 @@ impl MetalDevice {
self.allocate_buffer(size)
}

/// Creates a new private buffer (not necessarily zeroed).
///
/// This is intentionally not in the Metal buffer pool to allow the efficient implementation of persistent buffers.
pub fn new_private_buffer(
Comment on lines +176 to +179
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this is nice to have, but I think we should name it something other than private buffer since that already means something for metal buffers (only available on gpu, ref).
We don't want to use actual metal private buffers as that isn't supported on iOS.

How about new_unpooled_buffer or new_persistent_buffer? :)

Copy link
Member Author

@EricLBuehler EricLBuehler Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this was a mistake on my part. The correct behavior that I intended for this function is to have:

  • private if not on iOS
  • shared/RESOURCE_OPTIONS if on iOS

Copy link
Member

@ivarflakstad ivarflakstad Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Could I ask why you want it to be private?
According to Apple's documentation there is no performance benefit, so private is usually used when you want to ensure that the cpu does not have access to the buffer for some specific reason. I'd wager a guess this kind of behaviour is frequently used in gaming.

&self,
element_count: usize,
dtype: DType,
_name: &str,
) -> Result<Arc<Buffer>> {
let size = element_count * dtype.size_in_bytes();
let buffer = self
.device
.new_buffer(size, PRIVATE_RESOURCE_OPTIONS)
.map_err(MetalError::from)?;
Ok(Arc::new(buffer))
}

/// Creates a new buffer from data.
///
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
Expand Down
118 changes: 118 additions & 0 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,125 @@ fn mul_mat_via_q8_1(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}

fn indexed_moe_forward_fused_q8_1_input(
weight: &CudaView<u8>,
w_shape: &crate::Shape, //[num_experts, n, k]
w_dtype: GgmlDType,
input: &CudaSlice<f32>,
in_shape: &crate::Shape, //[batch, topk or 1, k]
ids: &CudaView<u32>,
idx_shape: &crate::Shape, //[batch, topk]
dev: &CudaDevice,
) -> Result<(CudaStorage, crate::Shape)> {
let (_, n, k) = w_shape.dims3()?;
let batch = in_shape.dims()[0];
let input_dim1 = in_shape.dims()[1];

let topk = idx_shape.dims()[1];
assert!(batch == idx_shape.dims()[0], "batch dim not match!");

// Quantize input into q8_1.
let total_rows = batch * input_dim1;
let k_padded = pad(k, MATRIX_ROW_PADDING);
// Get Q8_1 metadata.
let q8_1_block_size = GgmlDType::Q8_1.block_size();
let q8_1_type_size = GgmlDType::Q8_1.type_size();

// Calculate the size of the output buffer in bytes.
let num_blocks_per_row = k_padded / q8_1_block_size;
let dst_row_size_bytes = num_blocks_per_row * q8_1_type_size;
let y_size_in_bytes = total_rows * dst_row_size_bytes;
let mut input_quant = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };

let input_view = input.slice(0..);
quantize_q8_1(&input_view, &mut input_quant, k, total_rows, dev)?;

// output buffer
let outsize = batch * topk * n;
let out = unsafe { dev.alloc::<f32>(outsize)? };

let kernel_name = match w_dtype {
GgmlDType::Q2K => "indexed_moe_forward_q2k_q8_1",
GgmlDType::Q3K => "indexed_moe_forward_q3k_q8_1",
GgmlDType::Q4K => "indexed_moe_forward_q4k_q8_1",
GgmlDType::Q5K => "indexed_moe_forward_q5k_q8_1",
GgmlDType::Q6K => "indexed_moe_forward_q6k_q8_1",
GgmlDType::Q8_0 => "indexed_moe_forward_q8_0_q8_1",
_ => crate::bail!("unsupported dtype for indexed_moe_forward {w_dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let (nblocks, nwarps) = (n as u32, 4);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nblocks, batch as u32, topk as u32),
block_dim: (WARP_SIZE as u32, nwarps, 1),
shared_mem_bytes: 0,
};

let mut builder = func.builder();
builder.arg(weight);
builder.arg(&input_quant);
builder.arg(ids);
builder.arg(&out);

barg!(
builder,
n as i32,
k as i32,
batch as i32,
topk as i32,
k_padded as i32,
input_dim1 as i32
);
unsafe { builder.launch(cfg) }.w()?;

let mut out_shape = in_shape.dims().to_vec();
out_shape.pop();
out_shape.push(n);
out_shape[1] = topk;
Ok((
CudaStorage::wrap_cuda_slice(out, dev.clone()),
out_shape.into(),
))
}

impl QCudaStorage {
pub fn indexed_moe_forward(
&self,
self_shape: &crate::Shape, //[num_experts, n, k]
input: &CudaStorage, //[batch, topk or 1, k]
input_l: &crate::Layout,
ids: &CudaStorage, //[batch, topk]
ids_l: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
if matches!(
self.dtype(),
GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
) {
let input_storage = input.as_cuda_slice::<f32>()?;
let ids_storage = ids.as_cuda_slice::<u32>()?;
indexed_moe_forward_fused_q8_1_input(
&self.data.inner.slice(0..),
self_shape, //[num_experts, n, k]
self.dtype(),
&input_storage,
input_l.shape(), //[batch, topk or 1, k]
&ids_storage.slice(0..),
ids_l.shape(), //[batch, topk]
&self.device,
)
} else {
crate::bail!(
"The given quantized dtype {:?} is not supported for indexed_moe_forward!",
self.dtype()
);
Comment on lines +521 to +524
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thinking out loud here. It would be nice to have automatic fallback to an approach that isn't as optimized, but still valid. Perhaps returning Result<Option<(CudaStorage, crate::Shape)>> is a decent starting point?
If None then fallback?

Not thinking we add this in this PR ofc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might work, the issue is that effectively indexed_moe_forward is a grouped gemm so we'd need existing infrastructure to run a grouped gemm.

Regardless, providing a grouped gemm functionality will be very useful!

}
}

pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let padded_size_in_bytes =
Expand Down
11 changes: 11 additions & 0 deletions candle-core/src/quantized/dummy_cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ impl QCudaStorage {
pub fn data(&self) -> Result<Vec<u8>> {
Err(Error::NotCompiledWithCudaSupport)
}

pub fn indexed_moe_forward(
&self,
_: &crate::Shape,
_: &CudaStorage,
_: &crate::Layout,
_: &CudaStorage,
_: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
Err(Error::NotCompiledWithCudaSupport)
}
}

pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
Expand Down
11 changes: 11 additions & 0 deletions candle-core/src/quantized/dummy_metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ impl QMetalStorage {
pub fn data(&self) -> Result<Vec<u8>> {
Err(Error::NotCompiledWithMetalSupport)
}

pub fn indexed_moe_forward(
&self,
_: &crate::Shape,
_: &MetalStorage,
_: &crate::Layout,
_: &MetalStorage,
_: &crate::Layout,
) -> Result<(MetalStorage, crate::Shape)> {
Err(Error::NotCompiledWithMetalSupport)
}
}

pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
Expand Down
37 changes: 37 additions & 0 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,34 @@ impl QTensor {
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
self.storage.data()
}

pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
match &self.storage {
QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
(Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
let (storage, out_shape) = s.indexed_moe_forward(
self.shape(),
x_storage,
x.layout(),
ids_storage,
ids.layout(),
)?;
Ok(crate::tensor::from_storage(
Storage::Cuda(storage),
out_shape,
crate::op::BackpropOp::none(),
false,
))
}
_ => {
panic!("Non-cuda indexed_moe_forward is not implemented!");
}
},
_ => {
panic!("indexed_moe_forward is not implemented in this platform!");
}
}
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -713,6 +741,15 @@ impl QMatMul {
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}

pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
match self {
Self::QTensor(t) => t.indexed_moe_forward(x, ids),
_ => {
panic!("Not implemented!")
}
}
}
}

impl crate::CustomOp1 for QTensor {
Expand Down
Loading
Loading