diff --git a/bareMetalC/Makefile b/bareMetalC/Makefile index 5d6b57f5..db8fc8b2 100644 --- a/bareMetalC/Makefile +++ b/bareMetalC/Makefile @@ -18,6 +18,7 @@ tests = \ padded \ mvin_scale \ conv \ + conv_dw \ conv_with_pool \ conv_with_dilation \ conv_with_dilation_and_rot180 \ diff --git a/bareMetalC/conv_dw.c b/bareMetalC/conv_dw.c new file mode 100644 index 00000000..e8239152 --- /dev/null +++ b/bareMetalC/conv_dw.c @@ -0,0 +1,325 @@ +#include +#include +#include +#include +#include +#ifndef BAREMETAL +#include +#endif +#include "include/gemmini_testutils.h" + +#ifndef BAREMETAL + +#define BATCH_SIZE 4 +#define IN_DIM 224 +#define IN_CHANNELS 3 +#define OUT_CHANNELS 32 +#define KERNEL_DIM 3 +#define PADDING 1 +#define STRIDE 2 + +#else + +#define IN_DIM 14 +#define IN_CHANNELS 34 +#define OUT_CHANNELS IN_CHANNELS +#define BATCH_SIZE 1 +#define KERNEL_DIM 3 +#define PADDING 1 +#define STRIDE 2 + +#endif + +#define NO_BIAS false + +#define OUT_DIM ((IN_DIM + 2*PADDING - KERNEL_DIM) / STRIDE + 1) +#define PATCH_SIZE (KERNEL_DIM * KERNEL_DIM) +#define N_PATCHES (BATCH_SIZE * OUT_DIM * OUT_DIM) + +static void conv_dw(const size_t batch_size, const size_t channels, const size_t in_dim, const size_t out_dim, const size_t kernel_size, + const size_t padding, const size_t stride, + const elem_t input[batch_size][in_dim][in_dim][channels], + const elem_t weight[channels][kernel_size][kernel_size], + const acc_t * bias, + elem_t output [batch_size][out_dim][out_dim][channels]) + //elem_t output [I][J], + //const struct ConvParams * params) +{ + for (int batch = 0; batch < batch_size; batch++) { + for (int channel = 0; channel < channels; channel++) { + for (int out_row = 0; out_row < out_dim; out_row++) { + for (int out_col = 0; out_col < out_dim; out_col++) { + int in_row = out_row * stride - padding; + + acc_t result = 0; + if (bias!=NULL) { + result = bias[channel]; + } + + for (int kernel_row = 0; kernel_row < kernel_size; kernel_row++) { + int in_col = out_col * stride - padding; + + for (int kernel_col = 0; kernel_col < kernel_size; kernel_col++) { + if (in_row >= 0 && in_row < in_dim && in_col >= 0 && in_col < in_dim) { + result += input[batch][in_row][in_col][channel] * weight[channel][kernel_row][kernel_col]; + } + + in_col++; + } + + in_row++; + } + + // if (result < 0) { + // result = 0; + // } + + acc_t scaled = result;//ACC_SCALE(result, params->output_scale); + + if (scaled > elem_t_max) { + scaled = elem_t_max; + } else if (scaled < elem_t_min) { + scaled = elem_t_min; + } + + size_t r = batch * out_dim * out_dim + out_row * out_dim + out_col; + //output[r][channel] = scaled; + output[batch][out_row][out_col][channel] = scaled; + } + } + } + } +} +void flatten_weights(int out_channels, int kernel_dim, + int patch_size, + elem_t weights[out_channels][kernel_dim][kernel_dim], + elem_t weights_mat[patch_size][out_channels]) { + + assert(patch_size == kernel_dim * kernel_dim); + + for (int outc = 0; outc < out_channels; outc++) { + for (int krow = 0; krow < kernel_dim; krow++) { + for (int kcol = 0; kcol < kernel_dim; kcol++) { + int wmatrow = krow * kernel_dim + + kcol; + + weights_mat[wmatrow][outc] = + weights[outc][krow][kcol]; + } + } + } +} + +bool vec_is_equal(elem_t * a, elem_t * b, int len) { + for (int i = 0; i < len; i++) + if (a[i] != b[i]) + return false; + return true; +} + +void init_random(elem_t * buf, int len) { + elem_t i = 0; + for (elem_t * ptr = buf; ptr < buf + len; ptr++) { + // *ptr = (rand() % 32) - 16; +#ifdef FAST + *ptr = 1; +#else + *ptr = (rand() % 5) - 2; +#endif + } +} + +void init_random_acc(acc_t * buf, int len) { + elem_t i = 0; + for (acc_t * ptr = buf; ptr < buf + len; ptr++) { + // *ptr = (rand() % 32) - 16; +#ifdef FAST + *ptr = 1; +#else + *ptr = (rand() % 5) - 2; +#endif + } +} + +void init_zeros_acc(acc_t * buf, int len) { + for (acc_t * ptr = buf; ptr < buf + len; ptr++) { + *ptr = 0; + } +} + +int main() { +#ifndef BAREMETAL + if (mlockall(MCL_CURRENT | MCL_FUTURE) != 0) { + perror("mlockall failed"); + exit(1); + } +#endif + + gemmini_flush(0); + + // assert((in_dim + 2*padding - kernel_dim) % stride == 0); + + printf("Output dimension: %u\n\n", OUT_DIM); + + static elem_t input[BATCH_SIZE][IN_DIM][IN_DIM][IN_CHANNELS]; + static elem_t weights[OUT_CHANNELS][KERNEL_DIM][KERNEL_DIM]; + static acc_t bias[OUT_CHANNELS]; + static elem_t output[BATCH_SIZE][OUT_DIM][OUT_DIM][OUT_CHANNELS]; + + printf("Randomize inputs...\n"); + init_random(&input[0][0][0][0], sizeof(input) / sizeof(elem_t)); + + printf("Randomize weights...\n"); + init_random(&weights[0][0][0], sizeof(weights) / sizeof(elem_t)); + + printf("Randomize bias...\n"); + if (NO_BIAS) + init_zeros_acc(&bias[0], sizeof(bias) / sizeof(acc_t)); + else + init_random_acc(&bias[0], sizeof(bias) / sizeof(acc_t)); + + printf("CPU conv...\n"); + uint64_t start_cpu = read_cycles(); +#ifndef FAST + conv_dw(BATCH_SIZE, IN_CHANNELS, IN_DIM, + OUT_DIM, KERNEL_DIM, + PADDING, STRIDE, + input, + weights, + NO_BIAS ? NULL : bias, + output); +#endif + uint64_t end_cpu = read_cycles(); + printf("CPU conv took %llu cycles\n", end_cpu - start_cpu); + + static elem_t weights_mat[PATCH_SIZE][OUT_CHANNELS]; + static elem_t output_mat[N_PATCHES][OUT_CHANNELS]; + + printf("Flatten weights...\n"); + flatten_weights(OUT_CHANNELS, KERNEL_DIM, + PATCH_SIZE, + weights, + weights_mat); + + printf("Gemmini conv...\n"); + uint64_t start_gemmini = read_cycles(); + tiled_conv_A_stride_dw_auto( + BATCH_SIZE, IN_DIM, IN_CHANNELS, + OUT_CHANNELS, OUT_DIM, + STRIDE, 1, PADDING, KERNEL_DIM, + + (elem_t*)input, + (elem_t*)weights_mat, + NO_BIAS ? NULL : (acc_t*)bias, + (elem_t*)output_mat, + + NO_ACTIVATION, ACC_SCALE_IDENTITY, 0, 0, 0, 0, + + WS); + uint64_t end_gemmini = read_cycles(); + printf("Gemmini conv took %llu cycles\n", end_gemmini - start_gemmini); + + assert(sizeof(output_mat) == sizeof(output)); + +#ifdef FAST + bool success = true; + for (int orow = 0; orow < BATCH_SIZE * OUT_DIM * OUT_DIM; orow++) { + for (int ocol = 0; ocol < OUT_CHANNELS; ocol++) { + elem_t v = output_mat[orow][ocol]; + if (v != 21 && v != 31 && v != 46) { + success = false; + break; + } + } + } +#else + bool success = vec_is_equal(&output[0][0][0][0], &output_mat[0][0], sizeof(output) / sizeof(elem_t)); +#endif + + if (!success) { + // return 1; + printf("bias:\n"); + for (int och = 0; och < OUT_CHANNELS; och++) { + printf("%d,", bias[och]); + } + printf("\b\n\n"); + + printf("weights:\n"); + for (int och = 0; och < OUT_CHANNELS; och++) { + printf("["); + for (int wrow = 0; wrow < KERNEL_DIM; wrow++) { + printf("["); + for (int wcol = 0; wcol < KERNEL_DIM; wcol++) { + printf("["); + for (int ich = 0; ich < IN_CHANNELS; ich++) { + printf("%d,", weights[och][wrow][wcol][ich]); + } + printf("\b],"); + } + printf("\b],\n"); + } + printf("\b],"); + } + printf("\b\n\n"); + + printf("weights_mat:\n"); + for (int wrow = 0; wrow < KERNEL_DIM * KERNEL_DIM; wrow++) { + printf("["); + for (int wcol = 0; wcol < OUT_CHANNELS; wcol++) { + printf("%d,", weights_mat[wrow][wcol]); + } + printf("\b],\n"); + } + printf("\b\n\n"); + + printf("input:\n"); + for (int batch = 0; batch < BATCH_SIZE; batch++) { + printf("["); + for (int irow = 0; irow < IN_DIM; irow++) { + printf("["); + for (int icol = 0; icol < IN_DIM; icol++) { + printf("["); + for (int ich = 0; ich < IN_CHANNELS; ich++) { + printf("%d,", input[batch][irow][icol][ich]); + } + printf("\b],"); + } + printf("\b],\n"); + } + printf("\b],"); + } + printf("\b\n\n"); + + printf("output:\n"); + for (int batch = 0; batch < BATCH_SIZE; batch++) { + printf("["); + for (int orow = 0; orow < OUT_DIM; orow++) { + printf("["); + for (int ocol = 0; ocol < OUT_DIM; ocol++) { + printf("["); + for (int och = 0; och < OUT_CHANNELS; och++) { + printf("%d,", output[batch][orow][ocol][och]); + } + printf("\b],\n"); + } + printf("\b],\n"); + } + printf("\b],"); + } + printf("\b\n\n"); + + printf("output_mat:\n"); + for (int orow = 0; orow < BATCH_SIZE * OUT_DIM * OUT_DIM; orow++) { + printf("["); + for (int ocol = 0; ocol < OUT_CHANNELS; ocol++) { + printf("%d,", output_mat[orow][ocol]); + } + printf("\b],\n"); + } + printf("\b\n\n"); + + return 1; + } + + return 0; +} diff --git a/include/gemmini.h b/include/gemmini.h index c7da2f27..e0415337 100644 --- a/include/gemmini.h +++ b/include/gemmini.h @@ -284,7 +284,7 @@ static acc_scale_t_bits acc_scale_t_to_acc_scale_t_bits(acc_scale_t x) { } // weight-stationary matmul loop -#define gemmini_loop_conv_ws(batch_size, in_dim, in_channels, out_channels, out_dim, pool_out_dim, stride, padding, kernel_dim, pool_size, pool_stride, pool_padding, batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, input, no_bias, no_pool, downsample) \ +#define gemmini_loop_conv_ws(batch_size, in_dim, in_channels, out_channels, out_dim, pool_out_dim, stride, padding, kernel_dim, pool_size, pool_stride, pool_padding, batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, input, no_bias, no_pool, downsample, depthwise) \ { \ ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(out_channels) << 48) | ((uint64_t)(in_channels) << 32) | ((uint64_t)(in_dim) << 16) | (uint64_t)(batch_size), \ ((uint64_t)(padding) << 48) | ((uint64_t)(stride) << 32) | ((uint64_t)(pool_out_dim) << 16) | (uint64_t)(out_dim), k_LOOP_CONV_WS_CONFIG_1) \ @@ -299,7 +299,7 @@ static acc_scale_t_bits acc_scale_t_to_acc_scale_t_bits(acc_scale_t x) { ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, bias, \ input, k_LOOP_CONV_WS_CONFIG_6) \ ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, no_bias, \ - ((downsample) << 1) | (no_pool), k_LOOP_CONV_WS) \ + ((uint64_t)(depthwise) << 63) |((downsample) << 1) | (no_pool), k_LOOP_CONV_WS) \ } // Tiling functions @@ -1110,10 +1110,10 @@ static void sp_tiled_conv_A_stride( if (output != 0) { C_sp_addr_row = (C_sp_addr_row + ACC_ROWS / 2) % ACC_ROWS; } + bool depthwise = (kchs == 1) && (in_channels != 1); + gemmini_loop_conv_ws(batch_size, in_dim, in_channels, out_channels, out_dim, pool_out_dim, stride, padding, kernel_dim, pool_size, pool_stride, pool_padding, batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, input, no_bias, no_pool, downsample, depthwise); - gemmini_loop_conv_ws(batch_size, in_dim, in_channels, out_channels, out_dim, pool_out_dim, stride, padding, kernel_dim, pool_size, pool_stride, pool_padding, batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, input, no_bias, no_pool, downsample); - - /* +/* // mvin bias if (!no_bias && bias != NULL) { // TODO we probably don't need quite this many nested loops for this part @@ -1205,8 +1205,8 @@ static void sp_tiled_conv_A_stride( const int K = kchs - kch > DIM ? DIM : kchs - kch; const uint32_t B_sp_addr = B_sp_addr_start + (och / DIM) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch; - - gemmini_extended_mvin2(weights + (krow*kernel_dim*in_channels + kcol*in_channels + kch) * out_channels + och, + gemmini_extended_mvin2(weights + (krow*kernel_dim + kcol + kch) * out_channels + och, + //gemmini_extended_mvin2(weights + (krow*kernel_dim*in_channels + kcol*in_channels + kch) * out_channels + och, B_sp_addr, J, K); } @@ -1258,8 +1258,10 @@ static void sp_tiled_conv_A_stride( J, K, J, I); if (new_weights) { - gemmini_extended_compute_preloaded(A_sp_addr, GARBAGE_ADDR, K, I, J, I); + printf("preloaded krow: %d, kcol: %d \n", krow, kcol); + gemmini_extended_compute_preloaded(A_sp_addr, GARBAGE_ADDR, K, I, J, I); } else { + printf("accumulated krow: %d, kcol: %d \n", krow, kcol); gemmini_extended_compute_accumulated(A_sp_addr, GARBAGE_ADDR, K, I, J, I); } } @@ -1310,7 +1312,7 @@ static void sp_tiled_conv_A_stride( gemmini_fence(); } } - */ +*/ } //resnet downsampling layer (no padding, kernel size 1, stride 2) @@ -2625,6 +2627,155 @@ static void tiled_conv_A_stride( } } +static void tiled_conv_A_stride_dw( + int batch_size, int in_dim, int in_channels, + int out_channels, int out_dim, + int stride, int dilation, int padding, int kernel_dim, + + int batches, + int porows, int pocols, int pochs, + int krows, int kcols, int kchs, + + elem_t * input, + elem_t * weights, + acc_t * bias, + elem_t * output, + + int act, acc_scale_t scale, size_t relu6_shift, + int pool_size, int pool_stride, int pool_padding, + + enum tiled_matmul_type_t tiled_conv_type){ + + // TODO move everything below this into a tiled_conv_outer function to match the tiled_matmul function + + bool no_bias = false; + if (bias == NULL) { + bias = (acc_t*)1; + no_bias = true; + } + + bool no_pool = pool_stride == 0; + if (no_pool) { + pool_size = 1; + pool_stride = 1; + pool_padding = 0; + } + +#ifdef GEMMINI_ASSERTIONS + { + // const int orows = porows * pool_stride + pool_size - 1; + // const int ocols = pocols * pool_stride + pool_size - 1; + + // Check that data will fit in scratchpad + const int spad_rows = tiled_conv_total_spad_rows_A_stride(false, + stride, dilation, false, batches, porows, pocols, pochs, krows, kcols, kchs, pool_size, pool_stride); + const int acc_rows = tiled_conv_total_spad_rows_A_stride(true, + stride, dilation, false, batches, porows, pocols, pochs, krows, kcols, kchs, pool_size, pool_stride); + + if (spad_rows > BANK_NUM * BANK_ROWS / 2) { + printf("not enough scratchpad space to store inputs and weights, %d\n", spad_rows); + exit(1); + } + if (acc_rows > ACC_ROWS / 2) { + printf("not enough accumulator space to store outputs\n"); + exit(1); + } + if (kernel_dim <= padding) { + printf("kernel_dim must be larger than padding\n"); + exit(1); + } + if (dilation != 1) { + printf("dilation is only supported on CPU\n"); + exit(1); + } + } +#endif + + int och_stride = out_channels; + int ich_stride = in_channels; + gemmini_config_st(och_stride * sizeof(elem_t)); + gemmini_extended_config_ex(WEIGHT_STATIONARY, act, 0, scale, relu6_shift, stride, false, false); + + const int pool_out_dim = (out_dim + 2*pool_padding - pool_size) / pool_stride + 1; + + //pochs = 1 + for (int b = 0; b < batch_size; b += batches) { + for (int porow = 0; porow < pool_out_dim; porow += porows) { + const int orow = porow * pool_stride - pool_padding; + for (int pocol = 0; pocol < pool_out_dim; pocol += pocols) { + const int ocol = pocol * pool_stride - pool_padding; + for (int poch = 0; poch < out_channels; poch += pochs) { + for (int krow = 0; krow < kernel_dim; krow += krows) { + const int orow_floored = orow < 0 ? 0 : orow; + const int irow = orow_floored * stride + krow - padding; + for (int kcol = 0; kcol < kernel_dim; kcol += kcols) { + const int ocol_floored = ocol < 0 ? 0 : ocol; + const int icol = ocol_floored * stride + kcol - padding; + elem_t * out = output + (b*pool_out_dim*pool_out_dim + porow*pool_out_dim + pocol) * och_stride + poch; + + if (krow + krows < kernel_dim || + kcol + kcols < kernel_dim) { + out = NULL; + } + + acc_t * bias_ = bias + poch; + if (krow > 0 || + kcol > 0) { + bias_ = NULL; + } + + const int batches_ = batch_size - b > batches ? batches : batch_size - b; + const int porows_ = pool_out_dim - porow > porows ? porows : pool_out_dim - porow; + const int pocols_ = pool_out_dim - pocol > pocols ? pocols : pool_out_dim - pocol; + const int pochs_ = out_channels - poch > pochs ? pochs : out_channels - poch; + const int krows_ = kernel_dim - krow > krows ? krows : kernel_dim - krow; + const int kcols_ = kernel_dim - kcol > kcols ? kcols : kernel_dim - kcol; + + const int ocols_ = pocols_ * pool_stride + pool_size - 1; + const int orows_ = porows_ * pool_stride + pool_size - 1; + + const int plpad = ocol < 0 ? -ocol : 0; + const int prpad = ocol + ocols_ > out_dim ? ocol + ocols_ - out_dim : 0; + const int pupad = orow < 0 ? -orow : 0; + const int pdpad = orow + orows_ > out_dim ? orow + orows_ - out_dim : 0; + + const int icols_ = (ocols_ - plpad - prpad) * stride + kcols_ - 1; + const int irows_ = (orows_ - pupad - pdpad) * stride + krows_ - 1; + + const int lpad = icol < 0 ? -icol : 0; + const int rpad = icol + icols_ > in_dim ? icol + icols_ - in_dim : 0; + const int upad = irow < 0 ? -irow : 0; + const int dpad = irow + irows_ > in_dim ? irow + irows_ - in_dim : 0; + sp_tiled_conv_A_stride( + batch_size, in_dim, in_channels, + out_channels, out_dim, pool_out_dim, + + stride, padding, kernel_dim, + + pool_size, pool_stride, pool_padding, + + batches_, + porows_, pocols_, pochs_, + krows_, kcols_, 1, //hardcode kchs as 1 + + lpad, rpad, upad, dpad, + plpad, prpad, pupad, pdpad, + + input + (b*in_dim*in_dim + (irow+upad)*in_dim + (icol+lpad)) * ich_stride + poch, //instead of och + weights + poch,// + (krow*kernel_dim*in_channels + kcol*in_channels + kch) * och_stride + poch, + out, + bias_, + + no_bias, no_pool, false); + + } + } + } + } + } + } +} + static void tiled_conv_A_stride_auto( int batch_size, int in_dim, int in_channels, int out_channels, int out_dim, @@ -2811,6 +2962,185 @@ static void tiled_conv_A_stride_auto( tiled_conv_type); } +// with bubble insertion functionality +static void tiled_conv_A_stride_dw_auto( + int batch_size, int in_dim, int in_channels, + int out_channels, int out_dim, + int stride, int dilation, int padding, int kernel_dim, + + elem_t * input, + elem_t * weights, + acc_t * bias, + elem_t * output, + + int act, acc_scale_t scale, size_t relu6_shift, + int pool_size, int pool_stride, int pool_padding, + + enum tiled_matmul_type_t tiled_conv_type){ + const bool no_pool = pool_stride == 0; + if (no_pool) { + pool_size = 1; + pool_stride = 1; + pool_padding = 0; + } + + const int pool_out_dim = (out_dim + 2*pool_padding - pool_size) / pool_stride + 1; + + // Tile convolution params + + // int args[] = {batch_size, porows, pocols, pochs, krows, kcols, kchs}; + int args[] = {batch_size, pool_out_dim, pool_out_dim, 1, kernel_dim, kernel_dim, 1}; + const int max_args[] = {batch_size, pool_out_dim, pool_out_dim, out_channels, kernel_dim, kernel_dim, in_channels}; + //input, output channel the same + const int orows_idx = 1; + const int ocols_idx = 2; + const int out_channels_idx = 3; + const int in_channels_idx = 6; + + // We divide by 2 for the sake of double-buffering + const int max_spad_rows = (BANK_NUM*BANK_ROWS / 2); + const int max_acc_rows = (ACC_ROWS / 2); + + int spad_rows = tiled_conv_total_spad_rows_A_stride(false, + stride, dilation, false, args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride); + int acc_rows = tiled_conv_total_spad_rows_A_stride(true, + stride, dilation, false, args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride); + + bool down_sample = (in_dim > out_dim); + while (spad_rows > max_spad_rows || acc_rows > max_acc_rows) { + int max_val = -1; + int max_idx = -1; + + //don't need to tile channels + for (size_t j = 0; j < sizeof(args)/sizeof(args[0]) - 2; j++) { + // We avoid reducing ocols when possible to keep the spatial array fully utilized + size_t i = 0; + // if(!down_sample){ + if(j == 0) i = 0; + else if (j == 4) i = orows_idx; + else if(j == 1) i = ocols_idx; + else if (j == 2) i = 4; + else if(j == 3) i = 5; + + if(i == 0 && args[0] > 1){ // batch first + max_val = args[0]; + max_idx = 0; + break; + }else if (!(i == ocols_idx && args[i] <= DIM) + && args[i] > max_val) { // and then move on to channels + max_val = args[i]; + max_idx = i; + } + } + // printf("max_val: %d, max_idx: %d \n", max_val, max_idx); + + if(max_idx == ocols_idx){ + if(args[max_idx] % DIM != 0) args[max_idx] = (args[max_idx]/DIM)*DIM; + else args[max_idx] -= DIM; + }else{ + if(max_idx == 4 || max_idx == 5) args[max_idx] = 1; + else args[max_idx]--; + } + + + spad_rows = tiled_conv_total_spad_rows_A_stride(false, + stride, dilation, false, args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride); + acc_rows = tiled_conv_total_spad_rows_A_stride(true, + stride, dilation, false, args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride); + } +/* + printf("batches = %d\n", args[0]); + printf("orows = %d\n", args[1]); + printf("ocols = %d\n", args[2]); + printf("ochs = %d\n", args[3]); + printf("krows = %d\n", args[4]); + printf("kcols = %d\n", args[5]); + printf("kchs = %d\n\n", args[6]); +*/ + + // Check if we can increase ocols + bool not_increased = false; + + // Check if there are any parameters that we can currently still increase + bool nothing_increased = false; + bool kdim_increase = true; + while (!nothing_increased) { + nothing_increased = true; + //kdim_increase = true; + + for (size_t j = 0; j < sizeof(args)/sizeof(args[0]) - 2; j++) { + //size_t i =j;// down_sample ? j : 6-j; + size_t i = j; + if(j == 0) i = 5;//in_channels_idx; + else if (j == 1) i = 4; + else if(j == 2) i = ocols_idx; + else if (j == 3) i = orows_idx; + else if(j == 4) i = 0; + int args_candidate[] = {args[0], args[1], args[2], args[3], args[4], args[5], args[6]}; + if(i == ocols_idx && (args[i] % DIM == 0)) args_candidate[i] += DIM; + else args_candidate[i]+= kdim_increase && (i == 4 || i == 5) ? 2 : 1; + if (args_candidate[i] > max_args[i]) + continue; + + spad_rows = tiled_conv_total_spad_rows_A_stride(false, + stride, dilation, false, args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride); + acc_rows = tiled_conv_total_spad_rows_A_stride(true, + stride, dilation, false, args_candidate[0], args_candidate[1], args_candidate[2], args_candidate[3], args_candidate[4], args_candidate[5], args_candidate[6], pool_size, pool_stride); + + if (spad_rows <= max_spad_rows && acc_rows <= max_acc_rows) { + args[i] = args_candidate[i]; + nothing_increased = false; + kdim_increase = false; + } + } + } + + const int batches = args[0]; + const int orows = args[1]; + const int ocols = args[2]; + const int ochs = args[3]; + const int krows = args[4]; + const int kcols = args[5]; + const int kchs = args[6]; + +/* +// spad_rows = tiled_conv_total_spad_rows_A_stride(false, +// stride, dilation, args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride); +// acc_rows = tiled_conv_total_spad_rows_A_stride(true, +// stride, dilation, args[0], args[1], args[2], args[3], args[4], args[5], args[6], pool_size, pool_stride); + printf("batches = %d\n", batches); + printf("orows = %d\n", orows); + printf("ocols = %d\n", ocols); + printf("ochs = %d\n", ochs); + printf("krows = %d\n", krows); + printf("kcols = %d\n", kcols); + printf("kchs = %d\n\n", kchs); +// printf("total spad_rows reserved: %d\n", spad_rows); +// printf("total acc_rows reserved: %d\n\n", acc_rows); +// printf("scratchpad row utilization: %d%%\n", (spad_rows*100) / max_spad_rows); +// printf("accumulator row utilization: %d%%\n\n", (acc_rows*100) / max_acc_rows); + printf("inner matmul size: i=%d, j=%d, k=%d\n\n", ocols, ochs, kchs); +*/ + tiled_conv_A_stride_dw( + batch_size, in_dim, in_channels, + out_channels, out_dim, + stride, dilation, padding, kernel_dim, + + batches, + orows, ocols, ochs, + krows, kcols, kchs, + + input, + weights, + bias, + output, + + act, scale, relu6_shift, + pool_size, no_pool ? 0 : pool_stride, pool_padding, + + tiled_conv_type);//, ich_padding, och_padding, och_divide, skip_weight); +} + // This function is for a convolution with kernel_dim=1, stride==2, padding=0, and no pooling static void tiled_conv_downsample( int batch_size, int in_dim, int in_channels,