From 35614f49374b1f4ab4d6bdd91cd5c29111fd8535 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 17 Nov 2025 14:00:58 -0500 Subject: [PATCH 1/6] Add dummy i32/i16/f6e2m3/f6e3m2/f4/f8e8m0 dtypes --- candle-core/src/backend.rs | 1 + candle-core/src/convert.rs | 19 +- candle-core/src/cpu/kernels.rs | 34 + candle-core/src/cpu_backend/mod.rs | 649 +++++++++++++++++--- candle-core/src/cpu_backend/utils.rs | 21 + candle-core/src/cuda_backend/device.rs | 169 ++++- candle-core/src/cuda_backend/mod.rs | 133 +++- candle-core/src/cuda_backend/utils.rs | 28 +- candle-core/src/device.rs | 8 + candle-core/src/display.rs | 34 +- candle-core/src/dtype.rs | 101 ++- candle-core/src/dummy_cuda_backend.rs | 4 + candle-core/src/dummy_dtype.rs | 268 ++++++++ candle-core/src/dummy_metal_backend.rs | 4 + candle-core/src/lib.rs | 4 +- candle-core/src/metal_backend/device.rs | 2 + candle-core/src/metal_backend/mod.rs | 36 +- candle-core/src/npy.rs | 35 +- candle-core/src/op.rs | 249 +++++--- candle-core/src/safetensors.rs | 181 +++++- candle-core/src/scalar.rs | 24 +- candle-core/src/sort.rs | 62 +- candle-pyo3/src/lib.rs | 20 +- candle-transformers/src/models/deepseek2.rs | 21 + 24 files changed, 1822 insertions(+), 285 deletions(-) create mode 100644 candle-core/src/dummy_dtype.rs diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index a85f8d36d2..b61d46d2de 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -158,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; fn set_seed(&self, _: u64) -> Result<()>; + fn get_current_seed(&self) -> Result; /// Synchronize should block until all the operations on the device are completed. fn synchronize(&self) -> Result<()>; diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index db7bf6a4a8..38e7a7c9a6 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -1,6 +1,5 @@ //! Implement conversion traits for tensors use crate::{DType, Device, Error, Tensor, WithDType}; -use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::convert::TryFrom; @@ -94,6 +93,8 @@ from_tensor!(f32); from_tensor!(f16); from_tensor!(bf16); from_tensor!(i64); +from_tensor!(i32); +from_tensor!(i16); from_tensor!(u32); from_tensor!(u8); @@ -131,6 +132,16 @@ impl Tensor { f.write_u32::(v)? } } + DType::I16 => { + for v in vs.to_vec1::()? { + f.write_i16::(v)? + } + } + DType::I32 => { + for v in vs.to_vec1::()? { + f.write_i32::(v)? + } + } DType::I64 => { for v in vs.to_vec1::()? { f.write_i64::(v)? @@ -141,10 +152,14 @@ impl Tensor { f.write_all(&vs)?; } DType::F8E4M3 => { - for v in vs.to_vec1::()? { + let vs = vs.to_vec1::()?; + for v in vs { f.write_u8(v.to_bits())? } } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err(crate::Error::UnsupportedDTypeForOp(self.dtype(), "write_bytes").bt()) + } } Ok(()) } diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index 64f728f63f..bca76adcc8 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -151,6 +151,28 @@ impl VecOps for u32 { ::max(self, other) } } +impl VecOps for i16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} +impl VecOps for i32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} impl VecOps for i64 { #[inline(always)] fn min(self, other: Self) -> Self { @@ -163,6 +185,18 @@ impl VecOps for i64 { } } +impl VecOps for float8::F8E4M3 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } +} + #[inline(always)] pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { if n_threads == 1 { diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 8d8219ec9d..7d35c9e52a 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -2,7 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; -use float8::F8E4M3; +use float8::F8E4M3 as f8e4m3; use half::{bf16, f16}; use rayon::prelude::*; @@ -10,11 +10,10 @@ mod utils; pub use utils::{ binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8, }; -mod conv2d; -use conv2d::Conv2D; const USE_IM2COL_CONV1D: bool = true; const USE_COL2IM_CONV1D_TR: bool = true; +const USE_IM2COL_CONV2D: bool = true; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // intercept the oom errors to avoid panicking and provide a proper error. @@ -22,24 +21,38 @@ const USE_COL2IM_CONV1D_TR: bool = true; pub enum CpuStorage { U8(Vec), U32(Vec), + I16(Vec), + I32(Vec), I64(Vec), BF16(Vec), F16(Vec), F32(Vec), F64(Vec), - F8E4M3(Vec), + F8E4M3(Vec), + // Dummy types that store raw bytes + F6E2M3(Vec), + F6E3M2(Vec), + F4(Vec), + F8E8M0(Vec), } #[derive(Debug, Clone)] pub enum CpuStorageRef<'a> { U8(&'a [u8]), U32(&'a [u32]), + I16(&'a [i16]), + I32(&'a [i32]), I64(&'a [i64]), BF16(&'a [bf16]), F16(&'a [f16]), F32(&'a [f32]), F64(&'a [f64]), - F8E4M3(&'a [F8E4M3]), + F8E4M3(&'a [f8e4m3]), + // Dummy types that store raw bytes + F6E2M3(&'a [u8]), + F6E3M2(&'a [u8]), + F4(&'a [u8]), + F8E8M0(&'a [u8]), } #[derive(Debug, Clone)] @@ -1090,6 +1103,94 @@ impl Map2 for ConvTranspose1D<'_> { } } +struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); + +impl Map2 for Conv2D<'_> { + const OP: &'static str = "conv2d"; + fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { + let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; + let k = &k[k_l.start_offset()..]; + let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?; + let (out_h, out_w) = (p.out_h(), p.out_w()); + + // Output shape: [b_size, c_out, out_h, out_w]. + let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; + + // TODO: Avoid making this copy if `inp` already has the appropriate layout. + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w]; + let cont_s0 = p.i_h * p.i_w * p.c_in; + let cont_s1 = p.i_w * p.c_in; + let cont_s2 = p.c_in; + for b_idx in 0..p.b_size { + for h_idx in 0..p.i_h { + for w_idx in 0..p.i_w { + for c_idx in 0..p.c_in { + let src_idx = + b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3; + let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx; + inp_cont[dst_idx] = inp[src_idx] + } + } + } + } + + for offset_h in 0..p.k_h { + for offset_w in 0..p.k_w { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { + let dst_idx = dst_c_idx * out_w * out_h; + let k_cont = (0..p.c_in) + .map(|c_in_idx| { + k[dst_c_idx * k_s0 + + c_in_idx * k_s1 + + offset_h * k_s2 + + offset_w * k_s3] + }) + .collect::>(); + for b_idx in 0..p.b_size { + let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w; + for dst_h in 0..out_h { + let dst_idx = dst_idx + dst_h * out_w; + let src_h = p.stride * dst_h + offset_h * p.dilation; + if src_h < p.padding || src_h >= p.i_h + p.padding { + continue; + } + let src_h = src_h - p.padding; + for dst_w in 0..out_w { + let dst_idx = dst_idx + dst_w; + let src_w = p.stride * dst_w + offset_w * p.dilation; + if src_w < p.padding || src_w >= p.i_w + p.padding { + continue; + } + let src_w = src_w - p.padding; + let inp_cont = &inp_cont + [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..]; + assert!(inp_cont.len() >= p.c_in); + assert!(k_cont.len() >= p.c_in); + let mut d = T::zero(); + unsafe { + T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) + } + let dst_p = dst.as_ptr(); + // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise + // the different tasks so no two threads can try to write at the same + // location. + unsafe { + let ptr = dst_p.add(dst_idx) as *mut T; + *ptr += d + } + } + } + } + }); + } + } + + Ok(dst) + } +} + struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); impl Map2 for ConvTranspose2D<'_> { @@ -1552,6 +1653,28 @@ impl CpuStorage { .concat(); Self::U32(storages) } + Self::I16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I16(storages) + } + Self::I32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I32(storages) + } Self::I64(_) => { let storages = storages .iter() @@ -1618,6 +1741,50 @@ impl CpuStorage { .concat(); Self::F8E4M3(storages) } + Self::F6E2M3(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F6E2M3(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F6E2M3(storages) + } + Self::F6E3M2(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F6E3M2(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F6E3M2(storages) + } + Self::F4(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F4(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F4(storages) + } + Self::F8E8M0(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F8E8M0(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F8E8M0(storages) + } }; Ok(s) } @@ -1630,12 +1797,18 @@ impl BackendStorage for CpuStorage { match self { Self::U8(_) => DType::U8, Self::U32(_) => DType::U32, + Self::I16(_) => DType::I16, + Self::I32(_) => DType::I32, Self::I64(_) => DType::I64, Self::BF16(_) => DType::BF16, Self::F16(_) => DType::F16, Self::F32(_) => DType::F32, Self::F64(_) => DType::F64, Self::F8E4M3(_) => DType::F8E4M3, + Self::F6E2M3(_) => DType::F6E2M3, + Self::F6E3M2(_) => DType::F6E3M2, + Self::F4(_) => DType::F4, + Self::F8E8M0(_) => DType::F8E8M0, } } @@ -1670,10 +1843,6 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, bf16::from_f64); Ok(Self::BF16(data)) } - (Self::F8E4M3(storage), DType::BF16) => { - let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); - Ok(Self::BF16(data)) - } (Self::U8(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) @@ -1702,10 +1871,6 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, f16::from_f64); Ok(Self::F16(data)) } - (Self::F8E4M3(storage), DType::F16) => { - let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); - Ok(Self::F16(data)) - } (Self::U8(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -1734,10 +1899,6 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } - (Self::F8E4M3(storage), DType::F32) => { - let data = unary_map(storage, layout, |v| v.to_f32()); - Ok(Self::F32(data)) - } (Self::U8(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v); Ok(Self::U8(data)) @@ -1766,10 +1927,6 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } - (Self::F8E4M3(storage), DType::U8) => { - let data = unary_map(storage, layout, |v| v.to_f32() as u8); - Ok(Self::U8(data)) - } (Self::U8(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -1798,10 +1955,6 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } - (Self::F8E4M3(storage), DType::U32) => { - let data = unary_map(storage, layout, |v| v.to_f32() as u32); - Ok(Self::U32(data)) - } (Self::U8(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -1830,10 +1983,6 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } - (Self::F8E4M3(storage), DType::I64) => { - let data = unary_map(storage, layout, |v| v.to_f32() as i64); - Ok(Self::I64(data)) - } (Self::U8(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -1862,42 +2011,226 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } - (Self::F8E4M3(storage), DType::F64) => { - let data = unary_map(storage, layout, |v| v.to_f64()); - Ok(Self::F64(data)) - } + // Conversions to F8E4M3 (Self::U8(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); Ok(Self::F8E4M3(data)) } (Self::U32(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); Ok(Self::F8E4M3(data)) } (Self::I64(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); Ok(Self::F8E4M3(data)) } (Self::BF16(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32())); + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v.to_f32())); Ok(Self::F8E4M3(data)) } (Self::F16(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32())); + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v.to_f32())); Ok(Self::F8E4M3(data)) } (Self::F32(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, F8E4M3::from_f32); + let data = unary_map(storage, layout, f8e4m3::from_f32); Ok(Self::F8E4M3(data)) } (Self::F64(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, F8E4M3::from_f64); + let data = unary_map(storage, layout, f8e4m3::from_f64); Ok(Self::F8E4M3(data)) } (Self::F8E4M3(storage), DType::F8E4M3) => { let data = unary_map(storage, layout, |v| v); Ok(Self::F8E4M3(data)) } + // Conversions from F8E4M3 + (Self::F8E4M3(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } + (Self::F8E4M3(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } + (Self::F8E4M3(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } + (Self::F8E4M3(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } + (Self::F8E4M3(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } + (Self::F8E4M3(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v.to_f32()); + Ok(Self::F32(data)) + } + (Self::F8E4M3(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v.to_f64()); + Ok(Self::F64(data)) + } + // Conversions to I16 + (Self::U8(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::U32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::I16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I16(data)) + } + (Self::I32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::I64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::BF16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F8E4M3(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + // Conversions to I32 + (Self::U8(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::U32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::I16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::I32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I32(data)) + } + (Self::I64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::BF16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F8E4M3(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + // Conversions from I16 + (Self::I16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::I16(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::I16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I16(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::I16(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::I16(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::I16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + // Conversions from I32 + (Self::I32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::I32(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::I32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I32(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I32(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::I32(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::I32(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::I32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + // Dummy types - return error for all conversions to/from dummy types + (_, DType::F6E2M3) | (_, DType::F6E3M2) | (_, DType::F4) | (_, DType::F8E8M0) => { + Err(Error::UnsupportedDTypeForOp(dtype, "to_dtype").bt()) + } + (Self::F6E2M3(_), _) + | (Self::F6E3M2(_), _) + | (Self::F4(_), _) + | (Self::F8E8M0(_), _) => { + Err(Error::UnsupportedDTypeForOp(self.dtype(), "to_dtype").bt()) + } } } @@ -2012,12 +2345,18 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } Self::F8E4M3(storage) => { - let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e))); + let data = unary_map(storage, layout, |v| v.powf(f8e4m3::from_f64(e))); Ok(Self::F8E4M3(data)) } - Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), - Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), - Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), + Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "powf").bt()), + Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "powf").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "powf").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "powf").bt()), + Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "powf").bt()), + Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "powf").bt()), + Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "powf").bt()), + Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "powf").bt()), + Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "powf").bt()), } } @@ -2041,12 +2380,18 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } Self::F8E4M3(storage) => { - let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha))); + let data = unary_map(storage, layout, |v| elu(v, f8e4m3::from_f64(alpha))); Ok(Self::F8E4M3(data)) } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), + Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "elu").bt()), + Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "elu").bt()), + Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "elu").bt()), + Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "elu").bt()), } } @@ -2088,15 +2433,6 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } } - Self::F8E4M3(storage) => { - if B::F8E4M3_VEC { - let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec); - Ok(Self::F8E4M3(data)) - } else { - let data = unary_map(storage, layout, B::f8e4m3); - Ok(Self::F8E4M3(data)) - } - } Self::U8(storage) => { let data = unary_map(storage, layout, B::u8); Ok(Self::U8(data)) @@ -2105,10 +2441,26 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } + Self::I16(storage) => { + let data = unary_map(storage, layout, B::i16); + Ok(Self::I16(data)) + } + Self::I32(storage) => { + let data = unary_map(storage, layout, B::i32); + Ok(Self::I32(data)) + } Self::I64(storage) => { let data = unary_map(storage, layout, B::i64); Ok(Self::I64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, B::f8e4m3); + Ok(Self::F8E4M3(data)) + } + Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "unary").bt()), + Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "unary").bt()), + Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "unary").bt()), + Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "unary").bt()), } } @@ -2159,6 +2511,14 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U32(data)) } + (Self::I16(lhs), Self::I16(rhs)) => { + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i16); + Ok(Self::I16(data)) + } + (Self::I32(lhs), Self::I32(rhs)) => { + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i32); + Ok(Self::I32(data)) + } (Self::I64(lhs), Self::I64(rhs)) => { let data = if B::I64_VEC { binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec) @@ -2175,6 +2535,10 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U8(data)) } + (Self::F8E4M3(lhs), Self::F8E4M3(rhs)) => { + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f8e4m3); + Ok(Self::F8E4M3(data)) + } _ => { // This should be covered by the dtype check above. Err(Error::DTypeMismatchBinaryOp { @@ -2202,6 +2566,12 @@ impl BackendStorage for CpuStorage { (Self::U32(src), Self::U32(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::I16(src), Self::I16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::I32(src), Self::I32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (Self::I64(src), Self::I64(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } @@ -2217,6 +2587,19 @@ impl BackendStorage for CpuStorage { (Self::F64(src), Self::F64(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::F8E4M3(src), Self::F8E4M3(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F6E2M3(src), Self::F6E2M3(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F6E3M2(src), Self::F6E3M2(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F4(src), Self::F4(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o), + (Self::F8E8M0(src), Self::F8E8M0(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (_, dst) => { return Err(Error::DTypeMismatchBinaryOp { lhs: self.dtype(), @@ -2233,11 +2616,26 @@ impl BackendStorage for CpuStorage { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F8E4M3(src), Self::F8E4M3(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_l) + } + (Self::F6E2M3(src), Self::F6E2M3(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_l) + } + (Self::F6E3M2(src), Self::F6E3M2(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_l) + } + (Self::F4(src), Self::F4(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F8E8M0(src), Self::F8E8M0(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_l) + } (_, dst) => { // This should be covered by the dtype check above. return Err(Error::DTypeMismatchBinaryOp { @@ -2262,6 +2660,8 @@ impl BackendStorage for CpuStorage { match self { Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), } @@ -2375,7 +2775,46 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result { - Conv2D(params).map(self, l, kernel, kernel_l) + if !USE_IM2COL_CONV2D { + return Conv2D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let (h_out, w_out) = (params.out_h(), params.out_w()); + let k = op.h_k * op.w_k * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv_transpose2d( @@ -2435,6 +2874,8 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), + Self::I16(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), + Self::I32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), } @@ -2464,6 +2905,20 @@ impl BackendStorage for CpuStorage { }; IndexAdd { ids, dim }.map(self, l, src, src_l) } + Self::I16(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + Self::I32(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } Self::I64(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], @@ -2529,8 +2984,23 @@ impl BackendStorage for CpuStorage { (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v), (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v), (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v), + (Self::I16(storage), Scalar::I16(v)) => set(storage, l, v), + (Self::I32(storage), Scalar::I32(v)) => set(storage, l, v), (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v), (Self::F8E4M3(storage), Scalar::F8E4M3(v)) => set(storage, l, v), + // Dummy types don't support scalar operations + (Self::F6E2M3(_), _) => { + crate::bail!("const_set not supported for dummy type F6E2M3") + } + (Self::F6E3M2(_), _) => { + crate::bail!("const_set not supported for dummy type F6E3M2") + } + (Self::F4(_), _) => { + crate::bail!("const_set not supported for dummy type F4") + } + (Self::F8E8M0(_), _) => { + crate::bail!("const_set not supported for dummy type F8E8M0") + } (st, s) => crate::bail!( "const_set dtype mismatch, expected {:?} but got {:?}", st.dtype(), @@ -2572,15 +3042,26 @@ impl BackendDevice for CpuDevice { crate::bail!("cannot seed the CPU rng with set_seed") } + fn get_current_seed(&self) -> Result { + crate::bail!("cannot get the CPU rng seed with get_current_seed") + } + fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { use rand::prelude::*; let elem_count = shape.elem_count(); let mut rng = rand::rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F8E4M3 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F4 + | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()), DType::BF16 => { let mut data = Vec::with_capacity(elem_count); let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max)) @@ -2599,16 +3080,6 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } - DType::F8E4M3 => { - let mut data = Vec::with_capacity(elem_count); - let uniform = - rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max)) - .map_err(Error::wrap)?; - for _i in 0..elem_count { - data.push(rng.sample::(uniform)) - } - Ok(CpuStorage::F8E4M3(data)) - } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let uniform = @@ -2635,9 +3106,16 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F8E4M3 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F4 + | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), DType::BF16 => { let mut data = Vec::with_capacity(elem_count); let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std)) @@ -2656,15 +3134,6 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } - DType::F8E4M3 => { - let mut data = Vec::with_capacity(elem_count); - let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std)) - .map_err(Error::wrap)?; - for _i in 0..elem_count { - data.push(normal.sample(&mut rng)) - } - Ok(CpuStorage::F8E4M3(data)) - } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let normal = @@ -2703,6 +3172,16 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::U32(v) } + DType::I16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I16(v) + } + DType::I32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I32(v) + } DType::I64 => { let mut v = Vec::with_capacity(elem_count); v.set_len(elem_count); @@ -2733,6 +3212,9 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::F8E4M3(v) } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err(Error::UnsupportedDTypeForOp(dtype, "alloc_uninit").bt()) + } }; Ok(storage) } @@ -2742,12 +3224,17 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![0i16; elem_count]), + DType::I32 => CpuStorage::I32(vec![0i32; elem_count]), DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), - DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]), DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), DType::F64 => CpuStorage::F64(vec![0f64; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![f8e4m3::ZERO; elem_count]), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err(Error::UnsupportedDTypeForOp(dtype, "zeros").bt()) + } }; Ok(storage) } diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index dd27d3d18d..1f800a928b 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -10,12 +10,19 @@ pub trait Map1 { match vs { C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)), + C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)), C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)), + // Dummy types don't support Map1 operations + C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()), + C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()), + C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()), + C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()), } } } @@ -27,12 +34,19 @@ pub trait Map1Any { match vs { C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I16(vs) => Ok(self.f(vs, layout, C::I16)?), + C::I32(vs) => Ok(self.f(vs, layout, C::I32)?), C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?), + // Dummy types don't support Map1Any operations + C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()), + C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()), + C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()), + C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()), } } } @@ -45,6 +59,8 @@ pub trait Map2 { match (v1, v2) { (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), + (C::I16(v1), C::I16(v2)) => Ok(C::I16(self.f(v1, l1, v2, l2)?)), + (C::I32(v1), C::I32(v2)) => Ok(C::I32(self.f(v1, l1, v2, l2)?)), (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), @@ -69,11 +85,14 @@ pub trait Map2InPlace { match (v1, v2) { (C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?, (C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?, + (C::I16(v1), C::I16(v2)) => self.f(v1, l1, v2, l2)?, + (C::I32(v1), C::I32(v2)) => self.f(v1, l1, v2, l2)?, (C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?, (C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?, (C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?, (C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?, (C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?, + (C::F8E4M3(v1), C::F8E4M3(v2)) => self.f(v1, l1, v2, l2)?, (v1, v2) => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), @@ -93,6 +112,8 @@ pub trait Map2U8 { match (v1, v2) { (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::I16(v1), C::I16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::I32(v1), C::I32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index a1ed305b61..b3526ed7e5 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -1,12 +1,11 @@ -use crate::backend::BackendDevice; +use crate::backend::{BackendDevice, BackendStorage}; use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::CudaFunction; -use float8::F8E4M3; use half::{bf16, f16}; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; @@ -39,6 +38,7 @@ pub struct CudaDevice { stream: Arc, pub(crate) blas: Arc, curand: Arc>, + seed_value: Arc>, } impl std::fmt::Debug for CudaDevice { @@ -94,6 +94,18 @@ impl CudaDevice { self.stream.memcpy_dtod(src, dst).w() } + pub fn memcpy_dtoh< + T: cudarc::driver::DeviceRepr, + Src: cudarc::driver::DevicePtr, + Dst: cudarc::driver::HostSlice, + >( + &self, + src: &Src, + dst: &mut Dst, + ) -> Result<()> { + self.stream.memcpy_dtoh(src, dst).w() + } + pub fn memcpy_stod< T: cudarc::driver::DeviceRepr, Src: cudarc::driver::HostSlice + ?Sized, @@ -145,10 +157,6 @@ impl CudaDevice { self.stream.clone() } - pub fn cublas_handle(&self) -> Arc { - self.blas.clone() - } - /// When turned on, all cuda tensors **created after calling this function** will /// not track uses via cuda events. /// @@ -237,6 +245,10 @@ impl CudaDevice { stream: self.stream.clone(), }) } + + pub fn cublas_handle(&self) -> Arc { + self.blas.clone() + } } impl CudaDevice { @@ -256,6 +268,7 @@ impl CudaDevice { curand: Arc::new(Mutex::new(CudaRng(curand))), modules: Arc::new(std::sync::RwLock::new(module_store)), custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())), + seed_value: Arc::new(RwLock::new(299792458)), }) } } @@ -279,6 +292,7 @@ impl BackendDevice for CudaDevice { curand: Arc::new(Mutex::new(CudaRng(curand))), modules: Arc::new(std::sync::RwLock::new(module_store)), custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())), + seed_value: Arc::new(RwLock::new(299792458)), }) } @@ -287,9 +301,14 @@ impl BackendDevice for CudaDevice { // state will be identical and the same random numbers will be generated. let mut curand = self.curand.lock().unwrap(); curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?; + *self.seed_value.write().unwrap() = seed; Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Cuda { gpu_id: self.context.ordinal(), @@ -311,6 +330,14 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::I64(data) @@ -332,8 +359,12 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F64(data) } DType::F8E4M3 => { - let data = self.alloc_zeros::(elem_count)?; - CudaStorageSlice::F8E4M3(data) + return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + CudaError::InternalError("Dummy types not supported in CUDA backend").into(), + ) } }; Ok(CudaStorage { @@ -348,13 +379,17 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 | DType::F8E4M3 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_uniform", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count)? }; curand.0.fill_with_uniform(&mut data).w()?; @@ -365,6 +400,13 @@ impl BackendDevice for CudaDevice { curand.0.fill_with_uniform(&mut data).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()? + } }; let slice = if lo == 0. && up == 1.0 { slice @@ -392,13 +434,17 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 | DType::F8E4M3 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count_round)? }; curand @@ -412,6 +458,13 @@ impl BackendDevice for CudaDevice { curand.0.fill_with_normal(&mut data, mean, std).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()? + } }; Ok(CudaStorage { slice, @@ -430,6 +483,14 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count)?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc::(elem_count)?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + let data = self.alloc::(elem_count)?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc::(elem_count)?; CudaStorageSlice::I64(data) @@ -451,8 +512,12 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F64(data) } DType::F8E4M3 => { - let data = self.alloc::(elem_count)?; - CudaStorageSlice::F8E4M3(data) + return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + CudaError::InternalError("Dummy types not supported in CUDA backend").into(), + ) } }; Ok(CudaStorage { @@ -471,6 +536,14 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(storage)?; CudaStorageSlice::U32(data) } + CpuStorageRef::I16(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::I16(data) + } + CpuStorageRef::I32(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::I32(data) + } CpuStorageRef::I64(storage) => { let data = self.memcpy_stod(storage)?; CudaStorageSlice::I64(data) @@ -495,6 +568,16 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(storage)?; CudaStorageSlice::F8E4M3(data) } + CpuStorageRef::F4(_) + | CpuStorageRef::F6E2M3(_) + | CpuStorageRef::F6E3M2(_) + | CpuStorageRef::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: T::DTYPE, + op: "storage_from_slice", + } + .into()); + } }; Ok(CudaStorage { slice, @@ -512,6 +595,14 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(storage)?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::I16(data) + } + CpuStorage::I32(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.memcpy_stod(storage)?; CudaStorageSlice::I64(data) @@ -536,6 +627,16 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(storage)?; CudaStorageSlice::F8E4M3(data) } + CpuStorage::F4(_) + | CpuStorage::F6E2M3(_) + | CpuStorage::F6E3M2(_) + | CpuStorage::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: storage.dtype(), + op: "storage_from_cpu_storage", + } + .into()); + } }; Ok(CudaStorage { slice, @@ -553,6 +654,14 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(&storage)?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.memcpy_stod(&storage)?; + CudaStorageSlice::I16(data) + } + CpuStorage::I32(storage) => { + let data = self.memcpy_stod(&storage)?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.memcpy_stod(&storage)?; CudaStorageSlice::I64(data) @@ -577,6 +686,16 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(&storage)?; CudaStorageSlice::F8E4M3(data) } + CpuStorage::F4(_) + | CpuStorage::F6E2M3(_) + | CpuStorage::F6E3M2(_) + | CpuStorage::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: storage.dtype(), + op: "storage_from_cpu_storage_owned", + } + .into()); + } }; Ok(CudaStorage { slice, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index b1f166a6ac..51edd5de44 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -9,7 +9,6 @@ use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ CudaSlice, DevicePtr, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits, }; -use float8::F8E4M3; use half::{bf16, f16}; #[cfg(feature = "cudnn")] @@ -41,6 +40,8 @@ impl crate::scalar::Scalar { match self { Scalar::U8(v) => builder.arg(v), Scalar::U32(v) => builder.arg(v), + Scalar::I16(v) => builder.arg(v), + Scalar::I32(v) => builder.arg(v), Scalar::I64(v) => builder.arg(v), Scalar::F32(v) => builder.arg(v), Scalar::F64(v) => builder.arg(v), @@ -66,12 +67,19 @@ impl SlicePtrOrNull { pub enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), + I16(CudaSlice), + I32(CudaSlice), I64(CudaSlice), BF16(CudaSlice), F16(CudaSlice), F32(CudaSlice), F64(CudaSlice), - F8E4M3(CudaSlice), + F8E4M3(CudaSlice), + // Dummy types that store raw bytes + F6E2M3(CudaSlice), + F6E3M2(CudaSlice), + F4(CudaSlice), + F8E8M0(CudaSlice), } struct Clone; @@ -1176,12 +1184,14 @@ macro_rules! cuda_dtype { } cuda_dtype!(u8, U8); cuda_dtype!(u32, U32); +cuda_dtype!(i16, I16); +cuda_dtype!(i32, I32); cuda_dtype!(i64, I64); cuda_dtype!(f16, F16); cuda_dtype!(bf16, BF16); cuda_dtype!(f32, F32); cuda_dtype!(f64, F64); -cuda_dtype!(F8E4M3, F8E4M3); +cuda_dtype!(float8::F8E4M3, F8E4M3); impl CudaStorage { pub fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { @@ -1302,12 +1312,18 @@ impl BackendStorage for CudaStorage { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::I16(_) => DType::I16, + CudaStorageSlice::I32(_) => DType::I32, CudaStorageSlice::I64(_) => DType::I64, CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, CudaStorageSlice::F8E4M3(_) => DType::F8E4M3, + CudaStorageSlice::F6E2M3(_) => DType::F6E2M3, + CudaStorageSlice::F6E3M2(_) => DType::F6E3M2, + CudaStorageSlice::F4(_) => DType::F4, + CudaStorageSlice::F8E8M0(_) => DType::F8E8M0, } } @@ -1326,12 +1342,21 @@ impl BackendStorage for CudaStorage { let ((src, _guard_src), kernel_name) = match &mut self.slice { S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"), S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"), + S::I16(s) => (slice_ptr(s, src_o), "const_set_i16"), + S::I32(s) => (slice_ptr(s, src_o), "const_set_i32"), S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"), S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"), S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"), S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"), S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"), - S::F8E4M3(s) => (slice_ptr(s, src_o), "const_set_f8_e4m3"), + S::F8E4M3(s) => (slice_ptr(s, src_o), "const_set_f8e4m3"), + S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: self.dtype(), + op: "const_set", + } + .into()); + } }; let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?; @@ -1360,12 +1385,24 @@ impl BackendStorage for CudaStorage { let (inp, _guard) = match &self.slice { CudaStorageSlice::U8(inp) => slice_ptr(inp, start_o), CudaStorageSlice::U32(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::I16(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::I32(inp) => slice_ptr(inp, start_o), CudaStorageSlice::I64(inp) => slice_ptr(inp, start_o), CudaStorageSlice::BF16(inp) => slice_ptr(inp, start_o), CudaStorageSlice::F16(inp) => slice_ptr(inp, start_o), CudaStorageSlice::F32(inp) => slice_ptr(inp, start_o), CudaStorageSlice::F64(inp) => slice_ptr(inp, start_o), CudaStorageSlice::F8E4M3(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::F4(_) + | CudaStorageSlice::F6E2M3(_) + | CudaStorageSlice::F6E3M2(_) + | CudaStorageSlice::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: self.dtype(), + op: "to_dtype", + } + .into()); + } }; let inp = &inp; @@ -1450,8 +1487,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::F64(out) } DType::F8E4M3 => { - let out: CudaSlice = unsafe { dev.alloc::(el) }?; - + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -1459,9 +1495,16 @@ impl BackendStorage for CudaStorage { barg!(builder, *inp); builder.arg(&out); unsafe { builder.launch(cfg) }.w()?; - CudaStorageSlice::F8E4M3(out) } + DType::I16 | DType::I32 => { + return Err(CudaError::InternalError("i16,i32 dtypes are not supported").into()) + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + CudaError::InternalError("Dummy types not supported in CUDA backend").into(), + ) + } }; Ok(Self { slice, @@ -1526,6 +1569,14 @@ impl BackendStorage for CudaStorage { let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::I16(slice) => { + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; + Ok(CpuStorage::I16(cpu_storage)) + } + CudaStorageSlice::I32(slice) => { + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; + Ok(CpuStorage::I32(cpu_storage)) + } CudaStorageSlice::I64(slice) => { let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::I64(cpu_storage)) @@ -1550,6 +1601,14 @@ impl BackendStorage for CudaStorage { let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::F8E4M3(cpu_storage)) } + CudaStorageSlice::F4(_) + | CudaStorageSlice::F6E2M3(_) + | CudaStorageSlice::F6E3M2(_) + | CudaStorageSlice::F8E8M0(_) => Err(CudaError::UnsupportedDtype { + dtype: self.dtype(), + op: "to_cpu_storage", + } + .into()), } } @@ -1677,7 +1736,12 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv1d does not support u32"))?, + (S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv1d does not support i16"))?, + (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv1d does not support i32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv1d does not support i64"))?, + (S::F8E4M3(_), S::F8E4M3(_)) => { + Err(CudaError::InternalError("conv1d does not support f8e4m3"))? + } _ => Err(CudaError::InternalError("dtype mismatch in conv1d"))?, }; Ok(Self { slice, device }) @@ -1857,7 +1921,12 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, + (S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv2d does not support i16"))?, + (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, + (S::F8E4M3(_), S::F8E4M3(_)) => { + Err(CudaError::InternalError("conv2d does not support f8e4m3"))? + } _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, }; Ok(Self { slice, device }) @@ -2041,13 +2110,15 @@ impl BackendStorage for CudaStorage { let ((src, _guard_src), (dst, _guard_dst), kname) = match (&self.slice, &mut dst.slice) { (S::U8(s), S::U8(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_u8"), (S::U32(s), S::U32(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_u32"), + (S::I16(s), S::I16(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_i16"), + (S::I32(s), S::I32(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_i32"), (S::I64(s), S::I64(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_i64"), (S::BF16(s), S::BF16(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_bf16"), (S::F16(s), S::F16(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f16"), (S::F32(s), S::F32(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f32"), (S::F64(s), S::F64(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f64"), - (S::F8E4M3(s), S::F8E4M3(d)) => { - (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f8_e4m3") + (S::F8E4M3(_s), S::F8E4M3(_d)) => { + Err(CudaError::InternalError("copy2d not supported for f8e4m3"))? } _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, }; @@ -2124,12 +2195,12 @@ impl BackendStorage for CudaStorage { unsafe { builder.launch(cfg) }.w()?; } } - (CudaStorageSlice::F8E4M3(src), CudaStorageSlice::F8E4M3(dst)) => { + (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.memcpy_dtod(&src, &mut dst)? } else { - let func = dev.get_or_load_func("ucopy_f8_e4m3", &kernels::UNARY)?; + let func = dev.get_or_load_func("ucopy_u8", &kernels::UNARY)?; let mut builder = func.builder(); barg!(builder, el_count); barg!(builder, dims.len()); @@ -2140,12 +2211,12 @@ impl BackendStorage for CudaStorage { unsafe { builder.launch(cfg) }.w()?; } } - (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { + (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.memcpy_dtod(&src, &mut dst)? } else { - let func = dev.get_or_load_func("ucopy_u8", &kernels::UNARY)?; + let func = dev.get_or_load_func("ucopy_u32", &kernels::UNARY)?; let mut builder = func.builder(); barg!(builder, el_count); barg!(builder, dims.len()); @@ -2156,12 +2227,28 @@ impl BackendStorage for CudaStorage { unsafe { builder.launch(cfg) }.w()?; } } - (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { + (CudaStorageSlice::I16(src), CudaStorageSlice::I16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.memcpy_dtod(&src, &mut dst)? } else { - let func = dev.get_or_load_func("ucopy_u32", &kernels::UNARY)?; + let func = dev.get_or_load_func("ucopy_i16", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); + // SAFETY: ffi. + unsafe { builder.launch(cfg) }.w()?; + } + } + (CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.memcpy_dtod(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_i32", &kernels::UNARY)?; let mut builder = func.builder(); barg!(builder, el_count); barg!(builder, dims.len()); @@ -2204,6 +2291,22 @@ impl BackendStorage for CudaStorage { unsafe { builder.launch(cfg) }.w()?; } } + (CudaStorageSlice::F8E4M3(src), CudaStorageSlice::F8E4M3(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.memcpy_dtod(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_f8e4m3", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); + // SAFETY: ffi. + unsafe { builder.launch(cfg) }.w()?; + } + } _ => Err(CudaError::InternalError( "dtype mismatch in copy_strided op", ))?, diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index 761262693e..014b9e6c39 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -19,12 +19,16 @@ pub trait Map1 { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), + S::I16(s) => S::I16(self.f(s, d, l)?), + S::I32(s) => S::I32(self.f(s, d, l)?), S::I64(s) => S::I64(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?), - S::F8E4M3(s) => S::F8E4M3(self.f(s, d, l)?), + S::F8E4M3(_) | S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => { + crate::bail!("Map1 does not uspport this dtype."); + } }; Ok(out) } @@ -44,12 +48,16 @@ pub trait Map2 { let out = match (s1, s2) { (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), + (S::I16(s1), S::I16(s2)) => S::I16(self.f(s1, l1, s2, l2, d)?), + (S::I32(s1), S::I32(s2)) => S::I32(self.f(s1, l1, s2, l2, d)?), (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?), (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), - (S::F8E4M3(s1), S::F8E4M3(s2)) => S::F8E4M3(self.f(s1, l1, s2, l2, d)?), + (S::F8E4M3(_), S::F8E4M3(_)) => { + Err(CudaError::InternalError("Map2 not supported for F8E4M3"))? + } _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; Ok(out) @@ -88,9 +96,6 @@ pub trait Map3 { (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), - (S::F8E4M3(s1), S::F8E4M3(s2), S::F8E4M3(s3)) => { - S::F8E4M3(self.f(s1, l1, s2, l2, s3, l3, d)?) - } _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, }; Ok(out) @@ -118,12 +123,16 @@ pub trait Map2InPlace { match (dst, src) { (S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d), (S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d), + (S::I16(dst), S::I16(src)) => self.f(dst, dst_l, src, src_l, d), + (S::I32(dst), S::I32(src)) => self.f(dst, dst_l, src, src_l, d), (S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d), (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d), (S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d), (S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d), (S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d), - (S::F8E4M3(dst), S::F8E4M3(src)) => self.f(dst, dst_l, src, src_l, d), + (S::F8E4M3(_), S::F8E4M3(_)) => Err(CudaError::InternalError( + "Map2InPlace not supported for F8E4M3", + ))?, _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, } } @@ -142,12 +151,16 @@ pub trait Map1Any { let out = match s { S::U8(s) => self.f(s, d, l, S::U8)?, S::U32(s) => self.f(s, d, l, S::U32)?, + S::I16(s) => self.f(s, d, l, S::I16)?, + S::I32(s) => self.f(s, d, l, S::I32)?, S::I64(s) => self.f(s, d, l, S::I64)?, S::BF16(s) => self.f(s, d, l, S::BF16)?, S::F16(s) => self.f(s, d, l, S::F16)?, S::F32(s) => self.f(s, d, l, S::F32)?, S::F64(s) => self.f(s, d, l, S::F64)?, - S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, + S::F8E4M3(_) | S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => { + crate::bail!("Map1 does not uspport this dtype."); + } }; Ok(out) } @@ -172,7 +185,6 @@ pub trait Map2Any { (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, }; Ok(out) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 3db293cbd3..d0167c61e9 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -267,6 +267,14 @@ impl Device { } } + pub fn get_current_seed(&self) -> Result { + match self { + Self::Cpu => CpuDevice.get_current_seed(), + Self::Cuda(c) => c.get_current_seed(), + Self::Metal(m) => m.get_current_seed(), + } + } + pub fn same_device(&self, rhs: &Self) -> bool { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 422ca3525b..a9b53947f3 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -3,7 +3,6 @@ //! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py). //! use crate::{DType, Result, Tensor, WithDType}; -use float8::F8E4M3; use half::{bf16, f16}; impl Tensor { @@ -57,12 +56,22 @@ impl std::fmt::Debug for Tensor { match self.dtype() { DType::U8 => self.fmt_dt::(f), DType::U32 => self.fmt_dt::(f), + DType::I16 => self.fmt_dt::(f), + DType::I32 => self.fmt_dt::(f), DType::I64 => self.fmt_dt::(f), DType::BF16 => self.fmt_dt::(f), DType::F16 => self.fmt_dt::(f), DType::F32 => self.fmt_dt::(f), DType::F64 => self.fmt_dt::(f), - DType::F8E4M3 => self.fmt_dt::(f), + DType::F8E4M3 => self.fmt_dt::(f), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + write!( + f, + "Tensor[{:?}; dtype={}, unsupported dummy type]", + self.shape(), + self.dtype().as_str() + ) + } } } } @@ -466,6 +475,18 @@ impl std::fmt::Display for Tensor { tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } + DType::I16 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::I32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } DType::I64 => { let tf: IntFormatter = IntFormatter::new(); let max_w = tf.max_width(&to_display); @@ -501,12 +522,19 @@ impl std::fmt::Display for Tensor { } } DType::F8E4M3 => { - if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { let max_w = tf.max_width(&to_display); tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + writeln!( + f, + "Dummy type {} (not supported for display)", + self.dtype().as_str() + )?; + } }; let device_str = match self.device().location() { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index fd0ded5c3d..035ca6d503 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -1,18 +1,19 @@ //! Types for elements that can be stored and manipulated using tensors. #![allow(clippy::redundant_closure_call)] use crate::backend::BackendStorage; -use crate::cpu::kernels::VecOps; use crate::{CpuStorage, CpuStorageRef, Error, Result}; /// The different types of elements allowed in tensors. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { - // Floating-point 8 bits integer (4-bit exponent, 3-bit mantissa). - F8E4M3, // Unsigned 8 bits integer. U8, // Unsigned 32 bits integer. U32, + // Signed 16 bits integer. + I16, + // Signed 32 bits integer. + I32, // Signed 64 bits integer. I64, // Brain floating-point using half precision (16 bits). @@ -23,6 +24,16 @@ pub enum DType { F32, // Floating-point using double precision (64 bits). F64, + // 8-bit floating point with 4-bit exponent and 3-bit mantissa. + F8E4M3, + /// 6-bit float with 2 exponent bits and 3 mantissa bits (MX6 format) + F6E2M3, + /// 6-bit float with 3 exponent bits and 2 mantissa bits (MX6 format) + F6E3M2, + /// 4-bit float (MX4 format) + F4, + /// 8-bit float with 8 exponent bits and 0 mantissa bits + F8E8M0, } #[derive(Debug, PartialEq, Eq)] @@ -42,12 +53,18 @@ impl std::str::FromStr for DType { match s { "u8" => Ok(Self::U8), "u32" => Ok(Self::U32), + "i16" => Ok(Self::I16), + "i32" => Ok(Self::I32), "i64" => Ok(Self::I64), "bf16" => Ok(Self::BF16), "f16" => Ok(Self::F16), "f32" => Ok(Self::F32), "f64" => Ok(Self::F64), - "f8_e4m3" => Ok(Self::F8E4M3), + "f8e4m3" => Ok(Self::F8E4M3), + "f6e2m3" => Ok(Self::F6E2M3), + "f6e3m2" => Ok(Self::F6E3M2), + "f4" => Ok(Self::F4), + "f8e8m0" => Ok(Self::F8E8M0), _ => Err(DTypeParseError(s.to_string())), } } @@ -59,12 +76,18 @@ impl DType { match self { Self::U8 => "u8", Self::U32 => "u32", + Self::I16 => "i16", + Self::I32 => "i32", Self::I64 => "i64", Self::BF16 => "bf16", Self::F16 => "f16", Self::F32 => "f32", Self::F64 => "f64", - Self::F8E4M3 => "f8_e4m3", + Self::F8E4M3 => "f8e4m3", + Self::F6E2M3 => "f6e2m3", + Self::F6E3M2 => "f6e3m2", + Self::F4 => "f4", + Self::F8E8M0 => "f8e8m0", } } @@ -72,27 +95,49 @@ impl DType { pub fn size_in_bytes(&self) -> usize { match self { Self::U8 => 1, - Self::F8E4M3 => 1, Self::U32 => 4, + Self::I16 => 2, + Self::I32 => 4, Self::I64 => 8, Self::BF16 => 2, Self::F16 => 2, Self::F32 => 4, Self::F64 => 8, + Self::F8E4M3 => 1, + Self::F6E2M3 => 0, // 6 bits + Self::F6E3M2 => 0, // 6 bits + Self::F4 => 0, // 4 bits + Self::F8E8M0 => 1, } } pub fn is_int(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => true, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => false, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true, + Self::BF16 + | Self::F16 + | Self::F32 + | Self::F64 + | Self::F8E4M3 + | Self::F6E2M3 + | Self::F6E3M2 + | Self::F4 + | Self::F8E8M0 => false, } } pub fn is_float(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => false, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => true, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false, + Self::BF16 + | Self::F16 + | Self::F32 + | Self::F64 + | Self::F8E4M3 + | Self::F6E2M3 + | Self::F6E3M2 + | Self::F4 + | Self::F8E8M0 => true, } } } @@ -176,27 +221,19 @@ macro_rules! with_dtype { } }; } -use float8::F8E4M3; +use float8::F8E4M3 as f8e4m3; use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64); +with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64); with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); -with_dtype!(F8E4M3, F8E4M3, |v: f64| F8E4M3::from_f64(v), |v: F8E4M3| v - .to_f64()); - -impl VecOps for F8E4M3 { - fn max(self, rhs: Self) -> Self { - F8E4M3::max(self, rhs) - } - fn min(self, rhs: Self) -> Self { - F8E4M3::min(self, rhs) - } -} +with_dtype!(f8e4m3, F8E4M3, f8e4m3::from_f64, |v: f8e4m3| v.to_f64()); pub trait IntDType: WithDType + num_traits::Bounded { fn is_true(&self) -> bool; @@ -230,10 +267,28 @@ impl IntDType for u8 { } } +impl IntDType for i16 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + +impl IntDType for i32 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + pub trait FloatDType: WithDType {} impl FloatDType for f16 {} impl FloatDType for bf16 {} impl FloatDType for f32 {} impl FloatDType for f64 {} -impl FloatDType for F8E4M3 {} +impl FloatDType for f8e4m3 {} diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 329099354b..f55f39308d 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -218,6 +218,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/dummy_dtype.rs b/candle-core/src/dummy_dtype.rs new file mode 100644 index 0000000000..5fdb0961a8 --- /dev/null +++ b/candle-core/src/dummy_dtype.rs @@ -0,0 +1,268 @@ +//! Dummy data types for experimental/future float formats +//! +//! These are placeholder types for experimental floating-point formats +//! that are defined in the safetensors spec but not yet fully implemented. + +use crate::{DType, Error, Result, WithDType}; + +/// 6-bit float with 2 exponent bits and 3 mantissa bits (MX6 format) +/// This is a dummy type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct F6E2M3; + +/// 6-bit float with 3 exponent bits and 2 mantissa bits (MX6 format) +/// This is a dummy type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct F6E3M2; + +/// 4-bit float (MX4 format) +/// This is a dummy type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct F4; + +/// 8-bit float with 8 exponent bits and 0 mantissa bits +/// This is a dummy type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct F8E8M0; + +// Implement WithDType for dummy types +macro_rules! dummy_with_dtype { + ($ty:ty, $dtype:ident) => { + impl WithDType for $ty { + const DTYPE: DType = DType::$dtype; + + fn from_f64(_v: f64) -> Self { + panic!( + "{} is a dummy type and cannot be constructed", + stringify!($ty) + ) + } + + fn to_f64(self) -> f64 { + panic!( + "{} is a dummy type and cannot be converted", + stringify!($ty) + ) + } + + fn to_scalar(self) -> crate::scalar::Scalar { + panic!( + "{} is a dummy type and cannot be converted to scalar", + stringify!($ty) + ) + } + + fn cpu_storage_ref(_data: &[Self]) -> crate::CpuStorageRef<'_> { + panic!( + "{} is a dummy type and does not support storage", + stringify!($ty) + ) + } + + fn to_cpu_storage_owned(_data: Vec) -> crate::CpuStorage { + panic!( + "{} is a dummy type and does not support storage", + stringify!($ty) + ) + } + + fn cpu_storage_data(_s: crate::CpuStorage) -> Result> { + Err(Error::UnsupportedDTypeForOp(DType::$dtype, "cpu_storage_data").bt()) + } + + fn cpu_storage_as_slice(_s: &crate::CpuStorage) -> Result<&[Self]> { + Err(Error::UnsupportedDTypeForOp(DType::$dtype, "cpu_storage_as_slice").bt()) + } + } + }; +} + +dummy_with_dtype!(F6E2M3, F6E2M3); +dummy_with_dtype!(F6E3M2, F6E3M2); +dummy_with_dtype!(F4, F4); +dummy_with_dtype!(F8E8M0, F8E8M0); + +// Implement NumAssign traits for dummy types +macro_rules! dummy_num_assign { + ($ty:ty) => { + impl std::ops::AddAssign for $ty { + fn add_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::SubAssign for $ty { + fn sub_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::MulAssign for $ty { + fn mul_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::DivAssign for $ty { + fn div_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::RemAssign for $ty { + fn rem_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Add for $ty { + type Output = Self; + fn add(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Sub for $ty { + type Output = Self; + fn sub(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Mul for $ty { + type Output = Self; + fn mul(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Div for $ty { + type Output = Self; + fn div(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Rem for $ty { + type Output = Self; + fn rem(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl num_traits::Zero for $ty { + fn zero() -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + + fn is_zero(&self) -> bool { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl num_traits::One for $ty { + fn one() -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl num_traits::Num for $ty { + type FromStrRadixErr = std::num::ParseFloatError; + + fn from_str_radix( + _str: &str, + _radix: u32, + ) -> std::result::Result { + panic!( + "{} is a dummy type and does not support parsing", + stringify!($ty) + ) + } + } + + impl crate::cpu::kernels::VecOps for $ty { + fn min(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + + fn max(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + }; +} + +dummy_num_assign!(F6E2M3); +dummy_num_assign!(F6E3M2); +dummy_num_assign!(F4); +dummy_num_assign!(F8E8M0); + +// Display implementations +impl std::fmt::Display for F6E2M3 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F6E2M3") + } +} + +impl std::fmt::Display for F6E3M2 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F6E3M2") + } +} + +impl std::fmt::Display for F4 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F4") + } +} + +impl std::fmt::Display for F8E8M0 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F8E8M0") + } +} diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index de43f243fb..f4955f2d17 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -222,6 +222,10 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 3c8ba16195..068f3a340e 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -44,7 +44,7 @@ //! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use. //! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models. //! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python. -//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implementation of many published transformer models. +//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models. //! #[cfg(feature = "accelerate")] @@ -62,6 +62,7 @@ mod device; pub mod display; mod dtype; pub mod dummy_cuda_backend; +pub mod dummy_dtype; mod dummy_metal_backend; pub mod error; mod indexer; @@ -94,6 +95,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef}; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; +pub use dummy_dtype::{F4, F6E2M3, F6E3M2, F8E8M0}; pub use error::{Context, Error, Result}; pub use indexer::{IndexOp, TensorIndexer}; pub use layout::Layout; diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 0a13bbfcf3..728c6b8324 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -57,6 +57,8 @@ pub struct MetalDevice { pub(crate) kernels: Arc, /// Seed for random number generation. pub(crate) seed: Arc>, + /// Last seed value set on this device. + pub(crate) seed_value: Arc>, } // Resource options used for creating buffers. Shared storage mode allows both CPU and GPU to access the buffer. diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index e7a3324a3a..93817aeb95 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -3,7 +3,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; +use crate::{CpuStorage, CpuStorageRef, DType, Error, Layout, Result, Shape}; use candle_metal_kernels::{ metal::{Buffer, Commands, Device}, BufferOffset, CallConvTranspose2dCfg, Kernels, RESOURCE_OPTIONS, @@ -101,12 +101,17 @@ impl BackendStorage for MetalStorage { match self.dtype { DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I16 => Ok(CpuStorage::I16(self.to_cpu()?)), + DType::I32 => Ok(CpuStorage::I32(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), - DType::F8E4M3 => Ok(CpuStorage::F64(self.to_cpu()?)), + DType::F8E4M3 => Ok(CpuStorage::F8E4M3(self.to_cpu()?)), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(crate::Error::UnsupportedDTypeForOp(self.dtype, "to_cpu_storage").bt()) + } } } @@ -2099,6 +2104,7 @@ impl BackendDevice for MetalDevice { buffers: Arc::new(RwLock::new(HashMap::new())), kernels, seed, + seed_value: Arc::new(RwLock::new(299792458)), }) } @@ -2137,12 +2143,20 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), - CpuStorageRef::F8E4M3(_) => crate::bail!("Metal device does not yet support F8E4M3."), + CpuStorageRef::F8E4M3(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F6E2M3(_) + | CpuStorageRef::F6E3M2(_) + | CpuStorageRef::F4(_) + | CpuStorageRef::F8E8M0(_) => { + return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "to_dtype").bt()) + } }; Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE)) } @@ -2151,12 +2165,20 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), - CpuStorage::F8E4M3(_) => crate::bail!("Metal device does not yet support F8E4M3."), + CpuStorage::F8E4M3(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F6E2M3(_) + | CpuStorage::F6E3M2(_) + | CpuStorage::F4(_) + | CpuStorage::F8E8M0(_) => { + return Err(Error::UnsupportedDTypeForOp(storage.dtype(), "to_dtype").bt()) + } }; Ok(Self::Storage::new( buffer?, @@ -2245,6 +2267,8 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, seed: u64) -> Result<()> { + *self.seed_value.write().unwrap() = seed; + let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?; let contents = seed_buffer.data(); unsafe { @@ -2255,6 +2279,10 @@ impl BackendDevice for MetalDevice { Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + fn synchronize(&self) -> Result<()> { self.wait_until_completed() } diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 5cded74361..496465ec33 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -27,13 +27,11 @@ //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; use byteorder::{LittleEndian, ReadBytesExt}; -use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; use std::path::Path; -use std::slice; const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY"; const NPY_SUFFIX: &str = ".npy"; @@ -87,10 +85,16 @@ impl Header { DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", + DType::I16 => "i2", + DType::I32 => "i4", DType::I64 => "i8", DType::U32 => "u4", DType::U8 => "u1", DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?, + DType::F6E2M3 => Err(Error::Npy("f6e2m3 is not supported".into()))?, + DType::F6E3M2 => Err(Error::Npy("f6e3m2 is not supported".into()))?, + DType::F4 => Err(Error::Npy("f4 is not supported".into()))?, + DType::F8E8M0 => Err(Error::Npy("f8e8m0 is not supported".into()))?, }; if !shape.is_empty() { shape.push(',') @@ -163,9 +167,9 @@ impl Header { "e" | "f2" => DType::F16, "f" | "f4" => DType::F32, "d" | "f8" => DType::F64, - // "i" | "i4" => DType::S32, + "i" | "i4" => DType::I32, "q" | "i8" => DType::I64, - // "h" | "i2" => DType::S16, + "h" | "i2" => DType::I16, // "b" | "i1" => DType::S8, "B" | "u1" => DType::U8, "I" | "u4" => DType::U32, @@ -237,17 +241,30 @@ impl Tensor { reader.read_u32_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::I16 => { + let mut data_t = vec![0i16; elem_count]; + reader.read_i16_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::I32 => { + let mut data_t = vec![0i32; elem_count]; + reader.read_i32_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } DType::I64 => { let mut data_t = vec![0i64; elem_count]; reader.read_i64_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } DType::F8E4M3 => { - let mut data_t = vec![F8E4M3::ZERO; elem_count]; - let ptr = data_t.as_mut_ptr().cast::(); - let len = data_t.len(); - reader.read_i8_into(unsafe { slice::from_raw_parts_mut(ptr, len) })?; - Tensor::from_vec(data_t, shape, &Device::Cpu) + let mut data_t = vec![0u8; elem_count]; + reader.read_exact(&mut data_t)?; + let data_f8: Vec = + data_t.into_iter().map(float8::F8E4M3::from_bits).collect(); + Tensor::from_vec(data_f8, shape, &Device::Cpu) + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(Error::UnsupportedDTypeForOp(dtype, "from_reader").bt()) } } } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 367e850289..15d04729d0 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,8 +1,8 @@ -//! Tensor Operation Enums and Traits +//! Tensor Opertion Enums and Traits //! #![allow(clippy::redundant_closure_call)] use crate::Tensor; -use float8::F8E4M3; +use float8::F8E4M3 as f8e4m3; use half::{bf16, f16}; use num_traits::float::Float; @@ -191,10 +191,12 @@ pub trait UnaryOpT { fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; - fn f8e4m3(v1: F8E4M3) -> F8E4M3; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + fn i16(v1: i16) -> i16; + fn i32(v1: i32) -> i32; fn i64(v1: i64) -> i64; + fn f8e4m3(v1: f8e4m3) -> f8e4m3; // There is no very good way to represent optional function in traits so we go for an explicit // boolean flag to mark the function as existing. @@ -202,8 +204,6 @@ pub trait UnaryOpT { fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} const F16_VEC: bool = false; fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} - const F8E4M3_VEC: bool = false; - fn f8e4m3_vec(_xs: &[F8E4M3], _ys: &mut [F8E4M3]) {} const F32_VEC: bool = false; fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; @@ -218,10 +218,12 @@ pub trait BinaryOpT { fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; - fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + fn i16(v1: i16, v2: i16) -> i16; + fn i32(v1: i32, v2: i32) -> i32; fn i64(v1: i64, v2: i64) -> i64; + fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3; const BF16_VEC: bool = false; fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {} @@ -231,8 +233,6 @@ pub trait BinaryOpT { fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {} - const F8E4M3_VEC: bool = false; - fn f8e4m3_vec(_xs1: &[F8E4M3], __xs2: &[F8E4M3], _ys: &mut [F8E4M3]) {} const U8_VEC: bool = false; fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} const U32_VEC: bool = false; @@ -290,21 +290,29 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] - fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3 { + fn u8(v1: u8, v2: u8) -> u8 { $e(v1, v2) } #[inline(always)] - fn u8(v1: u8, v2: u8) -> u8 { + fn u32(v1: u32, v2: u32) -> u32 { $e(v1, v2) } #[inline(always)] - fn u32(v1: u32, v2: u32) -> u32 { + fn i16(v1: i16, v2: i16) -> i16 { + $e(v1, v2) + } + #[inline(always)] + fn i32(v1: i32, v2: i32) -> i32 { $e(v1, v2) } #[inline(always)] fn i64(v1: i64, v2: i64) -> i64 { $e(v1, v2) } + #[inline(always)] + fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3 { + $e(v1, v2) + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -374,10 +382,6 @@ macro_rules! unary_op { $e } #[inline(always)] - fn f8e4m3($a: F8E4M3) -> F8E4M3 { - $e - } - #[inline(always)] fn f32($a: f32) -> f32 { $e } @@ -394,9 +398,21 @@ macro_rules! unary_op { todo!("no unary function for u32") } #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } + #[inline(always)] fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn f8e4m3($a: f8e4m3) -> f8e4m3 { + $e + } } }; @@ -422,10 +438,6 @@ macro_rules! unary_op { $e } #[inline(always)] - fn f8e4m3($a: F8E4M3) -> F8E4M3 { - $e - } - #[inline(always)] fn u8(_: u8) -> u8 { todo!("no unary function for u8") } @@ -434,9 +446,21 @@ macro_rules! unary_op { todo!("no unary function for u32") } #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } + #[inline(always)] fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn f8e4m3($a: f8e4m3) -> f8e4m3 { + $e + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -517,17 +541,6 @@ impl UnaryOpT for Gelu { )) } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - F8E4M3::from_f32(0.5) - * v - * (F8E4M3::ONE - + F8E4M3::tanh( - F8E4M3::from_f32(SQRT_TWO_OVER_PI_F32) - * v - * (F8E4M3::ONE + F8E4M3::from_f32(0.044715) * v * v), - )) - } - #[inline(always)] fn f32(v: f32) -> f32 { 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) } @@ -544,9 +557,28 @@ impl UnaryOpT for Gelu { 0 } #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + f8e4m3::from_f32(0.5) + * v + * (f8e4m3::ONE + + f8e4m3::tanh( + f8e4m3::from_f32(SQRT_TWO_OVER_PI_F32) + * v + * (f8e4m3::ONE + f8e4m3::from_f32(0.044715) * v * v), + )) + } const KERNEL: &'static str = "ugelu"; #[cfg(feature = "mkl")] @@ -601,12 +633,8 @@ impl UnaryOpT for Erf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - F8E4M3::from_f64(Self::f64(v.to_f64())) - } - #[inline(always)] fn f32(v: f32) -> f32 { - crate::cpu::erf::erf_f32(v) + Self::f64(v as f64) as f32 } #[inline(always)] fn f64(v: f64) -> f64 { @@ -621,9 +649,21 @@ impl UnaryOpT for Erf { 0 } #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + f8e4m3::from_f64(Self::f64(v.to_f64())) + } } /// Silu operation @@ -639,10 +679,6 @@ impl UnaryOpT for Silu { v / (f16::ONE + (-v).exp()) } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - v / (F8E4M3::ONE + (-v).exp()) - } - #[inline(always)] fn f32(v: f32) -> f32 { v / (1.0 + (-v).exp()) } @@ -659,9 +695,21 @@ impl UnaryOpT for Silu { 0 } #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v / (f8e4m3::ONE + (-v).exp()) + } const KERNEL: &'static str = "usilu"; #[cfg(feature = "mkl")] @@ -714,10 +762,6 @@ impl UnaryOpT for Abs { v.abs() } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - v.abs() - } - #[inline(always)] fn f32(v: f32) -> f32 { v.abs() } @@ -734,9 +778,21 @@ impl UnaryOpT for Abs { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v.abs() + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v.abs() + } + #[inline(always)] fn i64(v: i64) -> i64 { v.abs() } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.abs() + } } impl UnaryOpT for Ceil { @@ -752,10 +808,6 @@ impl UnaryOpT for Ceil { v.ceil() } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - v.ceil() - } - #[inline(always)] fn f32(v: f32) -> f32 { v.ceil() } @@ -772,9 +824,21 @@ impl UnaryOpT for Ceil { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.ceil() + } } impl UnaryOpT for Floor { @@ -790,10 +854,6 @@ impl UnaryOpT for Floor { v.floor() } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - v.floor() - } - #[inline(always)] fn f32(v: f32) -> f32 { v.floor() } @@ -810,9 +870,21 @@ impl UnaryOpT for Floor { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.floor() + } } impl UnaryOpT for Round { @@ -828,10 +900,6 @@ impl UnaryOpT for Round { v.round() } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - v.round() - } - #[inline(always)] fn f32(v: f32) -> f32 { v.round() } @@ -848,9 +916,21 @@ impl UnaryOpT for Round { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.round() + } } impl UnaryOpT for GeluErf { @@ -866,12 +946,8 @@ impl UnaryOpT for GeluErf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - F8E4M3::from_f64(Self::f64(v.to_f64())) - } - #[inline(always)] fn f32(v: f32) -> f32 { - (crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v + Self::f64(v as f64) as f32 } #[inline(always)] fn f64(v: f64) -> f64 { @@ -886,9 +962,21 @@ impl UnaryOpT for GeluErf { 0 } #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + f8e4m3::from_f64(Self::f64(v.to_f64())) + } } impl UnaryOpT for Relu { @@ -904,10 +992,6 @@ impl UnaryOpT for Relu { v.max(f16::ZERO) } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - v.max(F8E4M3::ZERO) - } - #[inline(always)] fn f32(v: f32) -> f32 { v.max(0f32) } @@ -924,8 +1008,20 @@ impl UnaryOpT for Relu { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v.max(0) + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v.max(0) + } + #[inline(always)] fn i64(v: i64) -> i64 { - v + v.max(0) + } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.max(f8e4m3::ZERO) } } @@ -1006,11 +1102,6 @@ impl UnaryOpT for Sign { f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) } #[inline(always)] - fn f8e4m3(v: F8E4M3) -> F8E4M3 { - F8E4M3::from((v > F8E4M3::ZERO) as i8 as f32) - - F8E4M3::from((v < F8E4M3::ZERO) as i8 as f32) - } - #[inline(always)] fn f32(v: f32) -> f32 { f32::from(v > 0.) - f32::from(v < 0.) } @@ -1027,7 +1118,25 @@ impl UnaryOpT for Sign { u32::min(1, v) } #[inline(always)] + fn i16(v: i16) -> i16 { + (v > 0) as i16 - (v < 0) as i16 + } + #[inline(always)] + fn i32(v: i32) -> i32 { + (v > 0) as i32 - (v < 0) as i32 + } + #[inline(always)] fn i64(v: i64) -> i64 { (v > 0) as i64 - (v < 0) as i64 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + if v > f8e4m3::ZERO { + f8e4m3::ONE + } else if v < f8e4m3::ZERO { + -f8e4m3::ONE + } else { + f8e4m3::ZERO + } + } } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index d3b80fccc3..b633b478f7 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -9,8 +9,10 @@ //! Tensors can also be serialized to safetensor format using the `save` function or //! `Tensor::save_safetensors` method. //! +use crate::op::BackpropOp; +use crate::storage::Storage; +use crate::tensor::from_storage; use crate::{DType, Device, Error, Result, Tensor, WithDType}; -use float8::F8E4M3; use safetensors::tensor as st; use safetensors::tensor::SafeTensors; use std::borrow::Cow; @@ -22,12 +24,18 @@ impl From for st::Dtype { match value { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, + DType::I16 => st::Dtype::I16, + DType::I32 => st::Dtype::I32, DType::I64 => st::Dtype::I64, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, DType::F64 => st::Dtype::F64, DType::F8E4M3 => st::Dtype::F8_E4M3, + DType::F6E2M3 => st::Dtype::F6_E2M3, + DType::F6E3M2 => st::Dtype::F6_E3M2, + DType::F4 => st::Dtype::F4, + DType::F8E8M0 => st::Dtype::F8_E8M0, } } } @@ -38,12 +46,18 @@ impl TryFrom for DType { match value { st::Dtype::U8 => Ok(DType::U8), st::Dtype::U32 => Ok(DType::U32), + st::Dtype::I16 => Ok(DType::I16), + st::Dtype::I32 => Ok(DType::I32), st::Dtype::I64 => Ok(DType::I64), st::Dtype::BF16 => Ok(DType::BF16), st::Dtype::F16 => Ok(DType::F16), st::Dtype::F32 => Ok(DType::F32), st::Dtype::F64 => Ok(DType::F64), st::Dtype::F8_E4M3 => Ok(DType::F8E4M3), + st::Dtype::F6_E2M3 => Ok(DType::F6E2M3), + st::Dtype::F6_E3M2 => Ok(DType::F6E3M2), + st::Dtype::F4 => Ok(DType::F4), + st::Dtype::F8_E8M0 => Ok(DType::F8E8M0), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } @@ -201,53 +215,185 @@ impl Tensor { match dtype { DType::U8 => convert_slice::(data, shape, device), DType::U32 => convert_slice::(data, shape, device), + DType::I16 => convert_slice::(data, shape, device), + DType::I32 => convert_slice::(data, shape, device), DType::I64 => convert_slice::(data, shape, device), DType::BF16 => convert_slice::(data, shape, device), DType::F16 => convert_slice::(data, shape, device), DType::F32 => convert_slice::(data, shape, device), DType::F64 => convert_slice::(data, shape, device), - DType::F8E4M3 => convert_slice::(data, shape, device), + DType::F8E4M3 => convert_slice::(data, shape, device), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + // For dummy types, create storage with raw bytes + let storage = match device { + Device::Cpu => { + let cpu_storage = match dtype { + DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()), + DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()), + DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()), + DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()), + _ => unreachable!(), + }; + Storage::Cpu(cpu_storage) + } + #[cfg(feature = "cuda")] + Device::Cuda(device) => { + let mut slice = unsafe { device.alloc::(data.len())? }; + device.memcpy_htod(data, &mut slice)?; + + let slice = match dtype { + DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice), + DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice), + DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice), + DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice), + _ => unreachable!(), + }; + let storage = crate::cuda_backend::CudaStorage { + slice, + device: device.clone(), + }; + Storage::Cuda(storage) + } + #[cfg(not(feature = "cuda"))] + Device::Cuda(_) => { + return Err(Error::Msg("CUDA support not compiled".to_string())); + } + #[cfg(feature = "metal")] + Device::Metal(device) => { + let buffer = device.new_buffer_with_data(data)?; + + let storage = crate::metal_backend::MetalStorage::new( + buffer, + device.clone(), + data.len(), + dtype, + ); + Storage::Metal(storage) + } + #[cfg(not(feature = "metal"))] + Device::Metal(_) => { + return Err(Error::Msg("Metal support not compiled".to_string())); + } + }; + + let op = BackpropOp::none(); + Ok(from_storage(storage, shape, op, false)) + } } } } fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { - st::Dtype::I8 => { - let conv = |x| Ok(i64::from(x)); - convert_with_cast_::(view, device, conv) - } st::Dtype::U8 => convert_::(view, device), st::Dtype::U16 => { let conv = |x| Ok(u32::from(x)); convert_with_cast_::(view, device, conv) } st::Dtype::U32 => convert_::(view, device), - st::Dtype::I32 => { - let conv = |x| Ok(i64::from(x)); - convert_with_cast_::(view, device, conv) - } + st::Dtype::I16 => convert_::(view, device), + st::Dtype::I32 => convert_::(view, device), st::Dtype::I64 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), st::Dtype::F16 => convert_::(view, device), st::Dtype::F32 => convert_::(view, device), st::Dtype::F64 => convert_::(view, device), + st::Dtype::F8_E4M3 => convert_::(view, device), + st::Dtype::F6_E2M3 | st::Dtype::F6_E3M2 | st::Dtype::F4 | st::Dtype::F8_E8M0 => { + // For dummy types, we need to handle loading by creating a dummy tensor + // Since these types don't have actual data representation, we'll create + // a tensor that indicates it's a dummy type + convert_dummy(view, device) + } dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } +fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result { + // For dummy types, we'll create the appropriate storage variant that preserves + // both the raw data and the correct dtype + let (dtype, _dtype_name) = match view.dtype() { + st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"), + st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"), + st::Dtype::F4 => (DType::F4, "F4 (MX4)"), + st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"), + _ => unreachable!("convert_dummy called with non-dummy dtype"), + }; + + // Load the raw bytes + let data = view.data(); + let shape = view.shape(); + + // Create storage with the appropriate dummy type variant + let storage = match device { + Device::Cpu => { + let cpu_storage = match dtype { + DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()), + DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()), + DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()), + DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()), + _ => unreachable!(), + }; + Storage::Cpu(cpu_storage) + } + #[cfg(feature = "cuda")] + Device::Cuda(device) => { + let mut slice = unsafe { device.alloc::(data.len())? }; + device.memcpy_htod(data, &mut slice)?; + + let slice = match dtype { + DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice), + DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice), + DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice), + DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice), + _ => unreachable!(), + }; + let storage = crate::cuda_backend::CudaStorage { + slice, + device: device.clone(), + }; + Storage::Cuda(storage) + } + #[cfg(not(feature = "cuda"))] + Device::Cuda(_) => { + return Err(Error::Msg("CUDA support not compiled".to_string())); + } + #[cfg(feature = "metal")] + Device::Metal(device) => { + let buffer = device.new_buffer_with_data(data)?; + + let storage = + crate::metal_backend::MetalStorage::new(buffer, device.clone(), data.len(), dtype); + Storage::Metal(storage) + } + #[cfg(not(feature = "metal"))] + Device::Metal(_) => { + return Err(Error::Msg("Metal support not compiled".to_string())); + } + }; + + // Create tensor with correct dtype + let op = BackpropOp::none(); + Ok(from_storage(storage, shape, op, false)) +} + fn convert_back(tensor: &Tensor) -> Result> { // TODO: This makes an unnecessary copy when the tensor is on the cpu. let tensor = tensor.flatten_all()?; match tensor.dtype() { DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I16 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I64 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F64 => Ok(convert_back_::(tensor.to_vec1()?)), - DType::F8E4M3 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F8E4M3 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(Error::Msg("Internal error: dtype mismatch in storage".to_string()).bt()) + } } } @@ -482,17 +628,4 @@ mod tests { assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); std::fs::remove_file("multi.safetensors").unwrap(); } - - #[test] - fn load_i8() { - let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"I8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03"; - std::fs::write("test_i8.safetensors", bytes).unwrap(); - let weights = load("test_i8.safetensors", &Device::Cpu).unwrap(); - let tensor = weights.get("x").unwrap(); - assert_eq!(tensor.dims(), &[2]); - assert_eq!(tensor.dtype(), DType::I64); - let data: Vec = tensor.to_vec1().unwrap(); - assert_eq!(data, vec![1, 3]); - std::fs::remove_file("test_i8.safetensors").unwrap(); - } } diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs index 811c5b75e6..5c512c03b9 100644 --- a/candle-core/src/scalar.rs +++ b/candle-core/src/scalar.rs @@ -1,19 +1,21 @@ //! TensorScalar Enum and Trait //! use crate::{DType, Result, Tensor, WithDType}; -use float8::F8E4M3; +use float8::F8E4M3 as f8e4m3; use half::{bf16, f16}; #[derive(Debug, Clone, Copy, PartialEq)] pub enum Scalar { U8(u8), U32(u32), + I16(i16), + I32(i32), I64(i64), BF16(bf16), F16(f16), F32(f32), F64(f64), - F8E4M3(F8E4M3), + F8E4M3(f8e4m3), } impl From for Scalar { @@ -27,12 +29,17 @@ impl Scalar { match dtype { DType::U8 => Scalar::U8(0), DType::U32 => Scalar::U32(0), + DType::I16 => Scalar::I16(0), + DType::I32 => Scalar::I32(0), DType::I64 => Scalar::I64(0), DType::BF16 => Scalar::BF16(bf16::ZERO), DType::F16 => Scalar::F16(f16::ZERO), DType::F32 => Scalar::F32(0.0), DType::F64 => Scalar::F64(0.0), - DType::F8E4M3 => Scalar::F8E4M3(F8E4M3::ZERO), + DType::F8E4M3 => Scalar::F8E4M3(f8e4m3::ZERO), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + panic!("Cannot create zero scalar for dummy type {dtype:?}") + } } } @@ -40,12 +47,17 @@ impl Scalar { match dtype { DType::U8 => Scalar::U8(1), DType::U32 => Scalar::U32(1), + DType::I16 => Scalar::I16(1), + DType::I32 => Scalar::I32(1), DType::I64 => Scalar::I64(1), DType::BF16 => Scalar::BF16(bf16::ONE), DType::F16 => Scalar::F16(f16::ONE), DType::F32 => Scalar::F32(1.0), DType::F64 => Scalar::F64(1.0), - DType::F8E4M3 => Scalar::F8E4M3(F8E4M3::ONE), + DType::F8E4M3 => Scalar::F8E4M3(f8e4m3::ONE), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + panic!("Cannot create one scalar for dummy type {dtype:?}") + } } } @@ -53,6 +65,8 @@ impl Scalar { match self { Scalar::U8(_) => DType::U8, Scalar::U32(_) => DType::U32, + Scalar::I16(_) => DType::I16, + Scalar::I32(_) => DType::I32, Scalar::I64(_) => DType::I64, Scalar::BF16(_) => DType::BF16, Scalar::F16(_) => DType::F16, @@ -66,6 +80,8 @@ impl Scalar { match self { Scalar::U8(v) => *v as f64, Scalar::U32(v) => *v as f64, + Scalar::I16(v) => *v as f64, + Scalar::I32(v) => *v as f64, Scalar::I64(v) => *v as f64, Scalar::BF16(v) => v.to_f64(), Scalar::F16(v) => v.to_f64(), diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index efc8ad2b11..14ace645da 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -61,14 +61,6 @@ mod cuda { use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr}; use crate::{CudaDevice, WithDType}; - fn next_power_of_2(x: usize) -> usize { - let mut n = 1; - while n < x { - n *= 2 - } - n - } - impl crate::cuda_backend::Map1Any for ArgSort { fn f) -> S>( &self, @@ -94,7 +86,7 @@ mod cuda { let nrows = elem_count / ncols; let ncols_pad = next_power_of_2(ncols); let cfg = LaunchConfig { - grid_dim: (nrows as u32, 1, 1), + grid_dim: (1, nrows as u32, 1), block_dim: (ncols_pad as u32, 1, 1), shared_mem_bytes: (ncols_pad * std::mem::size_of::()) as u32, }; @@ -122,12 +114,33 @@ impl crate::CustomOp1 for ArgSort { let sort_indexes = match storage { crate::CpuStorage::U8(vs) => self.asort(vs, layout), crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I16(vs) => self.asort(vs, layout), + crate::CpuStorage::I32(vs) => self.asort(vs, layout), crate::CpuStorage::I64(vs) => self.asort(vs, layout), crate::CpuStorage::BF16(vs) => self.asort(vs, layout), crate::CpuStorage::F16(vs) => self.asort(vs, layout), crate::CpuStorage::F32(vs) => self.asort(vs, layout), crate::CpuStorage::F64(vs) => self.asort(vs, layout), crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout), + // Dummy types don't support sorting + crate::CpuStorage::F6E2M3(_) => { + return Err( + crate::Error::UnsupportedDTypeForOp(crate::DType::F6E2M3, "argsort").bt(), + ) + } + crate::CpuStorage::F6E3M2(_) => { + return Err( + crate::Error::UnsupportedDTypeForOp(crate::DType::F6E3M2, "argsort").bt(), + ) + } + crate::CpuStorage::F4(_) => { + return Err(crate::Error::UnsupportedDTypeForOp(crate::DType::F4, "argsort").bt()) + } + crate::CpuStorage::F8E8M0(_) => { + return Err( + crate::Error::UnsupportedDTypeForOp(crate::DType::F8E8M0, "argsort").bt(), + ) + } }; let sort_indexes = crate::CpuStorage::U32(sort_indexes); Ok((sort_indexes, layout.shape().into())) @@ -168,8 +181,15 @@ impl crate::CustomOp1 for ArgSort { DType::F64 => "asort_asc_f64", DType::U8 => "asort_asc_u8", DType::U32 => "asort_asc_u32", + DType::I16 => "asort_asc_i16", + DType::I32 => "asort_asc_i32", DType::I64 => "asort_asc_i64", - DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), + DType::F8E4M3 => "asort_asc_f8e4m3", + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(), + ) + } } } else { match storage.dtype() { @@ -179,14 +199,21 @@ impl crate::CustomOp1 for ArgSort { DType::F64 => "asort_desc_f64", DType::U8 => "asort_desc_u8", DType::U32 => "asort_desc_u32", + DType::I16 => "asort_desc_i16", + DType::I32 => "asort_desc_i32", DType::I64 => "asort_desc_i64", - DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), + DType::F8E4M3 => "asort_desc_f8e4m3", + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(), + ) + } } } }; let device = storage.device(); let kernels = device.kernels(); - let command_encoder = device.command_encoder()?; + let command_buffer = device.command_buffer()?; let el = layout.shape().elem_count(); let ncols = self.last_dim; let nrows = el / ncols; @@ -198,7 +225,7 @@ impl crate::CustomOp1 for ArgSort { } candle_metal_kernels::call_arg_sort( device.metal_device(), - &command_encoder, + &command_buffer, kernels, name, nrows, @@ -213,6 +240,15 @@ impl crate::CustomOp1 for ArgSort { } } +#[allow(unused)] +fn next_power_of_2(x: usize) -> usize { + let mut n = 1; + while n < x { + n *= 2 + } + n +} + impl Tensor { /// Returns the indices that sort the tensor along the last dimension. /// diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 5e739ed78c..d11031d6d9 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,7 +1,5 @@ #![allow(clippy::redundant_closure_call)] #![allow(clippy::useless_conversion)] -use float8::F8E4M3; -use half::{bf16, f16}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; @@ -10,6 +8,8 @@ use pyo3::ToPyObject; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use half::{bf16, f16}; +use float8::F8E4M3; #[cfg(feature = "mkl")] extern crate intel_mkl_src; @@ -206,7 +206,21 @@ trait MapDType { DType::F16 => self.f::(t), DType::F32 => self.f::(t), DType::F64 => self.f::(t), - DType::F8E4M3 => self.f::(t), + DType::I16 => Err(PyErr::new::( + "i16 dtype is not supported in Python interface", + )), + DType::I32 => Err(PyErr::new::( + "i32 dtype is not supported in Python interface", + )), + DType::F8E4M3 => Err(PyErr::new::( + "f8e4m3 dtype is not supported in Python interface", + )), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(PyErr::new::(format!( + "Dummy dtype {:?} is not supported", + t.dtype() + ))) + } } } } diff --git a/candle-transformers/src/models/deepseek2.rs b/candle-transformers/src/models/deepseek2.rs index 1b5d7a13f3..bd3cf76fcf 100644 --- a/candle-transformers/src/models/deepseek2.rs +++ b/candle-transformers/src/models/deepseek2.rs @@ -45,12 +45,33 @@ impl CustomOp1 for NonZero { let result = match storage { candle::CpuStorage::U8(vs) => self.nonzero(vs, layout), candle::CpuStorage::U32(vs) => self.nonzero(vs, layout), + candle::CpuStorage::I16(vs) => self.nonzero(vs, layout), + candle::CpuStorage::I32(vs) => self.nonzero(vs, layout), candle::CpuStorage::I64(vs) => self.nonzero(vs, layout), candle::CpuStorage::BF16(vs) => self.nonzero(vs, layout), candle::CpuStorage::F16(vs) => self.nonzero(vs, layout), candle::CpuStorage::F32(vs) => self.nonzero(vs, layout), candle::CpuStorage::F64(vs) => self.nonzero(vs, layout), candle::CpuStorage::F8E4M3(vs) => self.nonzero(vs, layout), + // Dummy types don't support nonzero operation + candle::CpuStorage::F6E2M3(_) => { + return Err( + candle::Error::UnsupportedDTypeForOp(candle::DType::F6E2M3, "nonzero").bt(), + ) + } + candle::CpuStorage::F6E3M2(_) => { + return Err( + candle::Error::UnsupportedDTypeForOp(candle::DType::F6E3M2, "nonzero").bt(), + ) + } + candle::CpuStorage::F4(_) => { + return Err(candle::Error::UnsupportedDTypeForOp(candle::DType::F4, "nonzero").bt()) + } + candle::CpuStorage::F8E8M0(_) => { + return Err( + candle::Error::UnsupportedDTypeForOp(candle::DType::F8E8M0, "nonzero").bt(), + ) + } }; let index_len = layout.dims().len(); let result_len = result.len() / index_len; From 8dd85fe57588585e43378fb123e872260dddd105 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 17 Nov 2025 14:06:44 -0500 Subject: [PATCH 2/6] Metal fixes --- candle-core/src/metal_backend/device.rs | 6 +++--- candle-core/src/metal_backend/mod.rs | 16 ++++++++++++++++ candle-core/src/sort.rs | 4 ++-- candle-pyo3/src/lib.rs | 4 ++-- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 728c6b8324..109d67f878 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -127,7 +127,7 @@ impl MetalDevice { } pub fn command_encoder(&self) -> Result { - let mut commands = self.commands.write().map_err(MetalError::from)?; + let commands = self.commands.write().map_err(MetalError::from)?; let (flush, command_encoder) = commands.command_encoder().map_err(MetalError::from)?; if flush { self.drop_unused_buffers()? @@ -136,7 +136,7 @@ impl MetalDevice { } pub fn blit_command_encoder(&self) -> Result { - let mut commands = self.commands.write().map_err(MetalError::from)?; + let commands = self.commands.write().map_err(MetalError::from)?; let (flush, command_encoder) = commands.blit_command_encoder().map_err(MetalError::from)?; if flush { self.drop_unused_buffers()? @@ -145,7 +145,7 @@ impl MetalDevice { } pub fn wait_until_completed(&self) -> Result<()> { - let mut commands = self.commands.write().map_err(MetalError::from)?; + let commands = self.commands.write().map_err(MetalError::from)?; commands.wait_until_completed().map_err(MetalError::from)?; Ok(()) } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 93817aeb95..48d151f3cf 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -476,6 +476,14 @@ impl BackendStorage for MetalStorage { DType::U8 => contiguous::const_set::U8, DType::F8E4M3 => crate::bail!("unsupported const-set f8e4m3"), DType::F64 => crate::bail!("unsupported const-set f64"), + DType::F4 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F8E8M0 + | DType::I16 + | DType::I32 => { + crate::bail!("unsupported const-set i32/i16/f6e2m3/f6e3m2/f4/f8e8m0") + } }; candle_metal_kernels::call_const_set_contiguous( &device.device, @@ -499,6 +507,14 @@ impl BackendStorage for MetalStorage { DType::U8 => strided::const_set::U8, DType::F8E4M3 => crate::bail!("unsupported const-set f8e4m3"), DType::F64 => crate::bail!("unsupported const-set f64"), + DType::F4 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F8E8M0 + | DType::I16 + | DType::I32 => { + crate::bail!("unsupported const-set i32/i16/f6e2m3/f6e3m2/f4/f8e8m0") + } }; candle_metal_kernels::call_const_set_strided( &device.device, diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 14ace645da..c6cadaaeda 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -213,7 +213,7 @@ impl crate::CustomOp1 for ArgSort { }; let device = storage.device(); let kernels = device.kernels(); - let command_buffer = device.command_buffer()?; + let command_encoder = device.command_encoder()?; let el = layout.shape().elem_count(); let ncols = self.last_dim; let nrows = el / ncols; @@ -225,7 +225,7 @@ impl crate::CustomOp1 for ArgSort { } candle_metal_kernels::call_arg_sort( device.metal_device(), - &command_buffer, + &command_encoder, kernels, name, nrows, diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index d11031d6d9..858d94243c 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,5 +1,7 @@ #![allow(clippy::redundant_closure_call)] #![allow(clippy::useless_conversion)] +use float8::F8E4M3; +use half::{bf16, f16}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; @@ -8,8 +10,6 @@ use pyo3::ToPyObject; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use half::{bf16, f16}; -use float8::F8E4M3; #[cfg(feature = "mkl")] extern crate intel_mkl_src; From f389c1f33918eac4ee03be3ec607ec2b18f70f6e Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 17 Nov 2025 17:53:08 -0500 Subject: [PATCH 3/6] Fix candle-onnx build --- candle-onnx/src/eval.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 59058977e0..ce44c361d6 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -774,6 +774,14 @@ fn simple_eval_( DType::F32 => arange_step!(f32), DType::F64 => arange_step!(f64), DType::F8E4M3 => arange_step!(f32), + DType::I32 + | DType::I16 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F4 + | DType::F8E8M0 => { + bail!("unsupported Range type i32/i16/f6e2m3/f6e3m2/f4/f8e8m0") + } }; values.insert(node.output[0].clone(), output); @@ -1695,7 +1703,15 @@ fn simple_eval_( let input = get(&node.input[0])?; let dt = input.dtype(); match dt { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 + | DType::U32 + | DType::I64 + | DType::I32 + | DType::I16 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F4 + | DType::F8E8M0 => { bail!( "unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str() From ccaa447053e3ac7e21b5c6892cb09f84f5f6937c Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 19 Nov 2025 19:54:52 -0500 Subject: [PATCH 4/6] Apply review comments --- candle-core/src/cuda_backend/device.rs | 7 +++++-- candle-core/src/cuda_backend/mod.rs | 5 ++--- candle-core/src/lib.rs | 2 +- candle-core/src/metal_backend/mod.rs | 6 +++--- candle-core/src/op.rs | 4 ++-- candle-core/src/safetensors.rs | 13 +++++++++++++ candle-core/src/sort.rs | 2 +- 7 files changed, 27 insertions(+), 12 deletions(-) diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index b3526ed7e5..a46ea3a698 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -3,6 +3,7 @@ use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::CudaFunction; +use float8::F8E4M3; use half::{bf16, f16}; use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; @@ -359,7 +360,8 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F64(data) } DType::F8E4M3 => { - return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::F8E4M3(data) } DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { return Err( @@ -512,7 +514,8 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F64(data) } DType::F8E4M3 => { - return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) + let data = self.alloc::(elem_count)?; + CudaStorageSlice::F8E4M3(data) } DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { return Err( diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 51edd5de44..027110fd1c 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -2117,9 +2117,8 @@ impl BackendStorage for CudaStorage { (S::F16(s), S::F16(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f16"), (S::F32(s), S::F32(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f32"), (S::F64(s), S::F64(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f64"), - (S::F8E4M3(_s), S::F8E4M3(_d)) => { - Err(CudaError::InternalError("copy2d not supported for f8e4m3"))? - } + (S::F8E4M3(s), S::F8E4M3(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_u8"), + (S::F8E8M0(s), S::F8E8M0(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_u8"), _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, }; let func = dev.get_or_load_func(kname, &kernels::FILL)?; diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 068f3a340e..65c9f1667c 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -44,7 +44,7 @@ //! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use. //! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models. //! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python. -//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models. +//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implementation of many published transformer models. //! #[cfg(feature = "accelerate")] diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 48d151f3cf..d3ab0da902 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -452,7 +452,7 @@ impl BackendStorage for MetalStorage { let kernel_name = match dtype { DType::F16 => contiguous_tiled::const_set::HALF, DType::BF16 => contiguous_tiled::const_set::BFLOAT, - _ => crate::bail!("internal bug in const_set"), + _ => unreachable!(), }; candle_metal_kernels::call_const_set_contiguous_tiled( &device.device, @@ -482,7 +482,7 @@ impl BackendStorage for MetalStorage { | DType::F8E8M0 | DType::I16 | DType::I32 => { - crate::bail!("unsupported const-set i32/i16/f6e2m3/f6e3m2/f4/f8e8m0") + return Err(Error::UnsupportedDTypeForOp(dtype, "const-set").bt()) } }; candle_metal_kernels::call_const_set_contiguous( @@ -513,7 +513,7 @@ impl BackendStorage for MetalStorage { | DType::F8E8M0 | DType::I16 | DType::I32 => { - crate::bail!("unsupported const-set i32/i16/f6e2m3/f6e3m2/f4/f8e8m0") + return Err(Error::UnsupportedDTypeForOp(dtype, "const-set").bt()) } }; candle_metal_kernels::call_const_set_strided( diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 15d04729d0..7d94e51d0d 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,4 +1,4 @@ -//! Tensor Opertion Enums and Traits +//! Tensor Operation Enums and Traits //! #![allow(clippy::redundant_closure_call)] use crate::Tensor; @@ -634,7 +634,7 @@ impl UnaryOpT for Erf { } #[inline(always)] fn f32(v: f32) -> f32 { - Self::f64(v as f64) as f32 + crate::cpu::erf::erf_f32(v) } #[inline(always)] fn f64(v: f64) -> f64 { diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index b633b478f7..a6d961ae46 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -628,4 +628,17 @@ mod tests { assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); std::fs::remove_file("multi.safetensors").unwrap(); } + + #[test] + fn load_i8() { + let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"I8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03"; + std::fs::write("test_i8.safetensors", bytes).unwrap(); + let weights = load("test_i8.safetensors", &Device::Cpu).unwrap(); + let tensor = weights.get("x").unwrap(); + assert_eq!(tensor.dims(), &[2]); + assert_eq!(tensor.dtype(), DType::I64); + let data: Vec = tensor.to_vec1().unwrap(); + assert_eq!(data, vec![1, 3]); + std::fs::remove_file("test_i8.safetensors").unwrap(); + } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index c6cadaaeda..80a7833437 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -86,7 +86,7 @@ mod cuda { let nrows = elem_count / ncols; let ncols_pad = next_power_of_2(ncols); let cfg = LaunchConfig { - grid_dim: (1, nrows as u32, 1), + grid_dim: (nrows as u32, 1, 1), block_dim: (ncols_pad as u32, 1, 1), shared_mem_bytes: (ncols_pad * std::mem::size_of::()) as u32, }; From 68d17ab6eeb9856e48e2acbd2fde0772a43070a4 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 19 Nov 2025 20:29:23 -0500 Subject: [PATCH 5/6] Residual fixes --- candle-core/src/cuda_backend/mod.rs | 2 +- candle-core/src/safetensors.rs | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 027110fd1c..399900fc8c 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1349,7 +1349,7 @@ impl BackendStorage for CudaStorage { S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"), S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"), S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"), - S::F8E4M3(s) => (slice_ptr(s, src_o), "const_set_f8e4m3"), + S::F8E4M3(s) => (slice_ptr(s, src_o), "const_set_f8_e4m3"), S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => { return Err(CudaError::UnsupportedDtype { dtype: self.dtype(), diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index a6d961ae46..bec233b614 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -630,15 +630,15 @@ mod tests { } #[test] - fn load_i8() { - let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"I8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03"; - std::fs::write("test_i8.safetensors", bytes).unwrap(); - let weights = load("test_i8.safetensors", &Device::Cpu).unwrap(); + fn load_u8() { + let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"U8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03"; + std::fs::write("test_u8.safetensors", bytes).unwrap(); + let weights = load("test_u8.safetensors", &Device::Cpu).unwrap(); let tensor = weights.get("x").unwrap(); assert_eq!(tensor.dims(), &[2]); - assert_eq!(tensor.dtype(), DType::I64); - let data: Vec = tensor.to_vec1().unwrap(); + assert_eq!(tensor.dtype(), DType::U8); + let data: Vec = tensor.to_vec1().unwrap(); assert_eq!(data, vec![1, 3]); - std::fs::remove_file("test_i8.safetensors").unwrap(); + std::fs::remove_file("test_u8.safetensors").unwrap(); } } From 9f16dbe0f3aee7497fa7b16b5b9649c35a047c44 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 21 Nov 2025 06:13:32 -0500 Subject: [PATCH 6/6] Apply review comments --- candle-core/src/cpu_backend/mod.rs | 183 ++++++-------------------- candle-core/src/cuda_backend/utils.rs | 7 +- candle-core/src/op.rs | 4 +- candle-core/src/sort.rs | 2 +- 4 files changed, 43 insertions(+), 153 deletions(-) diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 7d35c9e52a..afa3797353 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -2,7 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; -use float8::F8E4M3 as f8e4m3; +use float8::F8E4M3; use half::{bf16, f16}; use rayon::prelude::*; @@ -10,10 +10,11 @@ mod utils; pub use utils::{ binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8, }; +mod conv2d; +use conv2d::Conv2D; const USE_IM2COL_CONV1D: bool = true; const USE_COL2IM_CONV1D_TR: bool = true; -const USE_IM2COL_CONV2D: bool = true; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // intercept the oom errors to avoid panicking and provide a proper error. @@ -28,7 +29,7 @@ pub enum CpuStorage { F16(Vec), F32(Vec), F64(Vec), - F8E4M3(Vec), + F8E4M3(Vec), // Dummy types that store raw bytes F6E2M3(Vec), F6E3M2(Vec), @@ -47,7 +48,7 @@ pub enum CpuStorageRef<'a> { F16(&'a [f16]), F32(&'a [f32]), F64(&'a [f64]), - F8E4M3(&'a [f8e4m3]), + F8E4M3(&'a [F8E4M3]), // Dummy types that store raw bytes F6E2M3(&'a [u8]), F6E3M2(&'a [u8]), @@ -1103,94 +1104,6 @@ impl Map2 for ConvTranspose1D<'_> { } } -struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); - -impl Map2 for Conv2D<'_> { - const OP: &'static str = "conv2d"; - fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { - let p = self.0; - let inp = &inp[inp_l.start_offset()..]; - let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; - let k = &k[k_l.start_offset()..]; - let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?; - let (out_h, out_w) = (p.out_h(), p.out_w()); - - // Output shape: [b_size, c_out, out_h, out_w]. - let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; - - // TODO: Avoid making this copy if `inp` already has the appropriate layout. - let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w]; - let cont_s0 = p.i_h * p.i_w * p.c_in; - let cont_s1 = p.i_w * p.c_in; - let cont_s2 = p.c_in; - for b_idx in 0..p.b_size { - for h_idx in 0..p.i_h { - for w_idx in 0..p.i_w { - for c_idx in 0..p.c_in { - let src_idx = - b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3; - let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx; - inp_cont[dst_idx] = inp[src_idx] - } - } - } - } - - for offset_h in 0..p.k_h { - for offset_w in 0..p.k_w { - (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { - let dst_idx = dst_c_idx * out_w * out_h; - let k_cont = (0..p.c_in) - .map(|c_in_idx| { - k[dst_c_idx * k_s0 - + c_in_idx * k_s1 - + offset_h * k_s2 - + offset_w * k_s3] - }) - .collect::>(); - for b_idx in 0..p.b_size { - let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w; - for dst_h in 0..out_h { - let dst_idx = dst_idx + dst_h * out_w; - let src_h = p.stride * dst_h + offset_h * p.dilation; - if src_h < p.padding || src_h >= p.i_h + p.padding { - continue; - } - let src_h = src_h - p.padding; - for dst_w in 0..out_w { - let dst_idx = dst_idx + dst_w; - let src_w = p.stride * dst_w + offset_w * p.dilation; - if src_w < p.padding || src_w >= p.i_w + p.padding { - continue; - } - let src_w = src_w - p.padding; - let inp_cont = &inp_cont - [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..]; - assert!(inp_cont.len() >= p.c_in); - assert!(k_cont.len() >= p.c_in); - let mut d = T::zero(); - unsafe { - T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) - } - let dst_p = dst.as_ptr(); - // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise - // the different tasks so no two threads can try to write at the same - // location. - unsafe { - let ptr = dst_p.add(dst_idx) as *mut T; - *ptr += d - } - } - } - } - }); - } - } - - Ok(dst) - } -} - struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); impl Map2 for ConvTranspose2D<'_> { @@ -2013,31 +1926,31 @@ impl BackendStorage for CpuStorage { } // Conversions to F8E4M3 (Self::U8(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); Ok(Self::F8E4M3(data)) } (Self::U32(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); Ok(Self::F8E4M3(data)) } (Self::I64(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); Ok(Self::F8E4M3(data)) } (Self::BF16(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v.to_f32())); + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32())); Ok(Self::F8E4M3(data)) } (Self::F16(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v.to_f32())); + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32())); Ok(Self::F8E4M3(data)) } (Self::F32(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, f8e4m3::from_f32); + let data = unary_map(storage, layout, F8E4M3::from_f32); Ok(Self::F8E4M3(data)) } (Self::F64(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, f8e4m3::from_f64); + let data = unary_map(storage, layout, F8E4M3::from_f64); Ok(Self::F8E4M3(data)) } (Self::F8E4M3(storage), DType::F8E4M3) => { @@ -2185,7 +2098,7 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } (Self::I16(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); Ok(Self::F8E4M3(data)) } // Conversions from I32 @@ -2218,7 +2131,7 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } (Self::I32(storage), DType::F8E4M3) => { - let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); Ok(Self::F8E4M3(data)) } // Dummy types - return error for all conversions to/from dummy types @@ -2345,7 +2258,7 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } Self::F8E4M3(storage) => { - let data = unary_map(storage, layout, |v| v.powf(f8e4m3::from_f64(e))); + let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e))); Ok(Self::F8E4M3(data)) } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "powf").bt()), @@ -2380,7 +2293,7 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } Self::F8E4M3(storage) => { - let data = unary_map(storage, layout, |v| elu(v, f8e4m3::from_f64(alpha))); + let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha))); Ok(Self::F8E4M3(data)) } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), @@ -2775,46 +2688,7 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result { - if !USE_IM2COL_CONV2D { - return Conv2D(params).map(self, l, kernel, kernel_l); - } - let op = Im2Col { - h_k: params.k_h, - w_k: params.k_w, - padding: params.padding, - stride: params.stride, - dilation: params.dilation, - }; - let col = op.map(self, l)?; - let b = params.b_size; - let n = params.c_out; - let (h_out, w_out) = (params.out_h(), params.out_w()); - let k = op.h_k * op.w_k * params.c_in; - let m = h_out * w_out; - let col_l = Layout::contiguous((b, m, k)); - let res = if kernel_l.is_contiguous() { - let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) - .transpose(1, 2)? - .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? - } else { - // Make the kernel contiguous if not already the case. - let mut kernel_c = unsafe { - self.device() - .alloc_uninit(kernel_l.shape(), kernel.dtype())? - }; - kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; - let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) - .transpose(1, 2)? - .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? - }; - let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) - .transpose(1, 2)? - .transpose(1, 3)?; - let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; - res.copy_strided_src(&mut res_t, 0, &res_l)?; - Ok(res_t) + Conv2D(params).map(self, l, kernel, kernel_l) } fn conv_transpose2d( @@ -3057,7 +2931,6 @@ impl BackendDevice for CpuDevice { | DType::I16 | DType::I32 | DType::I64 - | DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 @@ -3080,6 +2953,16 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max)) + .map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let uniform = @@ -3111,7 +2994,6 @@ impl BackendDevice for CpuDevice { | DType::I16 | DType::I32 | DType::I64 - | DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 @@ -3134,6 +3016,15 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std)) + .map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let normal = @@ -3231,7 +3122,7 @@ impl BackendDevice for CpuDevice { DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), DType::F64 => CpuStorage::F64(vec![0f64; elem_count]), - DType::F8E4M3 => CpuStorage::F8E4M3(vec![f8e4m3::ZERO; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]), DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { return Err(Error::UnsupportedDTypeForOp(dtype, "zeros").bt()) } diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index 014b9e6c39..582ef54f08 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -26,7 +26,8 @@ pub trait Map1 { S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?), - S::F8E4M3(_) | S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => { + S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, + S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => { crate::bail!("Map1 does not uspport this dtype."); } }; @@ -55,9 +56,7 @@ pub trait Map2 { (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), - (S::F8E4M3(_), S::F8E4M3(_)) => { - Err(CudaError::InternalError("Map2 not supported for F8E4M3"))? - } + (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; Ok(out) diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 7d94e51d0d..a4d5d6cb97 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -947,7 +947,7 @@ impl UnaryOpT for GeluErf { } #[inline(always)] fn f32(v: f32) -> f32 { - Self::f64(v as f64) as f32 + (crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v } #[inline(always)] fn f64(v: f64) -> f64 { @@ -975,7 +975,7 @@ impl UnaryOpT for GeluErf { } #[inline(always)] fn f8e4m3(v: f8e4m3) -> f8e4m3 { - f8e4m3::from_f64(Self::f64(v.to_f64())) + f8e4m3::from_f32(Self::f32(v.to_f32())) } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 80a7833437..6b68c11796 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -184,7 +184,7 @@ impl crate::CustomOp1 for ArgSort { DType::I16 => "asort_asc_i16", DType::I32 => "asort_asc_i32", DType::I64 => "asort_asc_i64", - DType::F8E4M3 => "asort_asc_f8e4m3", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { return Err( crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(),