diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index cd361bbd3a..e5616cc947 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -1,4 +1,6 @@ //! Candle-specific Error and Result +use std::{convert::Infallible, fmt::Display}; + use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; #[derive(Debug, Clone)] @@ -209,6 +211,13 @@ pub enum Error { #[error("{0}")] Wrapped(Box), + /// Arbitrary errors wrapping with context. + #[error("{wrapped:?}\n{context:?}")] + WrappedContext { + wrapped: Box, + context: String, + }, + #[error("{context}\n{inner}")] Context { inner: Box, @@ -299,40 +308,87 @@ pub fn zip(r1: Result, r2: Result) -> Result<(T, U)> { } } -// Taken from anyhow. -pub trait Context { +pub(crate) mod private { + pub trait Sealed {} + + impl Sealed for std::result::Result where E: std::error::Error {} + impl Sealed for Option {} +} + +/// Attach more context to an error. +/// +/// Inspired by [`anyhow::Context`]. +pub trait Context: private::Sealed { /// Wrap the error value with additional context. - fn context(self, context: C) -> Result + fn context(self, context: C) -> std::result::Result where - C: std::fmt::Display + Send + Sync + 'static; + C: Display + Send + Sync + 'static; /// Wrap the error value with additional context that is evaluated lazily /// only once an error does occur. - fn with_context(self, f: F) -> Result + fn with_context(self, f: F) -> std::result::Result where - C: std::fmt::Display + Send + Sync + 'static, + C: Display + Send + Sync + 'static, F: FnOnce() -> C; } -impl Context for Option { - fn context(self, context: C) -> Result +impl Context for std::result::Result +where + E: std::error::Error + Send + Sync + 'static, +{ + fn context(self, context: C) -> std::result::Result where - C: std::fmt::Display + Send + Sync + 'static, + C: Display + Send + Sync + 'static, { + // Not using map_err to save 2 useless frames off the captured backtrace + // in ext_context. match self { - Some(v) => Ok(v), - None => Err(Error::UnwrapNone.context(context).bt()), + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context.to_string(), + } + .bt()), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context().to_string(), + } + .bt()), + } + } +} + +impl Context for Option { + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using ok_or_else to save 2 useless frames off the captured + // backtrace. + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context).bt()), } } - fn with_context(self, f: F) -> Result + fn with_context(self, context: F) -> std::result::Result where - C: std::fmt::Display + Send + Sync + 'static, + C: Display + Send + Sync + 'static, F: FnOnce() -> C, { match self { Some(v) => Ok(v), - None => Err(Error::UnwrapNone.context(f()).bt()), + None => Err(Error::UnwrapNone.context(context()).bt()), } } } diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index f5f78bb271..ca40b02d30 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -65,6 +65,12 @@ pub const RESOURCE_OPTIONS: MTLResourceOptions = //| MTLResourceOptions::HazardTrackingModeUntracked.bits(), //); +// Resource options used for `new_private_buffer`. This uses `private` where supported. +#[cfg(target_os = "ios")] +pub const PRIVATE_RESOURCE_OPTIONS: MTLResourceOptions = MTLResourceOptions::StorageModeShared; +#[cfg(not(target_os = "ios"))] +pub const PRIVATE_RESOURCE_OPTIONS: MTLResourceOptions = MTLResourceOptions::StorageModePrivate; + impl std::fmt::Debug for MetalDevice { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "MetalDevice({:?})", self.id) @@ -167,6 +173,23 @@ impl MetalDevice { self.allocate_buffer(size) } + /// Creates a new private buffer (not necessarily zeroed). + /// + /// This is intentionally not in the Metal buffer pool to allow the efficient implementation of persistent buffers. + pub fn new_private_buffer( + &self, + element_count: usize, + dtype: DType, + _name: &str, + ) -> Result> { + let size = element_count * dtype.size_in_bytes(); + let buffer = self + .device + .new_buffer(size, PRIVATE_RESOURCE_OPTIONS) + .map_err(MetalError::from)?; + Ok(Arc::new(buffer)) + } + /// Creates a new buffer from data. /// /// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 6db6625428..3faf9f695f 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -406,7 +406,125 @@ fn mul_mat_via_q8_1( Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } +fn indexed_moe_forward_fused_q8_1_input( + weight: &CudaView, + w_shape: &crate::Shape, //[num_experts, n, k] + w_dtype: GgmlDType, + input: &CudaSlice, + in_shape: &crate::Shape, //[batch, topk or 1, k] + ids: &CudaView, + idx_shape: &crate::Shape, //[batch, topk] + dev: &CudaDevice, +) -> Result<(CudaStorage, crate::Shape)> { + let (_, n, k) = w_shape.dims3()?; + let batch = in_shape.dims()[0]; + let input_dim1 = in_shape.dims()[1]; + + let topk = idx_shape.dims()[1]; + assert!(batch == idx_shape.dims()[0], "batch dim not match!"); + + // Quantize input into q8_1. + let total_rows = batch * input_dim1; + let k_padded = pad(k, MATRIX_ROW_PADDING); + // Get Q8_1 metadata. + let q8_1_block_size = GgmlDType::Q8_1.block_size(); + let q8_1_type_size = GgmlDType::Q8_1.type_size(); + + // Calculate the size of the output buffer in bytes. + let num_blocks_per_row = k_padded / q8_1_block_size; + let dst_row_size_bytes = num_blocks_per_row * q8_1_type_size; + let y_size_in_bytes = total_rows * dst_row_size_bytes; + let mut input_quant = unsafe { dev.alloc::(y_size_in_bytes)? }; + + let input_view = input.slice(0..); + quantize_q8_1(&input_view, &mut input_quant, k, total_rows, dev)?; + + // output buffer + let outsize = batch * topk * n; + let out = unsafe { dev.alloc::(outsize)? }; + + let kernel_name = match w_dtype { + GgmlDType::Q2K => "indexed_moe_forward_q2k_q8_1", + GgmlDType::Q3K => "indexed_moe_forward_q3k_q8_1", + GgmlDType::Q4K => "indexed_moe_forward_q4k_q8_1", + GgmlDType::Q5K => "indexed_moe_forward_q5k_q8_1", + GgmlDType::Q6K => "indexed_moe_forward_q6k_q8_1", + GgmlDType::Q8_0 => "indexed_moe_forward_q8_0_q8_1", + _ => crate::bail!("unsupported dtype for indexed_moe_forward {w_dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?; + let (nblocks, nwarps) = (n as u32, 4); + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (nblocks, batch as u32, topk as u32), + block_dim: (WARP_SIZE as u32, nwarps, 1), + shared_mem_bytes: 0, + }; + + let mut builder = func.builder(); + builder.arg(weight); + builder.arg(&input_quant); + builder.arg(ids); + builder.arg(&out); + + barg!( + builder, + n as i32, + k as i32, + batch as i32, + topk as i32, + k_padded as i32, + input_dim1 as i32 + ); + unsafe { builder.launch(cfg) }.w()?; + + let mut out_shape = in_shape.dims().to_vec(); + out_shape.pop(); + out_shape.push(n); + out_shape[1] = topk; + Ok(( + CudaStorage::wrap_cuda_slice(out, dev.clone()), + out_shape.into(), + )) +} + impl QCudaStorage { + pub fn indexed_moe_forward( + &self, + self_shape: &crate::Shape, //[num_experts, n, k] + input: &CudaStorage, //[batch, topk or 1, k] + input_l: &crate::Layout, + ids: &CudaStorage, //[batch, topk] + ids_l: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + if matches!( + self.dtype(), + GgmlDType::Q8_0 + | GgmlDType::Q2K + | GgmlDType::Q3K + | GgmlDType::Q4K + | GgmlDType::Q5K + | GgmlDType::Q6K + ) { + let input_storage = input.as_cuda_slice::()?; + let ids_storage = ids.as_cuda_slice::()?; + indexed_moe_forward_fused_q8_1_input( + &self.data.inner.slice(0..), + self_shape, //[num_experts, n, k] + self.dtype(), + &input_storage, + input_l.shape(), //[batch, topk or 1, k] + &ids_storage.slice(0..), + ids_l.shape(), //[batch, topk] + &self.device, + ) + } else { + crate::bail!( + "The given quantized dtype {:?} is not supported for indexed_moe_forward!", + self.dtype() + ); + } + } + pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result { let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size(); let padded_size_in_bytes = diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs index 1636f50bb7..7194439a09 100644 --- a/candle-core/src/quantized/dummy_cuda.rs +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -70,6 +70,17 @@ impl QCudaStorage { pub fn data(&self) -> Result> { Err(Error::NotCompiledWithCudaSupport) } + + pub fn indexed_moe_forward( + &self, + _: &crate::Shape, + _: &CudaStorage, + _: &crate::Layout, + _: &CudaStorage, + _: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + Err(Error::NotCompiledWithCudaSupport) + } } pub fn load_quantized( diff --git a/candle-core/src/quantized/dummy_metal.rs b/candle-core/src/quantized/dummy_metal.rs index d4d87861f9..6f470e9099 100644 --- a/candle-core/src/quantized/dummy_metal.rs +++ b/candle-core/src/quantized/dummy_metal.rs @@ -66,6 +66,17 @@ impl QMetalStorage { pub fn data(&self) -> Result> { Err(Error::NotCompiledWithMetalSupport) } + + pub fn indexed_moe_forward( + &self, + _: &crate::Shape, + _: &MetalStorage, + _: &crate::Layout, + _: &MetalStorage, + _: &crate::Layout, + ) -> Result<(MetalStorage, crate::Shape)> { + Err(Error::NotCompiledWithMetalSupport) + } } pub fn load_quantized( diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d7768a94de..cee8ccc2ad 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -642,6 +642,34 @@ impl QTensor { pub fn data(&self) -> Result> { self.storage.data() } + + pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result { + match &self.storage { + QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) { + (Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => { + let (storage, out_shape) = s.indexed_moe_forward( + self.shape(), + x_storage, + x.layout(), + ids_storage, + ids.layout(), + )?; + Ok(crate::tensor::from_storage( + Storage::Cuda(storage), + out_shape, + crate::op::BackpropOp::none(), + false, + )) + } + _ => { + panic!("Non-cuda indexed_moe_forward is not implemented!"); + } + }, + _ => { + panic!("indexed_moe_forward is not implemented in this platform!"); + } + } + } } #[derive(Clone, Debug)] @@ -713,6 +741,15 @@ impl QMatMul { }; xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype) } + + pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result { + match self { + Self::QTensor(t) => t.indexed_moe_forward(x, ids), + _ => { + panic!("Not implemented!") + } + } + } } impl crate::CustomOp1 for QTensor { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 36a177959a..0c01ba94ae 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -270,6 +270,51 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), self.device()) } + // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from + // the variable module. + pub(crate) unsafe fn empty_impl>( + shape: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let none = BackpropOp::none(); + let shape = shape.into(); + let storage = device.alloc_uninit(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) + } + + /// Creates a new tensor filled with uninitialized memory. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use candle_core::{Tensor, DType, Device}; + /// let a = unsafe { Tensor::empty((2, 3), DType::F32, &Device::Cpu)? }; + /// // a == b + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub unsafe fn empty>(shape: S, dtype: DType, device: &Device) -> Result { + Self::empty_impl(shape, dtype, device, false) + } + + /// Creates a new tensor filled with uninitialized memory of the same shape, dtype, and device as the other + /// tensor. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use candle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = unsafe { a.empty_like()? }; + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub unsafe fn empty_like(&self) -> Result { + Tensor::empty(self.shape(), self.dtype(), self.device()) + } + pub(crate) fn rand_impl, T: crate::FloatDType>( lo: T, up: T, @@ -2768,6 +2813,49 @@ impl Tensor { } Ok(result) } + + /// Returns a view of which contains all slices of size `size` from self tensor in the dimension + /// `dim` and stepped by `step`. + pub fn unfold(&self, dim: D, size: usize, step: usize) -> Result { + // https://github.com/pytorch/pytorch/blob/75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d/aten/src/ATen/native/TensorShape.cpp#L3785-L3804 + let mut sizes = self.dims().to_vec(); + let mut strides = self.stride().to_vec(); + + let dim = dim.to_index(self.shape(), "unfold")?; + + let max_len = if self.dims().is_empty() { + 1 + } else { + sizes[dim] + }; + if size > max_len { + bail!( + "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}" + ) + } + sizes.push(size); + strides.push(if self.dims().is_empty() { + 1 + } else { + strides[dim] + }); + + if !self.dims().is_empty() { + sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize; + strides[dim] *= step; + } + + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(sizes.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } } macro_rules! bin_trait { diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 643783b350..3f90ec3a47 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -487,9 +487,9 @@ impl FlashAttnVarLen { None => candle::bail!("seqlens_k has to be contiguous"), }; - let q = q.as_cuda_slice::()?; - let k = k.as_cuda_slice::()?; - let v = v.as_cuda_slice::()?; + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; let q = q.slice(q_l.start_offset()..); let k = k.slice(k_l.start_offset()..); let v = v.slice(v_l.start_offset()..); @@ -604,7 +604,7 @@ impl FlashAttnVarLen { let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count)? }; + let dst = unsafe { dev.alloc::(elem_count)? }; let softmax_lse = dev.alloc_zeros::(num_heads * total_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 }; diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index b6a4310005..b888b3e8a8 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -4329,3 +4329,209 @@ extern "C" __global__ void load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } + + +/** + * @brief Performs an indexed, batched matrix-vector multiplication for quantized tensors (for MoE models). + * + * This kernel handles a batch of `total_tasks` independent operations. Each task consists + * of multiplying a Q8_1 quantized input vector with a Q4_K quantized weight matrix selected + * by an index. + * + * Parallelization Strategy: + * - The grid is 2D: gridDim.y corresponds to the task index, and gridDim.x corresponds to the row blocks of the output matrix. + * - `blockIdx.y`: Identifies which task to perform from the batch (`0` to `total_tasks - 1`). + * - `blockIdx.x`: Used internally by `mul_mat_vec_q` to parallelize the dot products across the rows of the weight matrix. + * + * @author + * Guoqing Bao + * Part of the project: https://github.com/guoqingbao/vllm.rs/ + * @param all_weights Pointer to the beginning of the weight tensor [num_experts, n, k]. + * @param all_inputs Pointer to the beginning of the input tensor [batch * topk, k]. + * @param indices Pointer to the expert indices for each task [batch * topk]. + * @param all_outputs Pointer to the beginning of the output tensor [batch * topk, n]. + * @param n The number of output features (rows in the weight matrix). + * @param k The number of input features (columns in the weight matrix). + * @param total_tasks The total number of tasks to process, typically batch_size * topk. + * @param k_padded The value of k padded to a multiple of MATRIX_ROW_PADDING. + * @param weight_expert_stride_bytes The stride in bytes to get from one expert matrix to the next. + * @param input_task_stride_bytes The stride in bytes to get from one quantized input vector to the next. + * @param output_task_stride_elems The stride in elements (f32) to get from one output vector to the next. + */ +template +__device__ void indexed_moe_forward( + const void * __restrict__ all_weights, + const void * __restrict__ all_inputs, + const unsigned int * __restrict__ indices, + float * __restrict__ all_outputs, + const int n, + const int k, + const int batch, + const int topk, + const int k_padded, + const int input_dim1) { + + // `blockIdx.y` corresponds to the batch index (0 to batch_size-1) + const int current_batch = blockIdx.y; + // `blockIdx.z` corresponds to the topk index (0 to topk-1) + const int current_topk = blockIdx.z; + + // `gridDim.z` is the number of blocks in the z-dim, which is `topk`. + // This correctly flattens the (batch, topk) index into a single task ID. + const int task_id = current_batch * gridDim.z + current_topk; + if (task_id >= gridDim.y * gridDim.z) { + return; + } + // If input_dim1 is 1, all experts in a batch use the same input vector. + // Otherwise, each expert has a unique input vector. + const int input_idx = (input_dim1 == 1) ? current_batch : task_id; + + // The expert to use is found in the `indices` array at the flattened `task_id`. + const unsigned int expert_id = indices[task_id]; + + // Calculate strides + const size_t weight_block_size = sizeof(block_q_t); + const size_t input_block_size = sizeof(block_q8_1); + const size_t weight_expert_stride_bytes = (size_t)(n * k) / QK_K * weight_block_size; + const size_t input_task_stride_bytes = (size_t)k_padded / QK8_1 * input_block_size; + const size_t output_task_stride_elems = n; + + //data offsets of current task + const void * current_input_ptr = (const char *)all_inputs + input_idx * input_task_stride_bytes; + const void * current_weight_ptr = (const char *)all_weights + expert_id * weight_expert_stride_bytes; + float * current_output_ptr = all_outputs + task_id * output_task_stride_elems; + + //fixed for inner compute + constexpr int ncols_y = 1; + constexpr int nwarps = 4; + constexpr int rows_per_cuda_block = 1; + + const int tid = WARP_SIZE * threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block * blockIdx.x; // `blockIdx.x` is the row within the task + + if (row0 >= n) { + return; + } + + const int blocks_per_row_x = k / qk; + const int blocks_per_col_y = k_padded / QK8_1; + constexpr int blocks_per_iter = vdr * nwarps * WARP_SIZE / qi; + + float tmp = 0.0f; + + const block_q_t * w = (const block_q_t *) current_weight_ptr; + const block_q8_1 * x = (const block_q8_1 *) current_input_ptr; + + for (int kbx = tid / (qi / vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk / QK8_1); + const int kqs = vdr * (tid % (qi / vdr)); + tmp += vec_dot_q_cuda(&w[kbx + row0 * blocks_per_row_x], &x[kby], kqs); + } + + // --- Inter-warp reduction using shared memory --- + __shared__ float tmp_shared[nwarps - 1][WARP_SIZE]; + if (threadIdx.y > 0) { + tmp_shared[threadIdx.y - 1][threadIdx.x] = tmp; + } + __syncthreads(); + + if (threadIdx.y == 0) { + for (int l = 0; l < nwarps - 1; ++l) { + tmp += tmp_shared[l][threadIdx.x]; + } + tmp = warp_reduce_sum(tmp); + if (threadIdx.x == 0) { + current_output_ptr[row0] = tmp; + } + } +} + +extern "C" __global__ void indexed_moe_forward_q2k_q8_1( + const void * __restrict__ all_weights, + const void * __restrict__ all_inputs, + const unsigned int * __restrict__ indices, + float * __restrict__ all_outputs, + const int n, + const int k, + const int batch, + const int topk, + const int k_padded, + const int input_dim1) { + indexed_moe_forward + (all_weights, all_inputs, indices, all_outputs, n, k, batch, topk, k_padded, input_dim1); +} + +extern "C" __global__ void indexed_moe_forward_q3k_q8_1( + const void * __restrict__ all_weights, + const void * __restrict__ all_inputs, + const unsigned int * __restrict__ indices, + float * __restrict__ all_outputs, + const int n, + const int k, + const int batch, + const int topk, + const int k_padded, + const int input_dim1) { + indexed_moe_forward + (all_weights, all_inputs, indices, all_outputs, n, k, batch, topk, k_padded, input_dim1); +} + +extern "C" __global__ void indexed_moe_forward_q4k_q8_1( + const void * __restrict__ all_weights, + const void * __restrict__ all_inputs, + const unsigned int * __restrict__ indices, + float * __restrict__ all_outputs, + const int n, + const int k, + const int batch, + const int topk, + const int k_padded, + const int input_dim1) { + indexed_moe_forward + (all_weights, all_inputs, indices, all_outputs, n, k, batch, topk, k_padded, input_dim1); +} + +extern "C" __global__ void indexed_moe_forward_q5k_q8_1( + const void * __restrict__ all_weights, + const void * __restrict__ all_inputs, + const unsigned int * __restrict__ indices, + float * __restrict__ all_outputs, + const int n, + const int k, + const int batch, + const int topk, + const int k_padded, + const int input_dim1) { + indexed_moe_forward + (all_weights, all_inputs, indices, all_outputs, n, k, batch, topk, k_padded, input_dim1); +} + +extern "C" __global__ void indexed_moe_forward_q6k_q8_1( + const void * __restrict__ all_weights, + const void * __restrict__ all_inputs, + const unsigned int * __restrict__ indices, + float * __restrict__ all_outputs, + const int n, + const int k, + const int batch, + const int topk, + const int k_padded, + const int input_dim1) { + indexed_moe_forward + (all_weights, all_inputs, indices, all_outputs, n, k, batch, topk, k_padded, input_dim1); +} + +extern "C" __global__ void indexed_moe_forward_q8_0_q8_1( + const void * __restrict__ all_weights, + const void * __restrict__ all_inputs, + const unsigned int * __restrict__ indices, + float * __restrict__ all_outputs, + const int n, + const int k, + const int batch, + const int topk, + const int k_padded, + const int input_dim1) { + indexed_moe_forward + (all_weights, all_inputs, indices, all_outputs, n, k, batch, topk, k_padded, input_dim1); +} diff --git a/candle-metal-kernels/src/kernels/sdpa.rs b/candle-metal-kernels/src/kernels/sdpa.rs index 03bde7a0f9..a81e4b79a4 100644 --- a/candle-metal-kernels/src/kernels/sdpa.rs +++ b/candle-metal-kernels/src/kernels/sdpa.rs @@ -25,170 +25,200 @@ pub fn call_sdpa_full( kernels: &Kernels, q_offset: usize, q_shape: &[usize], + q_strides: &[usize], q_buffer: &Buffer, k_offset: usize, + k_shape: &[usize], + k_strides: &[usize], k_buffer: &Buffer, v_offset: usize, v_buffer: &Buffer, + v_strides: &[usize], + mask_type: Option, + mask_buffer: Option<&Buffer>, + m_strides: Option<&[usize]>, output: &Buffer, - alpha: f32, - softcapping: f32, + o_strides: &[usize], + scale: f32, + do_causal: bool, itype: SdpaDType, ) -> Result<(), MetalKernelError> { #[derive(Debug)] #[repr(C)] - struct MLXFastAttentionParams { - m: i32, - n: i32, - k: i32, - - ldq: i32, // ldq == ldo - ldk: i32, - ldv: i32, - lds: i32, - ldo: i32, - - tiles_n: i32, - tiles_m: i32, - - batch_stride_q: i32, - batch_stride_k: i32, - batch_stride_v: i32, - batch_stride_o: i32, - - swizzle_log: i32, - gemm_n_iterations_aligned: i32, - gemm_k_iterations_aligned: i32, - gemm_sv_m_block_iterations: i32, - - batch_ndim: i32, - alpha: f32, - softcapping: f32, + struct AttnParams { + b: i32, + h: i32, + d: i32, + ql: i32, + kl: i32, + gqa_factor: i32, + scale: f32, + nq: i32, + nk: i32, + nq_aligned: i32, + nk_aligned: i32, + ql_rem: i32, + kl_rem: i32, + ql_off: i32, + q_strides: [i64; 3], + k_strides: [i64; 3], + v_strides: [i64; 3], + o_strides: [i64; 3], } - let bk = q_shape.last().unwrap(); + #[derive(Debug)] + #[repr(C)] + struct AttnMaskParams { + m_strides: [i64; 3], + } - const BN: usize = 16; - const BM: usize = 16; - const WM: usize = 2; - const WN: usize = 2; + const WM: usize = 4; + const WN: usize = 1; + + const BQ: usize = 32; + let bd = q_shape[q_shape.len() - 1]; + if ![32, 64, 72, 80, 96, 128, 256].contains(&bd) { + return Err(MetalKernelError::SdpaHeadSizeMismatch { + variation: "full", + got: bd, + expected: vec![32, 64, 72, 80, 96, 128, 256], + }); + }; + let bk = if bd < 128 { 32 } else { 16 }; - let name = match (bk, itype) { - (32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half", - (64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half", - (96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half", - (128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half", - (256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half", - (32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float", - (64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float", - (96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float", - (128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float", - (256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float", - (other, SdpaDType::F16 | SdpaDType::F32) => { - return Err(MetalKernelError::SdpaHeadSizeMismatch { - variation: "full", - got: *other, - expected: vec![32, 64, 96, 128, 256], - }) - } - (_, SdpaDType::BF16) => { - return Err(MetalKernelError::SdpaHeadDTypeMismatch { - variation: "full", - got: SdpaDType::BF16, - }) - } + let b = q_shape[0]; + let h = q_shape[1]; + let d = q_shape[3]; + let gqa_factor = q_shape[1] / k_shape[1]; + + let ql = q_shape[2]; + let kl = k_shape[2]; + + let align_q = (ql % BQ) == 0; + let align_k = (kl % bk) == 0; + let has_mask = mask_buffer.is_some(); + + let itype_repr = match itype { + SdpaDType::BF16 => "bfloat16", + SdpaDType::F16 => "float16", + SdpaDType::F32 => "float32", + }; + let mask_repr = match mask_type { + Some(SdpaDType::BF16) => "bfloat16", + Some(SdpaDType::F16) => "float16", + Some(SdpaDType::F32) => "float32", + None => itype_repr, }; + let name = + format!("steel_attention_{itype_repr}_bq{BQ}_bk{bk}_bd{bd}_wm{WM}_wn{WN}_mask{mask_repr}"); - let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?; + let constants = Some(ConstantValues::new(vec![ + (200, Value::Bool(/* align_Q */ align_q)), + (201, Value::Bool(/* align_K */ align_k)), + (300, Value::Bool(/* has_mask */ has_mask)), + (301, Value::Bool(/* do_causal */ do_causal)), + ])); + + let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoder = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - // q = (bs, qhead, seq, hidden) - // k/v = (bs, kv_head, seq, hidden) - - let qseq = q_shape[q_shape.len() - 2]; - - let m = q_shape[q_shape.len() - 2]; - let n = m; - let k = q_shape[q_shape.len() - 1]; - let bs_out = q_shape[0] * q_shape[1]; - - let batch_shape = [q_shape[0] * q_shape[1]]; - let dk = q_shape[q_shape.len() - 1]; - let ldq = dk; - let ldk = dk; - let ldv = dk; - let lds = BN; - let ldo = dk; - - let tn = 1; - let tm = m.div_ceil(BM); - - let b_stride_q = dk * qseq; - let b_stride_k = dk * qseq; - let b_stride_v = dk * qseq; - let b_stride_o = dk * qseq; - let swizzle_log = 0; - let gemm_n_iterations_aligned = n.div_ceil(BN); - let gemm_k_iterations_aligned = k.div_ceil(*bk); - let gemm_sv_m_block_iterations = m.div_ceil(BM); - let batch_ndim = batch_shape.len(); - - let alpha = if softcapping != 1. { - alpha / softcapping - } else { - alpha + let nq = (ql + BQ - 1) / BQ; + let nk = (kl + bk - 1) / bk; + + let nq_aligned = ql / BQ; + let nk_aligned = kl / bk; + + let params = AttnParams { + b: b as i32, + h: h as i32, + d: d as i32, + ql: ql as i32, + kl: kl as i32, + gqa_factor: gqa_factor as i32, + scale, + nq: nq as i32, + nk: nk as i32, + nq_aligned: nq_aligned as i32, + nk_aligned: nk_aligned as i32, + ql_rem: ql.wrapping_sub(nq_aligned * BQ) as i32, + kl_rem: kl.wrapping_sub(nk_aligned * bk) as i32, + ql_off: kl.wrapping_sub(ql) as i32, + q_strides: [ + q_strides[0] as i64, + q_strides[1] as i64, + q_strides[2] as i64, + ], + k_strides: [ + k_strides[0] as i64, + k_strides[1] as i64, + k_strides[2] as i64, + ], + v_strides: [ + v_strides[0] as i64, + v_strides[1] as i64, + v_strides[2] as i64, + ], + o_strides: [ + o_strides[0] as i64, + o_strides[1] as i64, + o_strides[2] as i64, + ], }; - let params = MLXFastAttentionParams { - m: m as i32, - n: n as i32, - k: k as i32, - ldq: ldq as i32, - ldk: ldk as i32, - ldv: ldv as i32, - lds: lds as i32, - ldo: ldo as i32, - tiles_n: tn, - tiles_m: tm as i32, - batch_stride_q: b_stride_q as i32, - batch_stride_k: b_stride_k as i32, - batch_stride_v: b_stride_v as i32, - batch_stride_o: b_stride_o as i32, - swizzle_log, - gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32, - gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32, - gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32, - batch_ndim: batch_ndim as i32, - alpha, - softcapping, - }; - let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o]; + impl EncoderParam for AttnParams { + fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) { + encoder.set_bytes(position, &data); + } + } - impl EncoderParam for MLXFastAttentionParams { + impl EncoderParam for AttnMaskParams { fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) { encoder.set_bytes(position, &data); } } - set_params!( - encoder, - ( - (q_buffer, q_offset), - (k_buffer, k_offset), - (v_buffer, v_offset), - output, - params, - &batch_shape[..], - &batch_strides[..] - ) - ); + if let Some(mask) = mask_buffer { + let mask_strides = m_strides.unwrap(); + let mask_params = AttnMaskParams { + m_strides: [ + mask_strides[0] as i64, + mask_strides[1] as i64, + mask_strides[2] as i64, + ], + }; + encoder.use_resource(mask, MTLResourceUsage::Read); + + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + params, + mask_params, + mask + ) + ); + } else { + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + params + ) + ); + } let grid_dims = MTLSize { - width: 1, - height: tm, - depth: bs_out, + width: nq, + height: h, + depth: b, }; let group_dims = MTLSize { width: 32, @@ -200,6 +230,7 @@ pub fn call_sdpa_full( encoder.use_resource(v_buffer, MTLResourceUsage::Read); encoder.use_resource(output, MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); + Ok(()) } diff --git a/candle-metal-kernels/src/metal/device.rs b/candle-metal-kernels/src/metal/device.rs index b9a9f9ec48..9380d19bb9 100644 --- a/candle-metal-kernels/src/metal/device.rs +++ b/candle-metal-kernels/src/metal/device.rs @@ -93,4 +93,12 @@ impl Device { let raw = self.as_ref().newCommandQueue().unwrap(); Ok(raw) } + + pub fn recommended_max_working_set_size(&self) -> usize { + self.as_ref().recommendedMaxWorkingSetSize() as usize + } + + pub fn current_allocated_size(&self) -> usize { + self.as_ref().currentAllocatedSize() + } } diff --git a/candle-metal-kernels/src/metal_src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/metal_src/scaled_dot_product_attention.metal index 1876252eee..e1057a994b 100644 --- a/candle-metal-kernels/src/metal_src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/metal_src/scaled_dot_product_attention.metal @@ -5,6 +5,262 @@ using namespace metal; +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +#if defined(__HAVE_BFLOAT__) + +typedef bfloat bfloat16_t; +typedef half float16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype& __operator__( \ + addr_space itype& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +#endif + // ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" struct MLXFastAttentionParams { @@ -140,6 +396,9 @@ template // Move the pointers to the next kv keys += stride; values += stride; + if (sdpa_vector_has_mask) { + mask += BN * mask_seq_stride; + } } // Each thread has a partial part of the output so we need to combine them. @@ -275,6 +534,43 @@ template mask += BN * blocks * mask_seq_stride; } } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BN + simd_gid] = + o[i] * fast::exp(max_scores[simd_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // And write the output + if (simd_gid == 0) { + U output = outputs[simd_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[simd_lid * BN + j]; + } + out[i] = static_cast(output); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } } template @@ -329,114 +625,55 @@ template } } -// ============ "mlx/backend/metal/kernels/steel/defines.h" - -#define STEEL_CONST static constant constexpr const -#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") - -// ============ "mlx/backend/metal/kernels/steel/gemm/transforms.h" - -template -struct TransformNone { - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT) { - return static_cast(x); - } -}; - -template -struct TransformAdd { - TransformAdd(const float, const float) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT c) { - return static_cast(x) + c; - } -}; - -template -struct TransformAxpby { - const float alpha; - const float beta; - - TransformAxpby(const float alpha_, const float beta_) - : alpha(alpha_), beta(beta_) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - METAL_FUNC OutT apply(InT x, OutT c) const { - return static_cast(x * alpha + (beta * c)); - } -}; - -template -struct AccumHelper { - typedef float accum_type; -}; +// ============ "mlx/backend/metal/kernels/utils.h" -struct BlockSwizzle { - static METAL_FUNC int2 - swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { - const int tid_x = (tid.x) >> swizzle_log; - const int tid_y = - ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); - return int2(tid_x, tid_y); - } +template +struct Limits { + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); }; -// ============ "mlx/backend/metal/kernels/utils.h" +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; -#if defined(__HAVE_BFLOAT__) -typedef bfloat bfloat16_t; -#endif -typedef half float16_t; +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; -METAL_FUNC ulong2 elem_to_loc_broadcast( - uint elem, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - int ndim) { - ulong loc_a{0}; - ulong loc_b{0}; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - int pos_in_dim = (elem % shape[i]); - elem /= shape[i]; - loc_a += pos_in_dim * a_strides[i]; - loc_b += pos_in_dim * b_strides[i]; - } - return ulong2(loc_a, loc_b); -} +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); -METAL_FUNC ulong3 elem_to_loc_broadcast( - uint elem, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const size_t* c_strides, - int ndim) { - ulong loc_a{0}; - ulong loc_b{0}; - ulong loc_c{0}; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - int pos_in_dim = (elem % shape[i]); - elem /= shape[i]; - loc_a += pos_in_dim * a_strides[i]; - loc_b += pos_in_dim * b_strides[i]; - loc_c += pos_in_dim * c_strides[i]; - } - return ulong3(loc_a, loc_b, loc_c); -} -// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.metal" +// ============ "mlx/backend/metal/kernels/steel/attn/loader.h" template < typename T, @@ -449,7 +686,7 @@ template < short n_reads = (BCOLS * BROWS) / (tgp_size), short TCOLS = BCOLS / n_reads, short TROWS = tgp_size / TCOLS> -struct BlockLoaderFA { +struct BlockLoader { STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; STEEL_CONST short vec_size = n_reads; @@ -471,7 +708,7 @@ struct BlockLoaderFA { }; /* Constructor */ - METAL_FUNC BlockLoaderFA( + METAL_FUNC BlockLoader( const device T* src_, const int src_ld_, threadgroup T* dst_, @@ -485,6 +722,18 @@ struct BlockLoaderFA { dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj) {} + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL @@ -528,7 +777,7 @@ struct BlockLoaderFA { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out unneeded values + // Zero out uneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); @@ -546,242 +795,925 @@ struct BlockLoaderFA { METAL_FUNC void next() { src += tile_stride; } - METAL_FUNC void next(short n) { - src += n * tile_stride; - } }; -template -struct LoopAlignment {}; +template +struct CShape { + STEEL_CONST int kRows = R; + STEEL_CONST int kCols = C; +}; template < typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - short lda_tgp, - short ldb_tgp, - typename AccumType = float, - typename Epilogue = TransformNone> -struct BlockMMAFA { - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = 8 * WM; - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = 8 * WN; - - // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; - // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; - - // Strides of A, B along reduction axis - STEEL_CONST short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - STEEL_CONST short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; - - // Jump between elements - STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; - STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; - - STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; - - // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - - // Offsets within threadgroup - const short tm; - const short tn; + short BROWS, + short BCOLS, + short kDstStrRow, + short kDstStrCol, + short reduction_dim, + short tgp_size, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderT { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; - short sm; - short sn; + // Leading dimension for src + const int src_ld; + const int tile_stride; - ushort sid; - ushort slid; + // Thread location indices + const short thread_idx; + const short bi; + const short bj; - short As_offset; - short Bs_offset; + // threadgroup and device memory + threadgroup T* dst; + const device T* src; /* Constructor */ - METAL_FUNC BlockMMAFA( + METAL_FUNC BlockLoaderT( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - // Determine thread position in simdgroup matrix - short qid = simd_lane_id / 4; - slid = simd_lane_id; - sid = simd_group_id; - - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - - // Determine thread and simdgroup offset - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { - // Adjust for simdgroup and thread location - As += As_offset; - Bs += Bs_offset; + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), + src(src_ + bi * src_ld + bj) {} - // Iterate over BK in blocks of 8 + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += 8) { - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup A as simdgroup matrices + for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = + op.apply(dst[i * kDstStrRow + j * kDstStrCol]); } + } + } - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup B as simdgroup matrices + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; } + } + } - simdgroup_barrier(mem_flags::mem_none); + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); - // Multiply and accumulate into result simdgroup matrices + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { + for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = T(0); } } - - // Progress to next simdgroup tile - As += tile_stride_a; - Bs += tile_stride_b; + return; } - } - METAL_FUNC void rescale_output(const threadgroup float* Corrections) { - // Loop over all simdgroup tiles + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - short row = sm + tm + i * TM_stride; - float scale_value = Corrections[row]; + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +// ============ "mlx/backend/metal/kernels/steel/utils/type_traits.h" + +template +struct make_void { + typedef void type; +}; + +template +using void_t = typename make_void::type; + +template +struct pointer_element {}; + +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; + +template +using pointer_element_t = typename pointer_element>::type; + +// ============ "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +/////////////////////////////////////////////////////////////////////////////// +// Integral constant with casting +/////////////////////////////////////////////////////////////////////////////// + +template +using Int = integral_constant; + +/////////////////////////////////////////////////////////////////////////////// +// Binary Operators on Integral constants +/////////////////////////////////////////////////////////////////////////////// + +#define integral_const_binop(__op__, __operator__) \ + template \ + METAL_FUNC constexpr auto __operator__( \ + integral_constant, integral_constant) { \ + constexpr auto res = tv __op__ uv; \ + return integral_constant{}; \ + } + +integral_const_binop(+, operator+); +integral_const_binop(-, operator-); +integral_const_binop(*, operator*); +integral_const_binop(/, operator/); + +integral_const_binop(==, operator==); +integral_const_binop(!=, operator!=); +integral_const_binop(<, operator<); +integral_const_binop(>, operator>); +integral_const_binop(<=, operator<=); +integral_const_binop(>=, operator>=); + +integral_const_binop(&&, operator&&); +integral_const_binop(||, operator||); + +#undef integral_const_binop + +/////////////////////////////////////////////////////////////////////////////// +// Reduction operators +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC constexpr T sum(T x) { + return x; +} + +template +METAL_FUNC constexpr auto sum(T x, Us... us) { + return x + sum(us...); +} + +// ============ "mlx/backend/metal/kernels/steel/gemm/transforms.h" + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +// ============ "mlx/backend/metal/kernels/steel/attn/mma.h" + +template +struct Shape2D { + RInt r; + CInt c; + + Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} +}; + +template +struct Layout2D { + Shape shape; + Layout layout; +}; + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + typedef metal::vec row_frag_type; + typedef metal::vec col_frag_type; + + template + using dtype_mat_t = typename metal::simdgroup_matrix; + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x.value + j * str_y.value]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_y + j) * str_y.value]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y.value] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y.value] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread dtype_frag_t& A, + thread dtype_frag_t& B, + thread dtype_frag_t& C) { + mat_type D_mat; + dtype_mat_t A_mat; + dtype_mat_t B_mat; + dtype_mat_t C_mat; + + reinterpret_cast&>(A_mat.thread_elements()) = A; + reinterpret_cast&>(B_mat.thread_elements()) = B; + reinterpret_cast&>(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + template + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread dtype_mat_t& A, + thread dtype_mat_t& B, + thread dtype_mat_t& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const frag_type& inp_vals, + thread T* reduced_vals) { + T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread frag_type& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags]; // = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_reduce( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template < + typename Dtype, + typename Atype, + typename Btype, + typename Ctype, + int M, + int N, + int K, + class MMAFragD, + class MMAFragA, + class MMAFragB, + class MMAFragC> +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short m_serp = m; //(n % 2) ? (M - 1 - m) : m; + short n_serp = (m % 2) ? (N - 1 - n) : n; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMAFragD::mma( + D.frag_at(m_serp, n_serp), + A.frag_at(m_serp, k), + B.frag_at(k, n_serp), + C.frag_at(m_serp, n_serp)); + } + } + } +} + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); - // int offset = (i * TM_stride) * ldc + (j * TN_stride); - accum[0] *= scale_value; - accum[1] *= scale_value; - } + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; } } /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* C, const int ldc) const { + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + tn + sn; + D += sm * ldd + sn; - // Loop over all simdgroup tiles + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } - // Apply epilogue - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); - // Write out C - C[offset] = outs[0]; - C[offset + 1] = outs[1]; - } + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); } } - METAL_FUNC void store_result_to_tgp_memory( - threadgroup U* C, + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, const int ldc, - short2 dst_tile_dims) const { + const int fdc, + thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); } } } } - METAL_FUNC void - store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; } } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } } } } @@ -795,8 +1727,10 @@ struct BlockMMAFA { const int fdc, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL @@ -804,18 +1738,15 @@ struct BlockMMAFA { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - - // Write out D - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -829,9 +1760,14 @@ struct BlockMMAFA { short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { @@ -839,556 +1775,551 @@ struct BlockMMAFA { STEEL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } } } +}; - METAL_FUNC void clear_results() { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - results[i * TN + j] = simdgroup_matrix(0); - } - } +// ============ "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" + +struct AttnParams { + int B; ///< Batch Size + int H; ///< Heads + int D; ///< Head Dim + + int qL; ///< Query Sequence Length + int kL; ///< Key Sequence Length + + int gqa_factor; ///< Group Query factor + float scale; ///< Attention scale + + int NQ; ///< Number of query blocks + int NK; ///< Number of key/value blocks + + int NQ_aligned; ///< Number of full query blocks + int NK_aligned; ///< Number of full key/value blocks + + int qL_rem; ///< Remainder in last query block + int kL_rem; ///< Remainder in last key/value block + int qL_off; ///< Offset in query sequence start + + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) +}; + +struct AttnMaskParams { + int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) +}; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; } }; +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp2(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off template < typename T, - typename U, - int BM, - int BN, + int BQ, int BK, + int BD, int WM, int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct FastAttentionKernel { - STEEL_CONST short tgp_padding = 16 / sizeof(T); - STEEL_CONST short float_padding = 16 / sizeof(float); - STEEL_CONST short tgp_mem_size_q = - transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_k = - transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_v = - transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); - - // maxes, rowsums, rescale - STEEL_CONST short tgp_mem_size_corrections = - 4 * (BM * sizeof(float) + float_padding); - - STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; - - STEEL_CONST short tgp_mem_size = share_kv_smem - ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections - : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections + tgp_mem_size_v; - - STEEL_CONST short tgp_size = WM * WN * 32; - - static_assert(transpose_q == false, "Expected Q not transposed."); - static_assert(transpose_k == true, "Expected K transposed."); - static_assert(transpose_v == false, "Expected V not transposed."); - static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); - - using loader_q_t = BlockLoaderFA< - T, - transpose_q ? BK : BM, - transpose_q ? BM : BK, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - !transpose_q, - tgp_size>; - - using loader_k_t = BlockLoaderFA< - T, - transpose_k ? BN : BK, - transpose_k ? BK : BN, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - transpose_k, - tgp_size>; - - using loader_v_t = BlockLoaderFA< - T, - transpose_v ? BK : BN, - transpose_v ? BN : BK, - transpose_v ? BN + tgp_padding : BK + tgp_padding, - transpose_v, - tgp_size>; - - using mma_qk_t = BlockMMAFA< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - AccumType, - Epilogue>; - - using mma_sv_t = BlockMMAFA< - T, - U, - BM, - BK, - BN, - WM, - WN, - false, - transpose_v, - BN + tgp_padding, - BK + tgp_padding, - AccumType, - Epilogue>; - - /* Main kernel function */ - template - static METAL_FUNC void gemm_loop( - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - const int gemm_k_iterations, - thread loader_k_t& loader_b, - thread mma_qk_t& mma_op, - thread const short& tgp_bm, - thread const short& tgp_bn, - LoopAlignment l = {}) { - // Appease the compiler - (void)l; - (void)tgp_bm; - - short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - // not valid for gemm_k_iterations > 1 (so, BK == d_k) - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on - threadgroup_barrier(mem_flags::mem_threadgroup); + // Pacifying compiler + (void)lid; - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Seqeunce + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Seqeunce + + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + + // Prepare threadgroup memory + constexpr short padQ = 16 / sizeof(T); + constexpr short padK = 16 / sizeof(T); + constexpr short padV = 16 / sizeof(T); + + constexpr short LDQ_tgp = BD + padQ; + constexpr short LDK_tgp = BK + padK; + constexpr short LDV_tgp = BD + padV; + + constexpr short tgp_mem_0 = (BK + padK) * (BD); + constexpr short tgp_mem_1 = BK * (BD + padV); + constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1; + + threadgroup T Q_smem[BQ * (BD + padQ)]; + threadgroup T KV_smem[tgp_mem_s]; + + threadgroup T* Qs = Q_smem; + threadgroup T* Ks = KV_smem; + threadgroup T* Vs = KV_smem; + + // Prepare block loaders + using QBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BQ, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDQ_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 1, + /* short tgp_size = */ WM * WN * 32>; + + // K is loaded in transposed + using KBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ 1, + /* short kDstStrCol = */ LDK_tgp, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + using VBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDV_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + QBlockLoader loader_q( + Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); + KBlockLoader loader_k( + K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); + VBlockLoader loader_v( + V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); + + TransformScale ts(static_cast(params->scale * 1.44269504089)); + + // Prepare MMA tiles + constexpr short kFragSize = 8; // MMAFrag size + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * kFragSize); + // KV sequence frags (all warps load the same frags) + constexpr int TK = BK / kFragSize; + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / kFragSize; + + static_assert(TQ == 1, "Check TQ"); + + MMATile Qtile; + MMATile Ktile; + MMATile Stile; + MMATile Vtile; + MMATile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; + + const short Qs_offset = (tm + sm) * LDQ_tgp + sn; + const short Ks_offset = sm * LDK_tgp + sn; + const short Vs_offset = sm * LDV_tgp + sn; + + constexpr short Qs_tile_stride = kFragSize; + constexpr short Ks_tile_stride = kFragSize * LDK_tgp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load Q blocks apply scale + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + loader_q.load_safe(short2(BD, params->qL_rem)); + } else { + loader_q.load_unsafe(); + } + loader_q.apply_inplace_op(ts); + + // Init row reduction variables + constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; + + AccumType max_score[kRowsPT]; + AccumType sum_score[kRowsPT] = {0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::min; + } + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; } - static METAL_FUNC void initialize_corrections( - threadgroup float* C, - uint simd_lane_id, - uint simd_group_id) { - if (simd_group_id == 0) { - threadgroup float* maxes = C; - threadgroup float* sums = C + (BM + float_padding); - threadgroup float* o_rescale = sums + (BM + float_padding); - threadgroup float* output_rescale = o_rescale + (BM + float_padding); - - if (simd_lane_id < BM) { - maxes[simd_lane_id] = -INFINITY; // m_i - sums[simd_lane_id] = 0.f; // l_i - o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) - output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i + // Loop over KV seq length + for (int kb = 0; kb < kb_lim; kb++) { + // Load K block and apply scale + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!align_K && kb == (params->NK_aligned)) { + loader_k.load_safe(short2(BD, params->kL_rem)); + } else { + loader_k.load_unsafe(); + } + + // Do S = Q @ K.T + Stile.clear(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short dd = 0; dd < TD; dd++) { + simdgroup_barrier(mem_flags::mem_none); + + Qtile.template load( + &Qs[Qs_offset + dd * Qs_tile_stride]); + Ktile.template load( + &Ks[Ks_offset + dd * Ks_tile_stride]); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Stile, Qtile, Ktile, Stile); + } + + // Mask out length sequence + if (!align_K && kb == (params->NK_aligned)) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + short col_pos = sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if ((col_pos + jj) >= params->kL_rem) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } } } - } - static METAL_FUNC void rescale_ss( - threadgroup T* Ss, - threadgroup float* Corrections, - uint simd_group_id, - uint simd_lane_id, - short2 local_blocks, - float alpha, - float softcapping) { - if (simd_group_id == 0) { - short row_offset = BM + float_padding; - threadgroup float* maxes = Corrections; - threadgroup float* sums = Corrections + row_offset; - threadgroup float* o_rescale = sums + row_offset; - threadgroup float* output_scales = o_rescale + row_offset; - - if (simd_lane_id < uint(local_blocks.y)) { - float m_i_old = maxes[simd_lane_id]; - float l_i_old = sums[simd_lane_id]; - - float m_i_new = m_i_old; - float l_i_new = l_i_old; - - short offset = simd_lane_id * (BN + tgp_padding); - - float m_ij = -INFINITY; - - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - if (softcapping != 1.) { - val = precise::tanh(val); - val = val * softcapping; + // Mask out if causal + if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = + tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if (row_pos < (col_pos + jj)) { + Stile.frag_at(i, j)[jj] = neg_inf; + } } - m_ij = max(m_ij, val); } + } + } + + // Other masking as needed + if (has_mask) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); - m_i_new = max(m_ij, m_i_new); + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; - float rowsum = 0.f; // lij + using MMAFrag_mask_t = BaseMMAFrag; + using frag_t = typename MMAFrag_mask_t::frag_type; - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - if (softcapping != 1.) { - val = precise::tanh(val); - val = val * softcapping; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + + frag_t mfrag; + + MMAFrag_mask_t::load_safe( + mfrag, + mask, + int(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + Stile.frag_at(i, j)[jj] = + mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; + } else { + Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); + } } - float P_i_j = exp(val - m_ij); - rowsum += P_i_j; - P_i_j = P_i_j * exp(m_ij - m_i_new); - Ss[offset + j] = T(P_i_j); } - - l_i_new = - exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; - maxes[simd_lane_id] = m_i_new; - sums[simd_lane_id] = l_i_new; - float rescale = l_i_old * exp(m_i_old - m_i_new); - o_rescale[simd_lane_id] = rescale; - output_scales[simd_lane_id] = 1.0 / l_i_new; } } - } - /* Main kernel function */ - static METAL_FUNC void run( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device U* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - threadgroup T* Qs [[threadgroup(0)]], - threadgroup T* Ks [[threadgroup(1)]], - threadgroup T* Ss [[threadgroup(2)]], - threadgroup T* Vs [[threadgroup(3)]], - threadgroup float* Corrections [[threadgroup(4)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load V blocks + if (!align_K && kb == (params->NK_aligned)) { + loader_v.load_safe(short2(BD, params->kL_rem)); + } else { + loader_v.load_unsafe(); } - threadgroup_barrier(mem_flags::mem_none); - - // Find block in Q, O; and head in K, V. - const int c_row = tid_y * BM; - - Q += transpose_q ? c_row : c_row * params->ldq; - thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); - - short tgp_bm = min(BM, params->M - c_row); - short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - - loader_q.load_safe(tile_dims_Q); - - initialize_corrections(Corrections, simd_lane_id, simd_group_id); - - O += c_row * params->ldo; - - // Prepare threadgroup mma operation - thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); - thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); - thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); - thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); - - for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; - n_block++) { - short c_col = BN; - - // Prepare threadgroup loading operations - short gemm_k_iterations = params->gemm_k_iterations_aligned; - short tgp_bn_qk = min(BN, params->N - c_col * n_block); - threadgroup_barrier(mem_flags::mem_none); - - /////////////////////////////////////////////////////////////////////////////// - { // Loop over K - unaligned case - - if (tgp_bm == BM && tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } else if (tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - - } else if (tgp_bm == BM) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); + // Do softmax - } else { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } - } + // Temp variables + AccumType new_max[kRowsPT]; + AccumType factor[kRowsPT]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } - mma_qk_op.store_result_to_tgp_memory( - Ss, BN + tgp_padding, short2(BN, BM)); + // Row max + Stile.template row_reduce(new_max); - threadgroup_barrier(mem_flags::mem_threadgroup); + // exp(Si - rowmax(Si)) + Stile.template row_bin_op(new_max); - rescale_ss( - Ss, - Corrections, - simd_group_id, - simd_lane_id, - short2(tgp_bn_qk, tgp_bm), - params->alpha, - params->softcapping); + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp2(max_score[i] - new_max[i]); + } + + // Save max for next iteration + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = new_max[i]; + } + + // Row Sum + AccumType sum_score_tmp[kRowsPT] = {0}; + Stile.template row_reduce(sum_score_tmp); - loader_v.load_safe(short2(BK, tgp_bn_qk)); + // Update norm + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + // Update O + Otile.template row_bin_op(factor); - threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); - mma_softmax_sv_op.rescale_output(o_scales); + // Load V into registers + threadgroup_barrier(mem_flags::mem_threadgroup); - mma_softmax_sv_op.mma(Ss, Vs); + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } - threadgroup float* final_output_scales = - Corrections + 3 * (BM + float_padding); + const short kk = ik * kFragSize; + const short dd = id * kFragSize; - mma_softmax_sv_op.rescale_output(final_output_scales); + Vtile.template load( + &Vs[Vs_offset + kk * LDV_tgp + dd]); - loader_v.next(); - loader_k.next(BN); + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } - mma_qk_op.clear_results(); + MMAFrag_acc_t::mma( + Otile.frag_at(iq, id), + Stile.frag_at(iq, ik), + Vtile.frag_at(0, 0), + Otile.frag_at(iq, id)); + } + } } - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); + // Prepare for next iteration + loader_k.next(); + loader_v.next(); } -}; -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device T* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using attention_kernel = FastAttentionKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_v, - MN_aligned, - K_aligned>; - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* Q_bstrides = batch_strides; - const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); - - Q += batch_offsets.x; - K += batch_offsets.y; - V += batch_offsets.y; + // Normalize output + Otile.template row_bin_op(sum_score); + threadgroup_barrier(mem_flags::mem_none); + // Store results + O += (tm + sm) * params->O_strides[2] + sn; + + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); } else { - Q += params->batch_stride_q * tid.z; - K += params->batch_stride_k * tid.z; - V += params->batch_stride_v * tid.z; - } - - // same shape as input - O += params->batch_stride_o * tid.z; - threadgroup T Qs[attention_kernel::tgp_mem_size_q]; - threadgroup T Ss[attention_kernel::tgp_mem_size_s]; - threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; - - if (attention_kernel::share_kv_smem) { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); - } else { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T Vs[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); + Otile.template store(O, params->O_strides[2]); } } // clang-format off // SDPA full instantiations -#define instantiate_fast_inference_self_attention_kernel( \ - itype, otype, bm, bn, bk, wm, wn) \ - template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ - "_itype_" #itype)]] [[kernel]] void \ - attention( \ - const device itype* Q [[buffer(0)]], \ - const device itype* K [[buffer(1)]], \ - const device itype* V [[buffer(2)]], \ - device otype* O [[buffer(3)]], \ - const constant MLXFastAttentionParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(5)]], \ - const constant size_t* batch_strides [[buffer(6)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); - -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 32, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 64, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 96, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 128, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 256, - 2, - 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 32, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 96, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); + +// Instantiate a templated kernel. +// Extra args are used as template parameters: +// e.g. instantiate_kernel(binary_int, binary, a, b) -> +// [[host_name(binary_int)]] [kernel] binary +#define instantiate_kernel(name, func, ...) \ + template [[host_name( \ + name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; + +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ + instantiate_kernel( \ + "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ + "_wm" #wm "_wn" #wn "_mask" #mname, \ + attention, dtype, bq, bk, bd, wm, wn, mtype, float) + +#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 96, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 72, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 32, 4, 1, mname, mtype) + +#define instantiate_attn_mask_helper(iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, bool_, bool) + +instantiate_attn_mask_helper(float16, half); +instantiate_attn_mask_helper(bfloat16, bfloat16_t); +instantiate_attn_mask_helper(float32, float); // SDPA vector instantiations #define instantiate_sdpa_vector(type, head_dim) \ @@ -1443,13 +2374,13 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); #define instantiate_sdpa_vector_heads(type) \ instantiate_sdpa_vector(type, 32) \ instantiate_sdpa_vector(type, 64) \ + instantiate_sdpa_vector(type, 72) \ + instantiate_sdpa_vector(type, 80) \ instantiate_sdpa_vector(type, 96) \ instantiate_sdpa_vector(type, 128) \ instantiate_sdpa_vector(type, 256) instantiate_sdpa_vector_heads(float) -#if defined(__HAVE_BFLOAT__) instantiate_sdpa_vector_heads(bfloat16_t) -#endif instantiate_sdpa_vector_heads(float16_t) - // clang-format on + // clang-format on \ No newline at end of file diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index d34d4748b5..7f21aa9b21 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -990,6 +990,8 @@ impl Module for Identity { struct Sdpa { scale: f32, softcapping: f32, + mask: Option, + do_causal: bool, } impl candle::CustomOp3 for Sdpa { @@ -1026,6 +1028,8 @@ impl candle::CustomOp3 for Sdpa { let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?]; let elem_count: usize = out_dims.iter().product(); + let out_shape = Shape::from_dims(&out_dims); + let out_layout = Layout::contiguous(out_shape.clone()); let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?; @@ -1047,16 +1051,20 @@ impl candle::CustomOp3 for Sdpa { let k_head = k_l.dim(D::Minus1)?; let q_head = q_l.dim(D::Minus1)?; let q_seq = q_l.dim(2)?; + let k_seq = k_l.dim(2)?; let mut implementation_supports_use_case = q_head == k_head; - let supported_head_dim = - q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256; - - const SDPA_FULL_THRESHOLD: usize = 2; - - let supports_sdpa_full = - q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head; - let supports_sdpa_vector = q_seq == 1 && supported_head_dim; + let supported_head_dim = q_head == 32 + || q_head == 64 + || q_head == 72 + || q_head == 80 + || q_head == 96 + || q_head == 128 + || q_head == 256; + + let supports_sdpa_full_mask = !self.mask.is_some() || q_seq <= k_seq; + let supports_sdpa_full = q_seq > 8 && supported_head_dim && supports_sdpa_full_mask; + let supports_sdpa_vector = q_seq <= 8 && supported_head_dim && q_seq <= k_seq; implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector; @@ -1095,7 +1103,7 @@ impl candle::CustomOp3 for Sdpa { // Route to the 2 pass fused attention if the k seqlen is large. // https://github.com/ml-explore/mlx/pull/1597 const TWO_PASS_K_THRESHOLD: usize = 1024; - if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD { + if k_seq >= TWO_PASS_K_THRESHOLD { let mut intermediate_shape = [ &out_dims[0..out_dims.len() - 2], &[candle_metal_kernels::SDPA_2PASS_BLOCKS], @@ -1167,27 +1175,70 @@ impl candle::CustomOp3 for Sdpa { .map_err(candle::Error::wrap)?; } } else if supports_sdpa_full { - if q_l.dim(2)? != k_l.dim(2)? { - candle::bail!( - "query and key sequence length must be equal if using full metal sdpa" - ) + encoder.set_label("full_attention"); + if self.softcapping != 1. { + candle::bail!("SDPA full requires softcapping to be disabled (1.0)"); } - encoder.set_label("full_attention"); + let mask_s_l = self.mask.as_ref().map(|m| m.storage_and_layout()); + + let (mask_type, mask_buffer, mask_strides) = if let Some(mask) = &self.mask { + let (mask_s, mask_l) = mask_s_l.as_ref().unwrap(); + + let mask_buffer = match &**mask_s { + candle::Storage::Metal(m) => m.buffer(), + _ => candle::bail!("Expected metal device for mask"), + }; + + let mask_type = match mask.dtype() { + DType::BF16 => SdpaDType::BF16, + DType::F16 => SdpaDType::F16, + DType::F32 => SdpaDType::F32, + other => candle::bail!("unsupported sdpa type {other:?}"), + }; + if mask_type != itype { + candle::bail!("Mask type {mask_type:?} must match q type {itype:?}"); + } + + if mask_l.dims() != [q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, k_seq] { + candle::bail!( + "Mask shape must be {:?} (bs, qheads, qseq, kseq), got {:?}", + [q_l.dim(0)?, q_head, q_l.dim(2)?, k_seq], + mask_l.dims() + ); + } + + ( + Some(mask_type), + Some(mask_buffer), + Some(mask_l.stride().to_vec()), + ) + } else { + (None, None, None) + }; + candle_metal_kernels::call_sdpa_full( q.device().device(), &encoder, q.device().kernels(), q_l.start_offset(), q_l.dims(), + q_l.stride(), q.buffer(), k_l.start_offset(), + k_l.dims(), + k_l.stride(), k.buffer(), v_l.start_offset(), v.buffer(), + v_l.stride(), + mask_type, + mask_buffer, + mask_strides.as_deref(), &output, + out_layout.stride(), self.scale, - self.softcapping, + self.do_causal, itype, ) .map_err(candle::Error::wrap)?; @@ -1196,7 +1247,7 @@ impl candle::CustomOp3 for Sdpa { } let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype()); - Ok((newstorage, Shape::from_dims(&out_dims))) + Ok((newstorage, out_shape)) } } @@ -1208,13 +1259,15 @@ impl candle::CustomOp3 for Sdpa { /// - `q`: (bs, qhead, seq, hidden) /// - `k`: (bs, kv_head, kv_seq, hidden) /// - `k`: (bs, kv_head, kv_seq, v_hidden) +/// - `mask`: (bs, qhead, seq, kv_seq) +/// - `do_causal`: Apply causal masking. If this is true, the mask does not need to be provided. /// - `scale` is applied before softmax. /// - If `softcapping` != 1.0: /// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v /// /// **Output shape:** (bs, qhead, seq, v_hidden) /// -/// **Supported head dims:** 32, 64, 96, 128, 256. +/// Note: For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q. /// /// ## On Metal: /// - If `seq` == 1: @@ -1222,9 +1275,27 @@ impl candle::CustomOp3 for Sdpa { /// - Supports `seq` != `kv_seq` (cross attn. support) /// - Supports GQA when `qhead` is a multiple of `kv_head` /// - Otherwise: -/// - Use an alternate kernel -/// - Requires `seq` == `kv_seq` -/// - GQA is not supported (requires `qhead` == `kv_head`) -pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result { - q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping }) +/// - Masking is supported +/// - Supports `seq` != `kv_seq` (cross attn. support) +/// - Supports GQA when `qhead` is a multiple of `kv_head` +/// - Softcapping is not supported. +pub fn sdpa( + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + do_causal: bool, + scale: f32, + softcapping: f32, +) -> Result { + q.apply_op3_no_bwd( + k, + v, + &Sdpa { + scale, + softcapping, + mask: mask.cloned(), + do_causal, + }, + ) } diff --git a/candle-nn/tests/sdpa.rs b/candle-nn/tests/sdpa.rs index f63d1f05e4..9fd24aedbb 100644 --- a/candle-nn/tests/sdpa.rs +++ b/candle-nn/tests/sdpa.rs @@ -38,7 +38,7 @@ mod metal_sdpa_tests { .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? @@ -68,7 +68,7 @@ mod metal_sdpa_tests { .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? @@ -104,7 +104,8 @@ mod metal_sdpa_tests { .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + let sdpa_output = + candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, SOFTCAP as f32)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? @@ -140,7 +141,8 @@ mod metal_sdpa_tests { .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + let sdpa_output = + candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, SOFTCAP as f32)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? @@ -170,7 +172,7 @@ mod metal_sdpa_tests { .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index e171b54fd8..1c416b12f2 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -225,7 +225,15 @@ impl LayerWeights { let y = if q.device().is_metal() && seq_len == 1 { // SDPA will do MQA for us - candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)? + candle_nn::ops::sdpa( + &q, + &k, + &v, + None, + false, + 1. / (self.head_dim as f32).sqrt(), + 1., + )? } else { // Support for MQA, useful for 70B models and mistral. let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;