@@ -2745,7 +2745,7 @@ impl BackendStorage for CpuStorage {
27452745 let kernel_l = Layout :: contiguous_with_offset ( ( 1 , n, k) , kernel_l. start_offset ( ) )
27462746 . transpose ( 1 , 2 ) ?
27472747 . broadcast_as ( ( b, k, n) ) ?;
2748- col. matmul ( kernel, ( b, m, n, k) , & col_l, & kernel_l) ?
2748+ col. matmul_with_alpha ( kernel, None , ( b, m, n, k) , & col_l, & kernel_l) ?
27492749 } else {
27502750 // Make the kernel contiguous if not already the case.
27512751 let mut kernel_c = unsafe {
@@ -2756,7 +2756,7 @@ impl BackendStorage for CpuStorage {
27562756 let kernel_l = Layout :: contiguous_with_offset ( ( 1 , n, k) , kernel_l. start_offset ( ) )
27572757 . transpose ( 1 , 2 ) ?
27582758 . broadcast_as ( ( b, k, n) ) ?;
2759- col. matmul ( kernel, ( b, m, n, k) , & col_l, & kernel_l) ?
2759+ col. matmul_with_alpha ( kernel, None , ( b, m, n, k) , & col_l, & kernel_l) ?
27602760 } ;
27612761 let res_l = Layout :: contiguous ( ( b, l_out, params. c_out ) ) . transpose ( 1 , 2 ) ?;
27622762 let mut res_t = unsafe { self . device ( ) . alloc_uninit ( res_l. shape ( ) , res. dtype ( ) ) ? } ;
@@ -2797,8 +2797,9 @@ impl BackendStorage for CpuStorage {
27972797 vec ! [ 0 , k_size * c_out, 1 ] ,
27982798 kernel_l. start_offset ( ) ,
27992799 ) ;
2800- self . matmul (
2800+ self . matmul_with_alpha (
28012801 kernel,
2802+ None ,
28022803 (
28032804 b_size,
28042805 /* m */ l_in,
@@ -2942,14 +2943,39 @@ impl BackendStorage for CpuStorage {
29422943 }
29432944 }
29442945
2945- fn matmul (
2946+ fn matmul_with_alpha_beta (
29462947 & self ,
29472948 rhs : & Self ,
2949+ c : & mut Self ,
2950+ s : Option < f64 > ,
2951+ bmnk : ( usize , usize , usize , usize ) ,
2952+ lhs_l : & Layout ,
2953+ rhs_l : & Layout ,
2954+ c_layout : & Layout ,
2955+ ) -> Result < ( ) > {
2956+ let mm = self . matmul_with_alpha ( rhs, s, bmnk, lhs_l, rhs_l) ?;
2957+ let mm_l = Layout :: contiguous ( c_layout. shape ( ) ) ;
2958+ * c = c. binary_impl :: < crate :: op:: Add > ( & mm, c_layout, & mm_l) ?;
2959+ Ok ( ( ) )
2960+ }
2961+
2962+ fn matmul_with_alpha (
2963+ & self ,
2964+ rhs : & Self ,
2965+ s : Option < f64 > ,
29482966 bmnk : ( usize , usize , usize , usize ) ,
29492967 lhs_l : & Layout ,
29502968 rhs_l : & Layout ,
29512969 ) -> Result < Self > {
2952- MatMul ( bmnk) . map ( self , lhs_l, rhs, rhs_l)
2970+ let mm = MatMul ( bmnk) . map ( self , lhs_l, rhs, rhs_l) ?;
2971+ match s {
2972+ None => Ok ( mm) ,
2973+ Some ( alpha) => {
2974+ let ( b, m, n, _) = bmnk;
2975+ let mm_l = Layout :: contiguous ( ( b, m, n) ) ;
2976+ mm. affine ( & mm_l, alpha, 0.0 )
2977+ }
2978+ }
29532979 }
29542980
29552981 fn device ( & self ) -> & Self :: Device {
0 commit comments