1- // See LICENSE for license details.
21
32#include <stdint.h>
43#include <stddef.h>
@@ -21,6 +20,14 @@ typedef acc_t ACC_T;
2120typedef elem_t ACC_T ;
2221#endif
2322
23+ #ifdef FAST
24+
25+ #define MAT_DIM_I 19
26+ #define MAT_DIM_K 18
27+ #define MAT_DIM_J 17
28+
29+ #else
30+
2431#ifndef BAREMETAL
2532#define MAT_DIM_I 500
2633#define MAT_DIM_K 412
@@ -29,8 +36,11 @@ typedef elem_t ACC_T;
2936#define MAT_DIM_I 60
3037#define MAT_DIM_K 50
3138#define MAT_DIM_J 30
39+
3240#endif
3341
42+ #endif // ifdef FAST
43+
3444void print_tile (elem_t * in , int tile_dim ) {
3545 for (size_t r = 0 ; r < tile_dim ; r ++ ) {
3646 printf ("row starts at: %p\n" , in + r * MAT_DIM_J );
@@ -104,30 +114,51 @@ int main() {
104114 // printf("Init A\n");
105115 for (size_t i = 0 ; i < MAT_DIM_I ; ++ i ) {
106116 for (size_t j = 0 ; j < MAT_DIM_K ; ++ j ) {
117+ #ifdef FAST
118+ full_A [i ][j ] = 1 ;
119+ #else
107120 full_A [i ][j ] = rand () % 2 ;
121+ #endif
108122 }
109123 }
110124
111125 // printf("Init B\n");
112126 for (size_t i = 0 ; i < MAT_DIM_J ; ++ i ) {
113127 for (size_t j = 0 ; j < MAT_DIM_K ; ++ j ) {
128+ #ifdef FAST
129+ full_B [i ][j ] = 1 ;
130+ #else
114131 full_B [i ][j ] = rand () % 2 ;
132+ #endif
115133 }
116134 }
117135
118136 // printf("Init D\n");
119137 for (size_t i = 0 ; i < MAT_DIM_I ; ++ i ) {
120138 for (size_t j = 0 ; j < MAT_DIM_J ; ++ j ) {
139+ #ifdef FAST
140+ full_D [i ][j ] = NO_BIAS ? 0 : 1 ;
141+ #else
121142 full_D [i ][j ] = NO_BIAS ? 0 : rand () % 2 ;
143+ #endif
122144 }
123145 }
124146
147+ #ifdef FAST
148+ for (size_t i = 0 ; i < MAT_DIM_I ; ++ i ) {
149+ for (size_t j = 0 ; j < MAT_DIM_J ; ++ j ) {
150+ gold [i ][j ] = MAT_DIM_K + !NO_BIAS ;
151+ }
152+ }
153+ #else
125154 printf ("Starting slow CPU matmul\n" );
126155 unsigned long cpu_start = read_cycles ();
127156 full_matmul (full_A , full_B , full_D , gold_full );
128157 unsigned long cpu_end = read_cycles ();
129158 printf ("Cycles taken: %u\n" , cpu_end - cpu_start );
130159 full_matscale (gold_full , gold , ACC_SCALE_IDENTITY );
160+ #endif // #ifdef FAST
161+
131162#endif
132163
133164 printf ("Starting gemmini matmul\n" );
@@ -149,9 +180,14 @@ int main() {
149180 if (!full_is_equal (full_C , gold )) {
150181 printf ("C:\n" );
151182 full_printMatrix (full_C );
183+
152184 printf ("Gold:\n" );
185+ #ifdef FAST
186+ printf ("All elements must be %d\n" , MAT_DIM_K + !NO_BIAS );
187+ #else
153188 full_printMatrix (gold );
154189 printf ("\n" );
190+ #endif // ifdef FAST
155191
156192 exit (1 );
157193 }
0 commit comments