22use crate :: backend:: { BackendDevice , BackendStorage } ;
33use crate :: op:: { BinaryOpT , CmpOp , ReduceOp , UnaryOpT } ;
44use crate :: { DType , Error , IntDType , Layout , Result , Shape , WithDType } ;
5- use float8:: F8E4M3 as f8e4m3 ;
5+ use float8:: F8E4M3 ;
66use half:: { bf16, f16} ;
77use rayon:: prelude:: * ;
88
99mod utils;
1010pub use utils:: {
1111 binary_map, binary_map_vec, unary_map, unary_map_vec, Map1 , Map1Any , Map2 , Map2InPlace , Map2U8 ,
1212} ;
13+ mod conv2d;
14+ use conv2d:: Conv2D ;
1315
1416const USE_IM2COL_CONV1D : bool = true ;
1517const USE_COL2IM_CONV1D_TR : bool = true ;
16- const USE_IM2COL_CONV2D : bool = true ;
1718
1819// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
1920// intercept the oom errors to avoid panicking and provide a proper error.
@@ -28,7 +29,7 @@ pub enum CpuStorage {
2829 F16 ( Vec < f16 > ) ,
2930 F32 ( Vec < f32 > ) ,
3031 F64 ( Vec < f64 > ) ,
31- F8E4M3 ( Vec < f8e4m3 > ) ,
32+ F8E4M3 ( Vec < F8E4M3 > ) ,
3233 // Dummy types that store raw bytes
3334 F6E2M3 ( Vec < u8 > ) ,
3435 F6E3M2 ( Vec < u8 > ) ,
@@ -47,7 +48,7 @@ pub enum CpuStorageRef<'a> {
4748 F16 ( & ' a [ f16 ] ) ,
4849 F32 ( & ' a [ f32 ] ) ,
4950 F64 ( & ' a [ f64 ] ) ,
50- F8E4M3 ( & ' a [ f8e4m3 ] ) ,
51+ F8E4M3 ( & ' a [ F8E4M3 ] ) ,
5152 // Dummy types that store raw bytes
5253 F6E2M3 ( & ' a [ u8 ] ) ,
5354 F6E3M2 ( & ' a [ u8 ] ) ,
@@ -1103,94 +1104,6 @@ impl Map2 for ConvTranspose1D<'_> {
11031104 }
11041105}
11051106
1106- struct Conv2D < ' a > ( & ' a crate :: conv:: ParamsConv2D ) ;
1107-
1108- impl Map2 for Conv2D < ' _ > {
1109- const OP : & ' static str = "conv2d" ;
1110- fn f < T : WithDType > ( & self , inp : & [ T ] , inp_l : & Layout , k : & [ T ] , k_l : & Layout ) -> Result < Vec < T > > {
1111- let p = self . 0 ;
1112- let inp = & inp[ inp_l. start_offset ( ) ..] ;
1113- let ( inp_s0, inp_s1, inp_s2, inp_s3) = crate :: shape:: dims4 ( inp_l. stride ( ) ) ?;
1114- let k = & k[ k_l. start_offset ( ) ..] ;
1115- let ( k_s0, k_s1, k_s2, k_s3) = crate :: shape:: dims4 ( k_l. stride ( ) ) ?;
1116- let ( out_h, out_w) = ( p. out_h ( ) , p. out_w ( ) ) ;
1117-
1118- // Output shape: [b_size, c_out, out_h, out_w].
1119- let dst = vec ! [ T :: zero( ) ; p. b_size * p. c_out * out_h * out_w] ;
1120-
1121- // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1122- let mut inp_cont = vec ! [ T :: zero( ) ; p. b_size * p. c_in * p. i_h * p. i_w] ;
1123- let cont_s0 = p. i_h * p. i_w * p. c_in ;
1124- let cont_s1 = p. i_w * p. c_in ;
1125- let cont_s2 = p. c_in ;
1126- for b_idx in 0 ..p. b_size {
1127- for h_idx in 0 ..p. i_h {
1128- for w_idx in 0 ..p. i_w {
1129- for c_idx in 0 ..p. c_in {
1130- let src_idx =
1131- b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1132- let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1133- inp_cont[ dst_idx] = inp[ src_idx]
1134- }
1135- }
1136- }
1137- }
1138-
1139- for offset_h in 0 ..p. k_h {
1140- for offset_w in 0 ..p. k_w {
1141- ( 0 ..p. c_out ) . into_par_iter ( ) . for_each ( |dst_c_idx| {
1142- let dst_idx = dst_c_idx * out_w * out_h;
1143- let k_cont = ( 0 ..p. c_in )
1144- . map ( |c_in_idx| {
1145- k[ dst_c_idx * k_s0
1146- + c_in_idx * k_s1
1147- + offset_h * k_s2
1148- + offset_w * k_s3]
1149- } )
1150- . collect :: < Vec < _ > > ( ) ;
1151- for b_idx in 0 ..p. b_size {
1152- let dst_idx = dst_idx + b_idx * p. c_out * out_h * out_w;
1153- for dst_h in 0 ..out_h {
1154- let dst_idx = dst_idx + dst_h * out_w;
1155- let src_h = p. stride * dst_h + offset_h * p. dilation ;
1156- if src_h < p. padding || src_h >= p. i_h + p. padding {
1157- continue ;
1158- }
1159- let src_h = src_h - p. padding ;
1160- for dst_w in 0 ..out_w {
1161- let dst_idx = dst_idx + dst_w;
1162- let src_w = p. stride * dst_w + offset_w * p. dilation ;
1163- if src_w < p. padding || src_w >= p. i_w + p. padding {
1164- continue ;
1165- }
1166- let src_w = src_w - p. padding ;
1167- let inp_cont = & inp_cont
1168- [ b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..] ;
1169- assert ! ( inp_cont. len( ) >= p. c_in) ;
1170- assert ! ( k_cont. len( ) >= p. c_in) ;
1171- let mut d = T :: zero ( ) ;
1172- unsafe {
1173- T :: vec_dot ( inp_cont. as_ptr ( ) , k_cont. as_ptr ( ) , & mut d, p. c_in )
1174- }
1175- let dst_p = dst. as_ptr ( ) ;
1176- // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
1177- // the different tasks so no two threads can try to write at the same
1178- // location.
1179- unsafe {
1180- let ptr = dst_p. add ( dst_idx) as * mut T ;
1181- * ptr += d
1182- }
1183- }
1184- }
1185- }
1186- } ) ;
1187- }
1188- }
1189-
1190- Ok ( dst)
1191- }
1192- }
1193-
11941107struct ConvTranspose2D < ' a > ( & ' a crate :: conv:: ParamsConvTranspose2D ) ;
11951108
11961109impl Map2 for ConvTranspose2D < ' _ > {
@@ -2013,31 +1926,31 @@ impl BackendStorage for CpuStorage {
20131926 }
20141927 // Conversions to F8E4M3
20151928 ( Self :: U8 ( storage) , DType :: F8E4M3 ) => {
2016- let data = unary_map ( storage, layout, |v| f8e4m3 :: from_f32 ( v as f32 ) ) ;
1929+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v as f32 ) ) ;
20171930 Ok ( Self :: F8E4M3 ( data) )
20181931 }
20191932 ( Self :: U32 ( storage) , DType :: F8E4M3 ) => {
2020- let data = unary_map ( storage, layout, |v| f8e4m3 :: from_f32 ( v as f32 ) ) ;
1933+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v as f32 ) ) ;
20211934 Ok ( Self :: F8E4M3 ( data) )
20221935 }
20231936 ( Self :: I64 ( storage) , DType :: F8E4M3 ) => {
2024- let data = unary_map ( storage, layout, |v| f8e4m3 :: from_f32 ( v as f32 ) ) ;
1937+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v as f32 ) ) ;
20251938 Ok ( Self :: F8E4M3 ( data) )
20261939 }
20271940 ( Self :: BF16 ( storage) , DType :: F8E4M3 ) => {
2028- let data = unary_map ( storage, layout, |v| f8e4m3 :: from_f32 ( v. to_f32 ( ) ) ) ;
1941+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v. to_f32 ( ) ) ) ;
20291942 Ok ( Self :: F8E4M3 ( data) )
20301943 }
20311944 ( Self :: F16 ( storage) , DType :: F8E4M3 ) => {
2032- let data = unary_map ( storage, layout, |v| f8e4m3 :: from_f32 ( v. to_f32 ( ) ) ) ;
1945+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v. to_f32 ( ) ) ) ;
20331946 Ok ( Self :: F8E4M3 ( data) )
20341947 }
20351948 ( Self :: F32 ( storage) , DType :: F8E4M3 ) => {
2036- let data = unary_map ( storage, layout, f8e4m3 :: from_f32) ;
1949+ let data = unary_map ( storage, layout, F8E4M3 :: from_f32) ;
20371950 Ok ( Self :: F8E4M3 ( data) )
20381951 }
20391952 ( Self :: F64 ( storage) , DType :: F8E4M3 ) => {
2040- let data = unary_map ( storage, layout, f8e4m3 :: from_f64) ;
1953+ let data = unary_map ( storage, layout, F8E4M3 :: from_f64) ;
20411954 Ok ( Self :: F8E4M3 ( data) )
20421955 }
20431956 ( Self :: F8E4M3 ( storage) , DType :: F8E4M3 ) => {
@@ -2185,7 +2098,7 @@ impl BackendStorage for CpuStorage {
21852098 Ok ( Self :: F64 ( data) )
21862099 }
21872100 ( Self :: I16 ( storage) , DType :: F8E4M3 ) => {
2188- let data = unary_map ( storage, layout, |v| f8e4m3 :: from_f32 ( v as f32 ) ) ;
2101+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v as f32 ) ) ;
21892102 Ok ( Self :: F8E4M3 ( data) )
21902103 }
21912104 // Conversions from I32
@@ -2218,7 +2131,7 @@ impl BackendStorage for CpuStorage {
22182131 Ok ( Self :: F64 ( data) )
22192132 }
22202133 ( Self :: I32 ( storage) , DType :: F8E4M3 ) => {
2221- let data = unary_map ( storage, layout, |v| f8e4m3 :: from_f32 ( v as f32 ) ) ;
2134+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v as f32 ) ) ;
22222135 Ok ( Self :: F8E4M3 ( data) )
22232136 }
22242137 // Dummy types - return error for all conversions to/from dummy types
@@ -2345,7 +2258,7 @@ impl BackendStorage for CpuStorage {
23452258 Ok ( Self :: F64 ( data) )
23462259 }
23472260 Self :: F8E4M3 ( storage) => {
2348- let data = unary_map ( storage, layout, |v| v. powf ( f8e4m3 :: from_f64 ( e) ) ) ;
2261+ let data = unary_map ( storage, layout, |v| v. powf ( F8E4M3 :: from_f64 ( e) ) ) ;
23492262 Ok ( Self :: F8E4M3 ( data) )
23502263 }
23512264 Self :: U8 ( _) => Err ( Error :: UnsupportedDTypeForOp ( DType :: U8 , "powf" ) . bt ( ) ) ,
@@ -2380,7 +2293,7 @@ impl BackendStorage for CpuStorage {
23802293 Ok ( Self :: F64 ( data) )
23812294 }
23822295 Self :: F8E4M3 ( storage) => {
2383- let data = unary_map ( storage, layout, |v| elu ( v, f8e4m3 :: from_f64 ( alpha) ) ) ;
2296+ let data = unary_map ( storage, layout, |v| elu ( v, F8E4M3 :: from_f64 ( alpha) ) ) ;
23842297 Ok ( Self :: F8E4M3 ( data) )
23852298 }
23862299 Self :: U8 ( _) => Err ( Error :: UnsupportedDTypeForOp ( DType :: U8 , "elu" ) . bt ( ) ) ,
@@ -2775,46 +2688,7 @@ impl BackendStorage for CpuStorage {
27752688 kernel_l : & Layout ,
27762689 params : & crate :: conv:: ParamsConv2D ,
27772690 ) -> Result < Self > {
2778- if !USE_IM2COL_CONV2D {
2779- return Conv2D ( params) . map ( self , l, kernel, kernel_l) ;
2780- }
2781- let op = Im2Col {
2782- h_k : params. k_h ,
2783- w_k : params. k_w ,
2784- padding : params. padding ,
2785- stride : params. stride ,
2786- dilation : params. dilation ,
2787- } ;
2788- let col = op. map ( self , l) ?;
2789- let b = params. b_size ;
2790- let n = params. c_out ;
2791- let ( h_out, w_out) = ( params. out_h ( ) , params. out_w ( ) ) ;
2792- let k = op. h_k * op. w_k * params. c_in ;
2793- let m = h_out * w_out;
2794- let col_l = Layout :: contiguous ( ( b, m, k) ) ;
2795- let res = if kernel_l. is_contiguous ( ) {
2796- let kernel_l = Layout :: contiguous_with_offset ( ( 1 , n, k) , kernel_l. start_offset ( ) )
2797- . transpose ( 1 , 2 ) ?
2798- . broadcast_as ( ( b, k, n) ) ?;
2799- col. matmul ( kernel, ( b, m, n, k) , & col_l, & kernel_l) ?
2800- } else {
2801- // Make the kernel contiguous if not already the case.
2802- let mut kernel_c = unsafe {
2803- self . device ( )
2804- . alloc_uninit ( kernel_l. shape ( ) , kernel. dtype ( ) ) ?
2805- } ;
2806- kernel. copy_strided_src ( & mut kernel_c, 0 , kernel_l) ?;
2807- let kernel_l = Layout :: contiguous_with_offset ( ( 1 , n, k) , kernel_l. start_offset ( ) )
2808- . transpose ( 1 , 2 ) ?
2809- . broadcast_as ( ( b, k, n) ) ?;
2810- col. matmul ( kernel, ( b, m, n, k) , & col_l, & kernel_l) ?
2811- } ;
2812- let res_l = Layout :: contiguous ( ( b, h_out, w_out, params. c_out ) )
2813- . transpose ( 1 , 2 ) ?
2814- . transpose ( 1 , 3 ) ?;
2815- let mut res_t = unsafe { self . device ( ) . alloc_uninit ( res_l. shape ( ) , res. dtype ( ) ) ? } ;
2816- res. copy_strided_src ( & mut res_t, 0 , & res_l) ?;
2817- Ok ( res_t)
2691+ Conv2D ( params) . map ( self , l, kernel, kernel_l)
28182692 }
28192693
28202694 fn conv_transpose2d (
@@ -3057,7 +2931,6 @@ impl BackendDevice for CpuDevice {
30572931 | DType :: I16
30582932 | DType :: I32
30592933 | DType :: I64
3060- | DType :: F8E4M3
30612934 | DType :: F6E2M3
30622935 | DType :: F6E3M2
30632936 | DType :: F4
@@ -3080,6 +2953,16 @@ impl BackendDevice for CpuDevice {
30802953 }
30812954 Ok ( CpuStorage :: F16 ( data) )
30822955 }
2956+ DType :: F8E4M3 => {
2957+ let mut data = Vec :: with_capacity ( elem_count) ;
2958+ let uniform =
2959+ rand:: distr:: Uniform :: new ( F8E4M3 :: from_f64 ( min) , F8E4M3 :: from_f64 ( max) )
2960+ . map_err ( Error :: wrap) ?;
2961+ for _i in 0 ..elem_count {
2962+ data. push ( rng. sample :: < F8E4M3 , _ > ( uniform) )
2963+ }
2964+ Ok ( CpuStorage :: F8E4M3 ( data) )
2965+ }
30832966 DType :: F32 => {
30842967 let mut data = Vec :: with_capacity ( elem_count) ;
30852968 let uniform =
@@ -3111,7 +2994,6 @@ impl BackendDevice for CpuDevice {
31112994 | DType :: I16
31122995 | DType :: I32
31132996 | DType :: I64
3114- | DType :: F8E4M3
31152997 | DType :: F6E2M3
31162998 | DType :: F6E3M2
31172999 | DType :: F4
@@ -3134,6 +3016,15 @@ impl BackendDevice for CpuDevice {
31343016 }
31353017 Ok ( CpuStorage :: F16 ( data) )
31363018 }
3019+ DType :: F8E4M3 => {
3020+ let mut data = Vec :: with_capacity ( elem_count) ;
3021+ let normal = rand_distr:: Normal :: new ( F8E4M3 :: from_f64 ( mean) , F8E4M3 :: from_f64 ( std) )
3022+ . map_err ( Error :: wrap) ?;
3023+ for _i in 0 ..elem_count {
3024+ data. push ( normal. sample ( & mut rng) )
3025+ }
3026+ Ok ( CpuStorage :: F8E4M3 ( data) )
3027+ }
31373028 DType :: F32 => {
31383029 let mut data = Vec :: with_capacity ( elem_count) ;
31393030 let normal =
@@ -3231,7 +3122,7 @@ impl BackendDevice for CpuDevice {
32313122 DType :: F16 => CpuStorage :: F16 ( vec ! [ f16:: ZERO ; elem_count] ) ,
32323123 DType :: F32 => CpuStorage :: F32 ( vec ! [ 0f32 ; elem_count] ) ,
32333124 DType :: F64 => CpuStorage :: F64 ( vec ! [ 0f64 ; elem_count] ) ,
3234- DType :: F8E4M3 => CpuStorage :: F8E4M3 ( vec ! [ f8e4m3 :: ZERO ; elem_count] ) ,
3125+ DType :: F8E4M3 => CpuStorage :: F8E4M3 ( vec ! [ F8E4M3 :: ZERO ; elem_count] ) ,
32353126 DType :: F6E2M3 | DType :: F6E3M2 | DType :: F4 | DType :: F8E8M0 => {
32363127 return Err ( Error :: UnsupportedDTypeForOp ( dtype, "zeros" ) . bt ( ) )
32373128 }
0 commit comments