@@ -1030,6 +1030,7 @@ void gemmini_t::loop_conv_ws(reg_t rs1, reg_t rs2) {
10301030 }
10311031 }
10321032
1033+ // Mvout results
10331034 if (output != 0 && no_pool) {
10341035 for (uint16_t b = 0 ; b < batches; b++)
10351036 for (uint16_t orow = 0 ; orow < orows; orow++)
@@ -1045,6 +1046,37 @@ void gemmini_t::loop_conv_ws(reg_t rs1, reg_t rs2) {
10451046 ((uint64_t )I << 48 ) | ((uint64_t )J << 32 ) | C_sp_addr);
10461047 }
10471048 }
1049+ } else if (output != 0 && !no_pool) {
1050+ // gemmini_extended_config_st(out_channels * sizeof(elem_t), pool_stride, pool_size, pool_out_dim, porows, pocols, orows, ocols, pupad, plpad);
1051+ config (
1052+ ((uint64_t )ocols << 56 ) |
1053+ ((uint64_t )orows << 48 ) |
1054+ ((uint64_t )pocols << 40 ) |
1055+ ((uint64_t )porows << 32 ) |
1056+ ((uint64_t )pool_out_dim << 24 ) |
1057+ ((uint64_t )plpad << 10 ) |
1058+ ((uint64_t )pupad << 8 ) |
1059+ ((uint64_t )pool_size << 6 ) |
1060+ ((uint64_t )pool_stride << 4 ) |
1061+ 2 ,
1062+ out_channels * sizeof (elem_t ));
1063+
1064+
1065+ (pool_out_dim << 24 ) | (plpad << 10 ) | (pupad << 8 ) | (pool_size << 6 ) | (pool_stride << 4 )
1066+
1067+ for (int b = 0 ; b < batches; b++) {
1068+ for (int poch = 0 ; poch < pochs; poch += DIM) {
1069+ const int channels = poch + DIM >= pochs ? pochs - poch : DIM;
1070+
1071+ const uint32_t C_sp_addr = C_sp_addr_start + (poch / DIM) * batches * orows * ocols + b * orows * ocols;
1072+
1073+ mvout (output + ((b * pool_out_dim * pool_out_dim)*out_channels + poch) * sizeof (elem_t ),
1074+ ((uint64_t )channels << 32 ) | C_sp_addr);
1075+ }
1076+ }
1077+
1078+ // gemmini_config_st(out_channels * sizeof(elem_t));
1079+ config (2 , out_channels * sizeof (elem_t ));
10481080 }
10491081}
10501082
0 commit comments