Skip to content

Commit 9533911

Browse files
committed
Merge branch 'dev' into first-layer-overlap
2 parents 1d2e7b3 + d68fe69 commit 9533911

File tree

7 files changed

+431
-71
lines changed

7 files changed

+431
-71
lines changed

bareMetalC/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ tests = \
1919
mvin_scale \
2020
conv \
2121
conv_with_pool \
22+
conv_with_dilation \
2223
tiled_matmul_os \
2324
tiled_matmul_ws \
2425
tiled_matmul_ws_At \

bareMetalC/conv.c

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#define KERNEL_DIM 3
1818
#define PADDING 1
1919
#define STRIDE 2
20+
#define DILATION 1
2021

2122
#else
2223

@@ -38,26 +39,27 @@
3839
#define KERNEL_DIM 3
3940
#define PADDING 1
4041
#define STRIDE 2
42+
#define DILATION 1
4143

4244
#endif
4345

4446
#define NO_BIAS false
4547

46-
#define OUT_DIM ((IN_DIM + 2*PADDING - KERNEL_DIM) / STRIDE + 1)
48+
#define OUT_DIM ((IN_DIM + 2*PADDING - DILATION * (KERNEL_DIM - 1) - 1) / STRIDE + 1)
4749
#define PATCH_SIZE (KERNEL_DIM * KERNEL_DIM * IN_CHANNELS)
4850
#define N_PATCHES (BATCH_SIZE * OUT_DIM * OUT_DIM)
4951

5052
void conv(int batch_size, int in_channels, int in_dim,
5153
int out_channels, int kernel_dim,
5254
int out_dim,
53-
int stride, int padding,
55+
int stride, int dilation, int padding,
5456
elem_t input[batch_size][in_dim][in_dim][in_channels],
5557
elem_t weights[out_channels][kernel_dim][kernel_dim][in_channels],
5658
acc_t bias[out_channels],
5759
elem_t output[batch_size][out_dim][out_dim][out_channels]) {
5860

5961
#ifdef GEMMINI_ASSERTIONS
60-
if (out_dim != (in_dim + 2*padding - kernel_dim) / stride + 1) {
62+
if (out_dim != (in_dim + 2*padding - dilation * (kernel_dim - 1) - 1) / stride + 1) {
6163
printf("conv out_dim is not correct\n");
6264
exit(1);
6365
}
@@ -72,8 +74,8 @@ void conv(int batch_size, int in_channels, int in_dim,
7274
for (int krow = 0; krow < kernel_dim; krow++) {
7375
for (int kcol = 0; kcol < kernel_dim; kcol++) {
7476
for (int kch = 0; kch < in_channels; kch++) {
75-
int irow = orow * stride + krow - padding;
76-
int icol = ocol * stride + kcol - padding;
77+
int irow = orow * stride + krow * dilation - padding;
78+
int icol = ocol * stride + kcol * dilation - padding;
7779

7880
elem_t pixel = irow < 0 || irow >= in_dim ||
7981
icol < 0 || icol >= in_dim ?
@@ -193,7 +195,7 @@ int main() {
193195
conv(BATCH_SIZE, IN_CHANNELS, IN_DIM,
194196
OUT_CHANNELS, KERNEL_DIM,
195197
OUT_DIM,
196-
STRIDE, PADDING,
198+
STRIDE, DILATION, PADDING,
197199
input,
198200
weights,
199201
bias,
@@ -216,7 +218,7 @@ int main() {
216218
tiled_conv_A_stride_auto(
217219
BATCH_SIZE, IN_DIM, IN_CHANNELS,
218220
OUT_CHANNELS, OUT_DIM,
219-
STRIDE, PADDING, KERNEL_DIM,
221+
STRIDE, DILATION, PADDING, KERNEL_DIM,
220222

221223
(elem_t*)input,
222224
(elem_t*)weights_mat,

0 commit comments

Comments
 (0)