@@ -702,7 +702,7 @@ static void matmul_cpu(bool transA, bool transB, size_t DIM_I, size_t DIM_J, siz
702702 int act , acc_scale_t scale , size_t relu6_shift , bool repeating_bias ) {
703703
704704 const int no_bias = D == NULL ;
705- if (DIM_I % 4 == 0 && DIM_J % 4 == 0 ) {
705+ if (! transA && ! transB && DIM_I % 4 == 0 && DIM_J % 4 == 0 ) {
706706 for (size_t i = 0 ; i < DIM_I ; i += 4 ) {
707707 for (size_t j = 0 ; j < DIM_J ; j += 4 ) {
708708
@@ -801,21 +801,20 @@ static void matmul_cpu(bool transA, bool transB, size_t DIM_I, size_t DIM_J, siz
801801 }
802802 }
803803 } else {
804+ size_t A_dim_strides [2 ] = {!transA ? stride_A : 1 , !transA ? 1 : stride_A }; // i, j stride
805+ size_t B_dim_strides [2 ] = {!transB ? 1 : stride_B , !transB ? stride_B : 1 }; // j, k stride
804806 for (size_t i = 0 ; i < DIM_I ; i ++ ) {
805807 for (size_t j = 0 ; j < DIM_J ; j ++ ) {
806- const elem_t * a = !transA ? (A + (i * stride_A )) : A + i ;
807- const elem_t * b = !transB ? (B + j ) : (B + (j * stride_B ));
808808 elem_t * c = C + (i * stride_C ) + j ;
809809
810810 const size_t bias_row = repeating_bias ? 0 : i ;
811811 acc_t sum = no_bias ? 0 : GEMMINI_ACC_SCALE (* (D + bias_row * stride_D + j ), D_scale_factor );
812812
813813 for (size_t k = 0 ; k < DIM_K ; k ++ ) {
814+ const elem_t * a = A + i * A_dim_strides [0 ] + k * A_dim_strides [1 ];
815+ const elem_t * b = B + j * B_dim_strides [0 ] + k * B_dim_strides [1 ];
814816 sum += (GEMMINI_SCALE (* a , A_scale_factor ) * GEMMINI_SCALE (* b , B_scale_factor ));
815- b += !transB ? stride_B : 1 ;
816- a += !transA ? 1 : stride_A ;
817817 }
818-
819818 * c = scale_and_sat (sum , act , scale , relu6_shift );
820819 }
821820 }
0 commit comments