1717#define KERNEL_DIM 3
1818#define PADDING 1
1919#define STRIDE 2
20+ #define DILATION 1
2021
2122#else
2223
3839#define KERNEL_DIM 3
3940#define PADDING 1
4041#define STRIDE 2
42+ #define DILATION 1
4143
4244#endif
4345
4446#define NO_BIAS false
4547
46- #define OUT_DIM ((IN_DIM + 2*PADDING - KERNEL_DIM) / STRIDE + 1)
48+ #define OUT_DIM ((IN_DIM + 2*PADDING - DILATION * ( KERNEL_DIM - 1) - 1 ) / STRIDE + 1)
4749#define PATCH_SIZE (KERNEL_DIM * KERNEL_DIM * IN_CHANNELS)
4850#define N_PATCHES (BATCH_SIZE * OUT_DIM * OUT_DIM)
4951
5052void conv (int batch_size , int in_channels , int in_dim ,
5153 int out_channels , int kernel_dim ,
5254 int out_dim ,
53- int stride , int padding ,
55+ int stride , int dilation , int padding ,
5456 elem_t input [batch_size ][in_dim ][in_dim ][in_channels ],
5557 elem_t weights [out_channels ][kernel_dim ][kernel_dim ][in_channels ],
5658 acc_t bias [out_channels ],
5759 elem_t output [batch_size ][out_dim ][out_dim ][out_channels ]) {
5860
5961#ifdef GEMMINI_ASSERTIONS
60- if (out_dim != (in_dim + 2 * padding - kernel_dim ) / stride + 1 ) {
62+ if (out_dim != (in_dim + 2 * padding - dilation * ( kernel_dim - 1 ) - 1 ) / stride + 1 ) {
6163 printf ("conv out_dim is not correct\n" );
6264 exit (1 );
6365 }
@@ -72,8 +74,8 @@ void conv(int batch_size, int in_channels, int in_dim,
7274 for (int krow = 0 ; krow < kernel_dim ; krow ++ ) {
7375 for (int kcol = 0 ; kcol < kernel_dim ; kcol ++ ) {
7476 for (int kch = 0 ; kch < in_channels ; kch ++ ) {
75- int irow = orow * stride + krow - padding ;
76- int icol = ocol * stride + kcol - padding ;
77+ int irow = orow * stride + krow * dilation - padding ;
78+ int icol = ocol * stride + kcol * dilation - padding ;
7779
7880 elem_t pixel = irow < 0 || irow >= in_dim ||
7981 icol < 0 || icol >= in_dim ?
@@ -193,7 +195,7 @@ int main() {
193195 conv (BATCH_SIZE , IN_CHANNELS , IN_DIM ,
194196 OUT_CHANNELS , KERNEL_DIM ,
195197 OUT_DIM ,
196- STRIDE , PADDING ,
198+ STRIDE , DILATION , PADDING ,
197199 input ,
198200 weights ,
199201 bias ,
@@ -216,7 +218,7 @@ int main() {
216218 tiled_conv_A_stride_auto (
217219 BATCH_SIZE , IN_DIM , IN_CHANNELS ,
218220 OUT_CHANNELS , OUT_DIM ,
219- STRIDE , PADDING , KERNEL_DIM ,
221+ STRIDE , DILATION , PADDING , KERNEL_DIM ,
220222
221223 (elem_t * )input ,
222224 (elem_t * )weights_mat ,
0 commit comments