Skip to content

Commit 8540b6b

Browse files
committed
feat(pooling): add average pooling layer and related test cases
- Added an average pooling layer to the model architecture, enabling the downsampling of feature maps using average pooling. - Implemented the forward pass logic for the average pooling layer, including the calculation of average values within pooling windows. - Developed a set of test cases to verify the correctness of the average pooling layer implementation. - The test cases cover different input sizes, pooling kernel sizes, and stride values to ensure the robustness of the layer.
1 parent f106dd7 commit 8540b6b

8 files changed

Lines changed: 403 additions & 0 deletions

File tree

include/NeuZephyr/Model.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,26 @@ namespace nz {
6868
Tensor::size_type kernelWidth,
6969
Tensor::size_type stride, Tensor::size_type padding, bool bias = true);
7070

71+
Node* AvgPool2d(Node* input, Tensor::size_type poolSize, Tensor::size_type stride,
72+
Tensor::size_type padding = 0);
73+
7174
void MSELoss(Node* input, Node* target);
7275

7376
void BCELoss(Node* input, Node* target);
7477

7578
void defaultOutput(Node* input);
7679
};
80+
81+
inline Node* Model::AvgPool2d(Node* input, Tensor::size_type poolSize, Tensor::size_type stride,
82+
Tensor::size_type padding) {
83+
if (!computeGraph.inGraph(input)) {
84+
computeGraph.addNode(input);
85+
}
86+
auto* avgPoolNode = new calc::AveragePoolingNode(input, poolSize, stride, padding);
87+
hiddenNodes.push_back(avgPoolNode);
88+
computeGraph.addNode(avgPoolNode);
89+
return avgPoolNode;
90+
}
7791
}
7892

7993

include/NeuZephyr/Nodes.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3329,6 +3329,20 @@ namespace nz::nodes {
33293329

33303330
void backward() override;
33313331
};
3332+
3333+
class DL_API AveragePoolingNode : public Node {
3334+
public:
3335+
Tensor::size_type poolSize;
3336+
Tensor::size_type stride;
3337+
Tensor::size_type padding;
3338+
3339+
AveragePoolingNode(Node* input, Tensor::size_type poolSize, Tensor::size_type stride,
3340+
Tensor::size_type padding);
3341+
3342+
void forward() override;
3343+
3344+
void backward() override;
3345+
};
33323346
}
33333347

33343348
/**

include/NeuZephyr/OperationKernels.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
#include <vector>
4949
#include "Dimension.cuh"
5050

51+
#define OUTPUT_DIM(INPUT, KERNEL, STRIDE, PADDING) \
52+
( ((size_t)(INPUT) + 2*(size_t)(PADDING) - (size_t)(KERNEL)) / (size_t)(STRIDE) + 1 )
53+
5154
/**
5255
* @namespace nz::krnl
5356
* @brief High-Performance CUDA Kernel Implementations for Tensor Computations
@@ -1019,6 +1022,16 @@ namespace nz::krnl {
10191022

10201023
void col2imgBackward(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
10211024
const size_t W_out, const size_t C_out, const size_t batches);
1025+
1026+
void AveragePooling(const dim3 gridDim, const dim3 blockDim, float* out, float* in,
1027+
const size_t pool_size, const size_t stride, const size_t padding,
1028+
const size_t batches, const size_t channels, const size_t H_in, const size_t W_in,
1029+
const size_t H_out, const size_t W_out);
1030+
1031+
void AveragePoolingBackward(const dim3 gridDim, const dim3 blockDim, float* out, float* in,
1032+
const size_t pool_size, const size_t stride, const size_t padding,
1033+
const size_t batches, const size_t channels, const size_t H_in, const size_t W_in,
1034+
const size_t H_out, const size_t W_out);
10221035
#endif
10231036
}
10241037

include/NeuZephyr/TensorOperations.cuh

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,5 +1144,33 @@ namespace nz::data {
11441144
}
11451145

11461146
DL_API void iCol2imgBackward(float* out, float* in, size_t H_out, size_t W_out, size_t C_out, size_t batches);
1147+
1148+
DL_API void iAveragePooling(float* out, float* in,
1149+
size_t pool_size, size_t stride, size_t padding,
1150+
size_t batches, size_t channels, size_t H_in, size_t W_in,
1151+
size_t H_out, size_t W_out);
1152+
1153+
template <typename T>
1154+
std::enable_if_t<is_valid_tensor_type<T>::value, T>
1155+
tensorAveragePooling(const T& in, const size_t pool_size, const size_t stride,
1156+
const size_t padding) {
1157+
const size_t H_out = OUTPUT_DIM(in.shape().H(), pool_size, stride, padding);
1158+
const size_t W_out = OUTPUT_DIM(in.shape().W(), pool_size, stride, padding);
1159+
T result({in.shape()[0], in.shape()[1], H_out, W_out}, in.requiresGrad());
1160+
iAveragePooling(result.data(), in.data(), pool_size, stride, padding,
1161+
in.shape()[0], in.shape()[1], in.shape().H(), in.shape().W(),
1162+
H_out, W_out);
1163+
if (in.requiresGrad()) {
1164+
iAveragePooling(result.grad(), in.grad(), pool_size, stride, padding,
1165+
in.shape()[0], in.shape()[1], in.shape().H(), in.shape().W(),
1166+
H_out, W_out);
1167+
}
1168+
return result;
1169+
}
1170+
1171+
DL_API void iAveragePoolingBackward(float* out, float* in,
1172+
size_t pool_size, size_t stride, size_t padding,
1173+
size_t batches, size_t channels, size_t H_in, size_t W_in,
1174+
size_t H_out, size_t W_out);
11471175
}
11481176
#endif //TENSOROPERATIONS_CUH

src/Nodes.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,29 @@ namespace nz::nodes {
661661
iCol2imgBackward(inputs[0]->output->grad(), output->grad(), outputHeight, outputWidth, outputChannels,
662662
inputs[0]->output->shape()[0]);
663663
}
664+
665+
AveragePoolingNode::AveragePoolingNode(Node* input, Tensor::size_type poolSize, Tensor::size_type stride,
666+
Tensor::size_type padding) : poolSize(poolSize), stride(stride), padding(padding) {
667+
inputs.push_back(input);
668+
output = std::make_shared<Tensor>(Tensor::shape_type{
669+
input->output->shape()[0], input->output->shape()[1],
670+
OUTPUT_DIM(input->output->shape()[2], poolSize, stride, padding),
671+
OUTPUT_DIM(input->output->shape()[3], poolSize, stride, padding)
672+
}, input->output->requiresGrad());
673+
type = "AveragePooling";
674+
}
675+
676+
void AveragePoolingNode::forward() {
677+
iAveragePooling(output->data(), inputs[0]->output->data(), poolSize, stride, padding, inputs[0]->output->shape()[0],
678+
inputs[0]->output->shape()[1], inputs[0]->output->shape()[2], inputs[0]->output->shape()[3],
679+
output->shape()[2], output->shape()[3]);
680+
}
681+
682+
void AveragePoolingNode::backward() {
683+
iAveragePoolingBackward(inputs[0]->output->grad(), output->grad(), poolSize, stride, padding, inputs[0]->output->shape()[0],
684+
inputs[0]->output->shape()[1], inputs[0]->output->shape()[2], inputs[0]->output->shape()[3],
685+
output->shape()[2], output->shape()[3]);
686+
}
664687
}
665688

666689
namespace loss {

src/OperationKernels.cu

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,4 +1400,92 @@ namespace nz::krnl {
14001400
StreamManager<float>::Instance().submit(col2imgBackwardKernel, gridDim, blockDim, 0, out, in, H_out, W_out,
14011401
C_out, batches);
14021402
}
1403+
1404+
__global__ void AveragePoolingKernel(float* out, const float* in, const size_t pool_size, const size_t stride, const size_t padding,
1405+
const size_t batches, const size_t channels, const size_t H_in, const size_t W_in, const size_t H_out, const size_t W_out) {
1406+
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
1407+
if (idx >= batches * channels * H_out * W_out) {
1408+
return;
1409+
}
1410+
const size_t currentBatch = idx / (channels * H_out * W_out);
1411+
const size_t currentChannel = (idx % (channels * H_out * W_out)) / (H_out * W_out);
1412+
const size_t h = (idx % (H_out * W_out)) / W_out;
1413+
const size_t w = (idx % (H_out * W_out)) % W_out;
1414+
const long long h_start = h * stride - padding;
1415+
const long long w_start = w * stride - padding;
1416+
out[idx] = 0.0f;
1417+
size_t count = 0;
1418+
for (long long i = 0; i < pool_size; i++) {
1419+
for (long long j = 0; j < pool_size; j++) {
1420+
const long long h_in = h_start + i;
1421+
const long long w_in = w_start + j;
1422+
if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
1423+
out[idx] += in[currentBatch * (channels * H_in * W_in) + currentChannel * (H_in * W_in) + h_in * W_in + w_in];
1424+
count++;
1425+
}
1426+
}
1427+
}
1428+
out[idx] = count > 0 ? out[idx] / (float)count : 0.0f;
1429+
}
1430+
1431+
void AveragePooling(const dim3 gridDim, const dim3 blockDim, float* out, float* in,
1432+
const size_t pool_size, const size_t stride, const size_t padding,
1433+
const size_t batches, const size_t channels, const size_t H_in, const size_t W_in,
1434+
const size_t H_out, const size_t W_out) {
1435+
StreamManager<float>::Instance().submit(AveragePoolingKernel, gridDim, blockDim, 0, out, in,
1436+
pool_size, stride, padding, batches, channels, H_in, W_in, H_out, W_out);
1437+
}
1438+
1439+
__global__ void AveragePoolingBackwardKernel(float* out, const float* in, const size_t pool_size, const size_t stride, const size_t padding,
1440+
const size_t batches, const size_t channels, const size_t H_in, const size_t W_in, const size_t H_out, const size_t W_out) {
1441+
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
1442+
if (idx >= batches * channels * H_out * W_out) {
1443+
return;
1444+
}
1445+
const size_t currentBatch = idx / (channels * H_out * W_out);
1446+
const size_t currentChannel = (idx % (channels * H_out * W_out)) / (H_out * W_out);
1447+
const size_t h = (idx % (H_out * W_out)) / W_out;
1448+
const size_t w = (idx % (H_out * W_out)) % W_out;
1449+
const long long h_start = h * stride - padding;
1450+
const long long w_start = w * stride - padding;
1451+
if (!padding) {
1452+
for (long long i = 0; i < pool_size; i++) {
1453+
for (long long j = 0; j < pool_size; j++) {
1454+
const long long h_in = h_start + i;
1455+
const long long w_in = w_start + j;
1456+
if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
1457+
atomicAdd(out + currentBatch * (channels * H_in * W_in) + currentChannel * (H_in * W_in) + h_in * W_in + w_in, in[idx] / (float)(pool_size*pool_size));
1458+
}
1459+
}
1460+
}
1461+
} else {
1462+
size_t count = 0;
1463+
for (long long i = 0; i < pool_size; i++) {
1464+
for (long long j = 0; j < pool_size; j++) {
1465+
const long long h_in = h_start + i;
1466+
const long long w_in = w_start + j;
1467+
if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
1468+
count++;
1469+
}
1470+
}
1471+
}
1472+
for (long long i = 0; i < pool_size; i++) {
1473+
for (long long j = 0; j < pool_size; j++) {
1474+
const long long h_in = h_start + i;
1475+
const long long w_in = w_start + j;
1476+
if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
1477+
atomicAdd(out + currentBatch * (channels * H_in * W_in) + currentChannel * (H_in * W_in) + h_in * W_in + w_in, in[idx] / (float)count);
1478+
}
1479+
}
1480+
}
1481+
}
1482+
}
1483+
1484+
void AveragePoolingBackward(const dim3 gridDim, const dim3 blockDim, float* out, float* in,
1485+
const size_t pool_size, const size_t stride, const size_t padding,
1486+
const size_t batches, const size_t channels, const size_t H_in, const size_t W_in,
1487+
const size_t H_out, const size_t W_out) {
1488+
StreamManager<float>::Instance().submit(AveragePoolingBackwardKernel, gridDim, blockDim, 0, out, in,
1489+
pool_size, stride, padding, batches, channels, H_in, W_in, H_out, W_out);
1490+
}
14031491
}

src/TensorOperations.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,20 @@ namespace nz::data {
160160
const dim3 grid((H_out * W_out * C_out * batches + BLOCKSIZE - 1) / BLOCKSIZE);
161161
krnl::col2imgBackward(grid, block, out, in, H_out, W_out, C_out, batches);
162162
}
163+
164+
void iAveragePooling(float* out, float* in, const size_t pool_size, const size_t stride, const size_t padding,
165+
const size_t batches, const size_t channels, const size_t H_in, const size_t W_in, const size_t H_out,
166+
const size_t W_out) {
167+
dim3 block(BLOCKSIZE);
168+
dim3 grid((batches * channels * H_out * W_out + BLOCKSIZE - 1) / BLOCKSIZE);
169+
krnl::AveragePooling(grid, block, out, in, pool_size, stride, padding, batches, channels, H_in, W_in, H_out, W_out);
170+
}
171+
172+
void iAveragePoolingBackward(float* out, float* in, const size_t pool_size, const size_t stride, const size_t padding, const size_t batches,
173+
const size_t channels, const size_t H_in, const size_t W_in, const size_t H_out, const size_t W_out) {
174+
dim3 block(BLOCKSIZE);
175+
dim3 grid((batches * channels * H_out * W_out + BLOCKSIZE - 1) / BLOCKSIZE);
176+
krnl::AveragePoolingBackward(grid, block, out, in, pool_size, stride, padding, batches, channels, H_in, W_in,
177+
H_out, W_out);
178+
}
163179
}

0 commit comments

Comments
 (0)