Skip to content

Commit 3a10c23

Browse files
committed
Add pooling support to Gemmini's conv-fsm
1 parent 9b0082a commit 3a10c23

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

gemmini/gemmini.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,7 @@ void gemmini_t::loop_conv_ws(reg_t rs1, reg_t rs2) {
10301030
}
10311031
}
10321032

1033+
// Mvout results
10331034
if (output != 0 && no_pool) {
10341035
for (uint16_t b = 0; b < batches; b++)
10351036
for (uint16_t orow = 0; orow < orows; orow++)
@@ -1045,6 +1046,37 @@ void gemmini_t::loop_conv_ws(reg_t rs1, reg_t rs2) {
10451046
((uint64_t)I << 48) | ((uint64_t)J << 32) | C_sp_addr);
10461047
}
10471048
}
1049+
} else if (output != 0 && !no_pool) {
1050+
// gemmini_extended_config_st(out_channels * sizeof(elem_t), pool_stride, pool_size, pool_out_dim, porows, pocols, orows, ocols, pupad, plpad);
1051+
config(
1052+
((uint64_t)ocols << 56) |
1053+
((uint64_t)orows << 48) |
1054+
((uint64_t)pocols << 40) |
1055+
((uint64_t)porows << 32) |
1056+
((uint64_t)pool_out_dim << 24) |
1057+
((uint64_t)plpad << 10) |
1058+
((uint64_t)pupad << 8) |
1059+
((uint64_t)pool_size << 6) |
1060+
((uint64_t)pool_stride << 4) |
1061+
2,
1062+
out_channels * sizeof(elem_t));
1063+
1064+
1065+
(pool_out_dim << 24) | (plpad << 10) | (pupad << 8) | (pool_size << 6) | (pool_stride << 4)
1066+
1067+
for (int b = 0; b < batches; b++) {
1068+
for (int poch = 0; poch < pochs; poch += DIM) {
1069+
const int channels = poch + DIM >= pochs ? pochs - poch : DIM;
1070+
1071+
const uint32_t C_sp_addr = C_sp_addr_start + (poch / DIM) * batches * orows * ocols + b * orows * ocols;
1072+
1073+
mvout(output + ((b * pool_out_dim * pool_out_dim)*out_channels + poch) * sizeof(elem_t),
1074+
((uint64_t)channels << 32) | C_sp_addr);
1075+
}
1076+
}
1077+
1078+
// gemmini_config_st(out_channels * sizeof(elem_t));
1079+
config(2, out_channels * sizeof(elem_t));
10481080
}
10491081
}
10501082

gemmini/gemmini_params.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ typedef uint32_t acc_scale_t_bits;
5959
i % 2 == 0 ? i : next)); \
6060
result; })
6161

62+
// Rounding right shift equation: https://riscv.github.io/documents/riscv-v-spec/#_vector_fixed_point_rounding_mode_register_vxrm
63+
#define ROUNDING_RIGHT_SHIFT_BITS(x, shift) \
64+
((shift) > 0 ? (((x) >> (shift)) + \
65+
(((shift) == 0 ? 0 : (((x) >> ((shift)-1)) & 1)) & \
66+
((((shift) <= 1 ? 0 : ((x) & ((1 << ((shift)-1)) - 1))) != 0) | (((x) >> (shift)) & 1)))) : ((x) << (-(shift))))
67+
6268
#define ACC_SCALE(x, scale) \
6369
({float y = ROUND_NEAR_EVEN((x) * (scale)); y > INT8_MAX ? INT8_MAX : (y < INT8_MIN ? INT8_MIN : (acc_t)y);})
6470

0 commit comments

Comments
 (0)