@@ -1400,4 +1400,92 @@ namespace nz::krnl {
14001400 StreamManager<float >::Instance ().submit (col2imgBackwardKernel, gridDim , blockDim , 0 , out, in, H_out, W_out,
14011401 C_out, batches);
14021402 }
1403+
1404+ __global__ void AveragePoolingKernel (float * out, const float * in, const size_t pool_size, const size_t stride, const size_t padding,
1405+ const size_t batches, const size_t channels, const size_t H_in, const size_t W_in, const size_t H_out, const size_t W_out) {
1406+ const size_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
1407+ if (idx >= batches * channels * H_out * W_out) {
1408+ return ;
1409+ }
1410+ const size_t currentBatch = idx / (channels * H_out * W_out);
1411+ const size_t currentChannel = (idx % (channels * H_out * W_out)) / (H_out * W_out);
1412+ const size_t h = (idx % (H_out * W_out)) / W_out;
1413+ const size_t w = (idx % (H_out * W_out)) % W_out;
1414+ const long long h_start = h * stride - padding;
1415+ const long long w_start = w * stride - padding;
1416+ out[idx] = 0 .0f ;
1417+ size_t count = 0 ;
1418+ for (long long i = 0 ; i < pool_size; i++) {
1419+ for (long long j = 0 ; j < pool_size; j++) {
1420+ const long long h_in = h_start + i;
1421+ const long long w_in = w_start + j;
1422+ if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
1423+ out[idx] += in[currentBatch * (channels * H_in * W_in) + currentChannel * (H_in * W_in) + h_in * W_in + w_in];
1424+ count++;
1425+ }
1426+ }
1427+ }
1428+ out[idx] = count > 0 ? out[idx] / (float )count : 0 .0f ;
1429+ }
1430+
1431+ void AveragePooling (const dim3 gridDim , const dim3 blockDim , float * out, float * in,
1432+ const size_t pool_size, const size_t stride, const size_t padding,
1433+ const size_t batches, const size_t channels, const size_t H_in, const size_t W_in,
1434+ const size_t H_out, const size_t W_out) {
1435+ StreamManager<float >::Instance ().submit (AveragePoolingKernel, gridDim , blockDim , 0 , out, in,
1436+ pool_size, stride, padding, batches, channels, H_in, W_in, H_out, W_out);
1437+ }
1438+
1439+ __global__ void AveragePoolingBackwardKernel (float * out, const float * in, const size_t pool_size, const size_t stride, const size_t padding,
1440+ const size_t batches, const size_t channels, const size_t H_in, const size_t W_in, const size_t H_out, const size_t W_out) {
1441+ const size_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
1442+ if (idx >= batches * channels * H_out * W_out) {
1443+ return ;
1444+ }
1445+ const size_t currentBatch = idx / (channels * H_out * W_out);
1446+ const size_t currentChannel = (idx % (channels * H_out * W_out)) / (H_out * W_out);
1447+ const size_t h = (idx % (H_out * W_out)) / W_out;
1448+ const size_t w = (idx % (H_out * W_out)) % W_out;
1449+ const long long h_start = h * stride - padding;
1450+ const long long w_start = w * stride - padding;
1451+ if (!padding) {
1452+ for (long long i = 0 ; i < pool_size; i++) {
1453+ for (long long j = 0 ; j < pool_size; j++) {
1454+ const long long h_in = h_start + i;
1455+ const long long w_in = w_start + j;
1456+ if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
1457+ atomicAdd (out + currentBatch * (channels * H_in * W_in) + currentChannel * (H_in * W_in) + h_in * W_in + w_in, in[idx] / (float )(pool_size*pool_size));
1458+ }
1459+ }
1460+ }
1461+ } else {
1462+ size_t count = 0 ;
1463+ for (long long i = 0 ; i < pool_size; i++) {
1464+ for (long long j = 0 ; j < pool_size; j++) {
1465+ const long long h_in = h_start + i;
1466+ const long long w_in = w_start + j;
1467+ if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
1468+ count++;
1469+ }
1470+ }
1471+ }
1472+ for (long long i = 0 ; i < pool_size; i++) {
1473+ for (long long j = 0 ; j < pool_size; j++) {
1474+ const long long h_in = h_start + i;
1475+ const long long w_in = w_start + j;
1476+ if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
1477+ atomicAdd (out + currentBatch * (channels * H_in * W_in) + currentChannel * (H_in * W_in) + h_in * W_in + w_in, in[idx] / (float )count);
1478+ }
1479+ }
1480+ }
1481+ }
1482+ }
1483+
1484+ void AveragePoolingBackward (const dim3 gridDim , const dim3 blockDim , float * out, float * in,
1485+ const size_t pool_size, const size_t stride, const size_t padding,
1486+ const size_t batches, const size_t channels, const size_t H_in, const size_t W_in,
1487+ const size_t H_out, const size_t W_out) {
1488+ StreamManager<float >::Instance ().submit (AveragePoolingBackwardKernel, gridDim , blockDim , 0 , out, in,
1489+ pool_size, stride, padding, batches, channels, H_in, W_in, H_out, W_out);
1490+ }
14031491}
0 commit comments