Skip to content

Commit 9f16dbe

Browse files
committed
Apply review comments
1 parent 68d17ab commit 9f16dbe

File tree

4 files changed

+43
-153
lines changed

4 files changed

+43
-153
lines changed

candle-core/src/cpu_backend/mod.rs

Lines changed: 37 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
use crate::backend::{BackendDevice, BackendStorage};
33
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
44
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
5-
use float8::F8E4M3 as f8e4m3;
5+
use float8::F8E4M3;
66
use half::{bf16, f16};
77
use rayon::prelude::*;
88

99
mod utils;
1010
pub use utils::{
1111
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
1212
};
13+
mod conv2d;
14+
use conv2d::Conv2D;
1315

1416
const USE_IM2COL_CONV1D: bool = true;
1517
const USE_COL2IM_CONV1D_TR: bool = true;
16-
const USE_IM2COL_CONV2D: bool = true;
1718

1819
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
1920
// intercept the oom errors to avoid panicking and provide a proper error.
@@ -28,7 +29,7 @@ pub enum CpuStorage {
2829
F16(Vec<f16>),
2930
F32(Vec<f32>),
3031
F64(Vec<f64>),
31-
F8E4M3(Vec<f8e4m3>),
32+
F8E4M3(Vec<F8E4M3>),
3233
// Dummy types that store raw bytes
3334
F6E2M3(Vec<u8>),
3435
F6E3M2(Vec<u8>),
@@ -47,7 +48,7 @@ pub enum CpuStorageRef<'a> {
4748
F16(&'a [f16]),
4849
F32(&'a [f32]),
4950
F64(&'a [f64]),
50-
F8E4M3(&'a [f8e4m3]),
51+
F8E4M3(&'a [F8E4M3]),
5152
// Dummy types that store raw bytes
5253
F6E2M3(&'a [u8]),
5354
F6E3M2(&'a [u8]),
@@ -1103,94 +1104,6 @@ impl Map2 for ConvTranspose1D<'_> {
11031104
}
11041105
}
11051106

1106-
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
1107-
1108-
impl Map2 for Conv2D<'_> {
1109-
const OP: &'static str = "conv2d";
1110-
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1111-
let p = self.0;
1112-
let inp = &inp[inp_l.start_offset()..];
1113-
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1114-
let k = &k[k_l.start_offset()..];
1115-
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1116-
let (out_h, out_w) = (p.out_h(), p.out_w());
1117-
1118-
// Output shape: [b_size, c_out, out_h, out_w].
1119-
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1120-
1121-
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
1122-
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1123-
let cont_s0 = p.i_h * p.i_w * p.c_in;
1124-
let cont_s1 = p.i_w * p.c_in;
1125-
let cont_s2 = p.c_in;
1126-
for b_idx in 0..p.b_size {
1127-
for h_idx in 0..p.i_h {
1128-
for w_idx in 0..p.i_w {
1129-
for c_idx in 0..p.c_in {
1130-
let src_idx =
1131-
b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1132-
let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1133-
inp_cont[dst_idx] = inp[src_idx]
1134-
}
1135-
}
1136-
}
1137-
}
1138-
1139-
for offset_h in 0..p.k_h {
1140-
for offset_w in 0..p.k_w {
1141-
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1142-
let dst_idx = dst_c_idx * out_w * out_h;
1143-
let k_cont = (0..p.c_in)
1144-
.map(|c_in_idx| {
1145-
k[dst_c_idx * k_s0
1146-
+ c_in_idx * k_s1
1147-
+ offset_h * k_s2
1148-
+ offset_w * k_s3]
1149-
})
1150-
.collect::<Vec<_>>();
1151-
for b_idx in 0..p.b_size {
1152-
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
1153-
for dst_h in 0..out_h {
1154-
let dst_idx = dst_idx + dst_h * out_w;
1155-
let src_h = p.stride * dst_h + offset_h * p.dilation;
1156-
if src_h < p.padding || src_h >= p.i_h + p.padding {
1157-
continue;
1158-
}
1159-
let src_h = src_h - p.padding;
1160-
for dst_w in 0..out_w {
1161-
let dst_idx = dst_idx + dst_w;
1162-
let src_w = p.stride * dst_w + offset_w * p.dilation;
1163-
if src_w < p.padding || src_w >= p.i_w + p.padding {
1164-
continue;
1165-
}
1166-
let src_w = src_w - p.padding;
1167-
let inp_cont = &inp_cont
1168-
[b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
1169-
assert!(inp_cont.len() >= p.c_in);
1170-
assert!(k_cont.len() >= p.c_in);
1171-
let mut d = T::zero();
1172-
unsafe {
1173-
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1174-
}
1175-
let dst_p = dst.as_ptr();
1176-
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
1177-
// the different tasks so no two threads can try to write at the same
1178-
// location.
1179-
unsafe {
1180-
let ptr = dst_p.add(dst_idx) as *mut T;
1181-
*ptr += d
1182-
}
1183-
}
1184-
}
1185-
}
1186-
});
1187-
}
1188-
}
1189-
1190-
Ok(dst)
1191-
}
1192-
}
1193-
11941107
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
11951108

11961109
impl Map2 for ConvTranspose2D<'_> {
@@ -2013,31 +1926,31 @@ impl BackendStorage for CpuStorage {
20131926
}
20141927
// Conversions to F8E4M3
20151928
(Self::U8(storage), DType::F8E4M3) => {
2016-
let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32));
1929+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
20171930
Ok(Self::F8E4M3(data))
20181931
}
20191932
(Self::U32(storage), DType::F8E4M3) => {
2020-
let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32));
1933+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
20211934
Ok(Self::F8E4M3(data))
20221935
}
20231936
(Self::I64(storage), DType::F8E4M3) => {
2024-
let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32));
1937+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
20251938
Ok(Self::F8E4M3(data))
20261939
}
20271940
(Self::BF16(storage), DType::F8E4M3) => {
2028-
let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v.to_f32()));
1941+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
20291942
Ok(Self::F8E4M3(data))
20301943
}
20311944
(Self::F16(storage), DType::F8E4M3) => {
2032-
let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v.to_f32()));
1945+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
20331946
Ok(Self::F8E4M3(data))
20341947
}
20351948
(Self::F32(storage), DType::F8E4M3) => {
2036-
let data = unary_map(storage, layout, f8e4m3::from_f32);
1949+
let data = unary_map(storage, layout, F8E4M3::from_f32);
20371950
Ok(Self::F8E4M3(data))
20381951
}
20391952
(Self::F64(storage), DType::F8E4M3) => {
2040-
let data = unary_map(storage, layout, f8e4m3::from_f64);
1953+
let data = unary_map(storage, layout, F8E4M3::from_f64);
20411954
Ok(Self::F8E4M3(data))
20421955
}
20431956
(Self::F8E4M3(storage), DType::F8E4M3) => {
@@ -2185,7 +2098,7 @@ impl BackendStorage for CpuStorage {
21852098
Ok(Self::F64(data))
21862099
}
21872100
(Self::I16(storage), DType::F8E4M3) => {
2188-
let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32));
2101+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
21892102
Ok(Self::F8E4M3(data))
21902103
}
21912104
// Conversions from I32
@@ -2218,7 +2131,7 @@ impl BackendStorage for CpuStorage {
22182131
Ok(Self::F64(data))
22192132
}
22202133
(Self::I32(storage), DType::F8E4M3) => {
2221-
let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32));
2134+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
22222135
Ok(Self::F8E4M3(data))
22232136
}
22242137
// Dummy types - return error for all conversions to/from dummy types
@@ -2345,7 +2258,7 @@ impl BackendStorage for CpuStorage {
23452258
Ok(Self::F64(data))
23462259
}
23472260
Self::F8E4M3(storage) => {
2348-
let data = unary_map(storage, layout, |v| v.powf(f8e4m3::from_f64(e)));
2261+
let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
23492262
Ok(Self::F8E4M3(data))
23502263
}
23512264
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "powf").bt()),
@@ -2380,7 +2293,7 @@ impl BackendStorage for CpuStorage {
23802293
Ok(Self::F64(data))
23812294
}
23822295
Self::F8E4M3(storage) => {
2383-
let data = unary_map(storage, layout, |v| elu(v, f8e4m3::from_f64(alpha)));
2296+
let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
23842297
Ok(Self::F8E4M3(data))
23852298
}
23862299
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
@@ -2775,46 +2688,7 @@ impl BackendStorage for CpuStorage {
27752688
kernel_l: &Layout,
27762689
params: &crate::conv::ParamsConv2D,
27772690
) -> Result<Self> {
2778-
if !USE_IM2COL_CONV2D {
2779-
return Conv2D(params).map(self, l, kernel, kernel_l);
2780-
}
2781-
let op = Im2Col {
2782-
h_k: params.k_h,
2783-
w_k: params.k_w,
2784-
padding: params.padding,
2785-
stride: params.stride,
2786-
dilation: params.dilation,
2787-
};
2788-
let col = op.map(self, l)?;
2789-
let b = params.b_size;
2790-
let n = params.c_out;
2791-
let (h_out, w_out) = (params.out_h(), params.out_w());
2792-
let k = op.h_k * op.w_k * params.c_in;
2793-
let m = h_out * w_out;
2794-
let col_l = Layout::contiguous((b, m, k));
2795-
let res = if kernel_l.is_contiguous() {
2796-
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2797-
.transpose(1, 2)?
2798-
.broadcast_as((b, k, n))?;
2799-
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2800-
} else {
2801-
// Make the kernel contiguous if not already the case.
2802-
let mut kernel_c = unsafe {
2803-
self.device()
2804-
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
2805-
};
2806-
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2807-
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2808-
.transpose(1, 2)?
2809-
.broadcast_as((b, k, n))?;
2810-
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2811-
};
2812-
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
2813-
.transpose(1, 2)?
2814-
.transpose(1, 3)?;
2815-
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
2816-
res.copy_strided_src(&mut res_t, 0, &res_l)?;
2817-
Ok(res_t)
2691+
Conv2D(params).map(self, l, kernel, kernel_l)
28182692
}
28192693

28202694
fn conv_transpose2d(
@@ -3057,7 +2931,6 @@ impl BackendDevice for CpuDevice {
30572931
| DType::I16
30582932
| DType::I32
30592933
| DType::I64
3060-
| DType::F8E4M3
30612934
| DType::F6E2M3
30622935
| DType::F6E3M2
30632936
| DType::F4
@@ -3080,6 +2953,16 @@ impl BackendDevice for CpuDevice {
30802953
}
30812954
Ok(CpuStorage::F16(data))
30822955
}
2956+
DType::F8E4M3 => {
2957+
let mut data = Vec::with_capacity(elem_count);
2958+
let uniform =
2959+
rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
2960+
.map_err(Error::wrap)?;
2961+
for _i in 0..elem_count {
2962+
data.push(rng.sample::<F8E4M3, _>(uniform))
2963+
}
2964+
Ok(CpuStorage::F8E4M3(data))
2965+
}
30832966
DType::F32 => {
30842967
let mut data = Vec::with_capacity(elem_count);
30852968
let uniform =
@@ -3111,7 +2994,6 @@ impl BackendDevice for CpuDevice {
31112994
| DType::I16
31122995
| DType::I32
31132996
| DType::I64
3114-
| DType::F8E4M3
31152997
| DType::F6E2M3
31162998
| DType::F6E3M2
31172999
| DType::F4
@@ -3134,6 +3016,15 @@ impl BackendDevice for CpuDevice {
31343016
}
31353017
Ok(CpuStorage::F16(data))
31363018
}
3019+
DType::F8E4M3 => {
3020+
let mut data = Vec::with_capacity(elem_count);
3021+
let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
3022+
.map_err(Error::wrap)?;
3023+
for _i in 0..elem_count {
3024+
data.push(normal.sample(&mut rng))
3025+
}
3026+
Ok(CpuStorage::F8E4M3(data))
3027+
}
31373028
DType::F32 => {
31383029
let mut data = Vec::with_capacity(elem_count);
31393030
let normal =
@@ -3231,7 +3122,7 @@ impl BackendDevice for CpuDevice {
32313122
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
32323123
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
32333124
DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
3234-
DType::F8E4M3 => CpuStorage::F8E4M3(vec![f8e4m3::ZERO; elem_count]),
3125+
DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
32353126
DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
32363127
return Err(Error::UnsupportedDTypeForOp(dtype, "zeros").bt())
32373128
}

candle-core/src/cuda_backend/utils.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ pub trait Map1 {
2626
S::F16(s) => S::F16(self.f(s, d, l)?),
2727
S::F32(s) => S::F32(self.f(s, d, l)?),
2828
S::F64(s) => S::F64(self.f(s, d, l)?),
29-
S::F8E4M3(_) | S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => {
29+
S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?,
30+
S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => {
3031
crate::bail!("Map1 does not uspport this dtype.");
3132
}
3233
};
@@ -55,9 +56,7 @@ pub trait Map2 {
5556
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
5657
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
5758
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
58-
(S::F8E4M3(_), S::F8E4M3(_)) => {
59-
Err(CudaError::InternalError("Map2 not supported for F8E4M3"))?
60-
}
59+
(S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?,
6160
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
6261
};
6362
Ok(out)

candle-core/src/op.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ impl UnaryOpT for GeluErf {
947947
}
948948
#[inline(always)]
949949
fn f32(v: f32) -> f32 {
950-
Self::f64(v as f64) as f32
950+
(crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
951951
}
952952
#[inline(always)]
953953
fn f64(v: f64) -> f64 {
@@ -975,7 +975,7 @@ impl UnaryOpT for GeluErf {
975975
}
976976
#[inline(always)]
977977
fn f8e4m3(v: f8e4m3) -> f8e4m3 {
978-
f8e4m3::from_f64(Self::f64(v.to_f64()))
978+
f8e4m3::from_f32(Self::f32(v.to_f32()))
979979
}
980980
}
981981

candle-core/src/sort.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ impl crate::CustomOp1 for ArgSort {
184184
DType::I16 => "asort_asc_i16",
185185
DType::I32 => "asort_asc_i32",
186186
DType::I64 => "asort_asc_i64",
187-
DType::F8E4M3 => "asort_asc_f8e4m3",
187+
DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
188188
DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
189189
return Err(
190190
crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(),

0 commit comments

Comments
 (0)