Skip to content

Commit 4be3aac

Browse files
committed
feat(convolution - ops): add backpropagation for img2col, computational graph nodes, and related test cases
- Implemented the backpropagation algorithm for the img2col operation to support gradient computation. - Added computational graph nodes for the img2col operation and its backpropagation to facilitate automatic differentiation. - Developed a comprehensive set of test cases to verify the correctness of the backpropagation and computational graph nodes. - These test cases cover different input scenarios and gradients to ensure the stability and accuracy of the implementation.
1 parent 7f828ff commit 4be3aac

7 files changed

Lines changed: 223 additions & 3 deletions

File tree

include/NeuZephyr/Nodes.cuh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3299,6 +3299,23 @@ namespace nz::nodes {
32993299

33003300
void backward() override;
33013301
};
3302+
3303+
class DL_API Img2ColNode : public Node {
3304+
public:
3305+
Tensor::size_type kernelHeight;
3306+
Tensor::size_type kernelWidth;
3307+
Tensor::size_type stride;
3308+
Tensor::size_type padding;
3309+
Tensor::size_type outputHeight;
3310+
Tensor::size_type outputWidth;
3311+
3312+
Img2ColNode(Node* input, Tensor::size_type kernelHeight, Tensor::size_type kernelWidth,
3313+
Tensor::size_type stride, Tensor::size_type padding);
3314+
3315+
void forward() override;
3316+
3317+
void backward() override;
3318+
};
33023319
}
33033320

33043321
/**

include/NeuZephyr/OperationKernels.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,10 @@ namespace nz::krnl {
10101010
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
10111011
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch);
10121012

1013+
void img2colBackward(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
1014+
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1015+
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch);
1016+
10131017
void col2img(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
10141018
const size_t W_out, const size_t C_out, const size_t batches);
10151019
#endif

include/NeuZephyr/TensorOperations.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,10 @@ namespace nz::data {
11091109
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
11101110
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch);
11111111

1112+
DL_API void iImg2colBackward(float* out, float* in, const size_t H_out,
1113+
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1114+
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch);
1115+
11121116
template <typename T>
11131117
std::enable_if_t<is_valid_tensor_type<T>::value, T>
11141118
tensorImg2col(const T& in, const size_t K_h, const size_t K_w, const size_t stride,

src/Nodes.cu

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ namespace nz::nodes {
584584

585585
void ExpandNode::forward() {
586586
const auto size = inputs[0]->output->shape()[1] * inputs[0]->output->shape()[2] *
587-
inputs[0]->output->shape()[3];
587+
inputs[0]->output->shape()[3];
588588
const auto total = size * newBatch;
589589
const dim3 block(BLOCKSIZE);
590590
const dim3 grid((total + block.x - 1) / block.x);
@@ -594,13 +594,46 @@ namespace nz::nodes {
594594
void ExpandNode::backward() {
595595
if (inputs[0]->output->requiresGrad()) {
596596
const auto size = inputs[0]->output->shape()[1] * inputs[0]->output->shape()[2] *
597-
inputs[0]->output->shape()[3];
597+
inputs[0]->output->shape()[3];
598598
const auto total = size * newBatch;
599599
const dim3 block(BLOCKSIZE);
600600
const dim3 grid((total + block.x - 1) / block.x);
601601
Compress(grid, block, inputs[0]->output->grad(), output->grad(), size, total);
602602
}
603603
}
604+
605+
Img2ColNode::Img2ColNode(Node* input, const Tensor::size_type kernelHeight, const Tensor::size_type kernelWidth,
606+
const Tensor::size_type stride,
607+
const Tensor::size_type padding) : kernelHeight(kernelHeight),
608+
kernelWidth(kernelWidth),
609+
stride(stride), padding(padding),
610+
outputHeight(
611+
(input->output->shape().H() + 2 * padding -
612+
kernelHeight) / stride + 1),
613+
outputWidth(
614+
(input->output->shape().W() + 2 * padding -
615+
kernelWidth) / stride + 1) {
616+
inputs.push_back(input);
617+
output = std::make_shared<Tensor>(Tensor::shape_type{
618+
input->output->shape()[0], 1, outputHeight * outputWidth,
619+
kernelHeight * kernelWidth * input->output->shape()[1]
620+
}, input->output->requiresGrad());
621+
type = "Img2Col";
622+
}
623+
624+
void Img2ColNode::forward() {
625+
iImg2col(output->data(), inputs[0]->output->data(), outputHeight, outputWidth, inputs[0]->output->shape()[1],
626+
kernelHeight, kernelWidth, stride, padding, inputs[0]->output->shape()[2], inputs[0]->output->shape()[3],
627+
inputs[0]->output->shape()[0]);
628+
}
629+
630+
void Img2ColNode::backward() {
631+
if (inputs[0]->output->requiresGrad()) {
632+
iImg2colBackward(inputs[0]->output->grad(), output->grad(), outputHeight, outputWidth, inputs[0]->output->shape()[1],
633+
kernelHeight, kernelWidth, stride, padding, inputs[0]->output->shape()[2], inputs[0]->output->shape()[3],
634+
inputs[0]->output->shape()[0]);
635+
}
636+
}
604637
}
605638

606639
namespace loss {

src/OperationKernels.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,31 @@ namespace nz::krnl {
13311331
K_h, K_w, stride, pad, H_in, W_in, batch);
13321332
}
13331333

1334+
__global__ void img2colBackwardKernel(float* out, const float* in, const size_t H_out, const size_t W_out, const size_t C,
1335+
const size_t K_h, const size_t K_w, const size_t stride, const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1336+
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
1337+
if (idx >= H_out * W_out * C * K_h * K_w * batch) {
1338+
return;
1339+
}
1340+
const size_t fixedIdx = idx % (H_out * W_out * C * K_h * K_w);
1341+
const size_t currentBatch = idx / (H_out * W_out * C * K_h * K_w);
1342+
const size_t k = fixedIdx / (C * K_h * K_w);
1343+
const size_t m = fixedIdx % (C * K_h * K_w);
1344+
const size_t c = m / (K_h * K_w);
1345+
const long long h = (k / W_out) * stride - pad + (m % (K_h * K_w)) / K_w;
1346+
const long long w = (k % W_out) * stride - pad + m % K_w;
1347+
if (h >= 0 && h < H_in && w >= 0 && w < W_in) {
1348+
atomicAdd(out + currentBatch * (C * H_in * W_in) + c * (H_in * W_in) + h * W_in + w, in[idx]);
1349+
}
1350+
}
1351+
1352+
void img2colBackward(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
1353+
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1354+
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1355+
StreamManager<float>::Instance().submit(img2colBackwardKernel, gridDim, blockDim, 0, out, in, H_out,
1356+
W_out, C, K_h, K_w, stride, pad, H_in, W_in, batch);
1357+
}
1358+
13341359
__global__ void col2imgKernel(float* out, const float* in, const size_t H_out, const size_t W_out, const size_t C_out, const size_t batches) {
13351360
const size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
13361361
if (idx >= H_out * W_out * C_out * batches) {

src/TensorOperations.cu

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,16 @@ namespace nz::data {
140140
krnl::img2col(grid, block, out, in, H_out, W_out, C, K_h, K_w, stride, pad, H_in, W_in, batch);
141141
}
142142

143+
void iImg2colBackward(float* out, float* in, const size_t H_out, const size_t W_out, const size_t C,
144+
const size_t K_h, const size_t K_w, const size_t stride, const size_t pad, const size_t H_in, const size_t W_in,
145+
const size_t batch) {
146+
const dim3 block(BLOCKSIZE);
147+
const dim3 grid((H_out * W_out * C * K_h * K_w * batch + BLOCKSIZE - 1) / BLOCKSIZE);
148+
krnl::img2colBackward(grid, block, out, in, H_out, W_out, C, K_h, K_w, stride, pad, H_in, W_in, batch);
149+
}
150+
143151
void iCol2img(float* out, float* in, const size_t H_out, const size_t W_out, const size_t C_out,
144-
const size_t batches) {
152+
const size_t batches) {
145153
const dim3 block(BLOCKSIZE);
146154
const dim3 grid((H_out * W_out * C_out * batches + BLOCKSIZE - 1) / BLOCKSIZE);
147155
krnl::col2img(grid, block, out, in, H_out, W_out, C_out, batches);

test/Test.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2943,6 +2943,135 @@ TEST(TensorBasic, img2colTest) {
29432943
EXPECT_EQ(expected, result);
29442944
}
29452945

2946+
TEST(NodeBasic, img2colForward) {
2947+
const size_t n = 2;
2948+
const size_t c = 3;
2949+
const size_t h = 4;
2950+
const size_t w = 5;
2951+
const size_t k_h = 3;
2952+
const size_t k_w = 3;
2953+
const size_t stride = 1;
2954+
const size_t pad = 1;
2955+
const size_t H_out = (h + 2 * pad - k_h) / stride + 1;
2956+
const size_t W_out = (w + 2 * pad - k_w) / stride + 1;
2957+
2958+
std::vector<float> inputData({n*c*h*w});
2959+
std::vector<float> expectedData({n*H_out*W_out*k_h*k_w*c});
2960+
2961+
std::random_device rd;
2962+
std::mt19937 gen(rd());
2963+
std::uniform_real_distribution<float> dist(0.1f, 0.9f);
2964+
2965+
for (auto& i : inputData) {
2966+
i = dist(gen);
2967+
}
2968+
2969+
for (size_t b = 0; b < n; ++b) {
2970+
for (size_t i = 0; i < H_out; ++i) {
2971+
for (size_t j = 0; j < W_out; ++j) {
2972+
const int h_start = static_cast<int>(i * stride) - pad;
2973+
const int w_start = static_cast<int>(j * stride) - pad;
2974+
2975+
for (size_t r = 0; r < k_h; ++r) {
2976+
const int h_in = h_start + r;
2977+
for (size_t s = 0; s < k_w; ++s) {
2978+
const int w_in = w_start + s;
2979+
for (size_t c_in = 0; c_in < c; ++c_in) {
2980+
float val = 0.0f;
2981+
if (h_in >= 0 && h_in < h && w_in >= 0 && w_in < w) {
2982+
const size_t input_idx =
2983+
b * (c * h * w) +
2984+
c_in * (h * w) +
2985+
h_in * w +
2986+
w_in;
2987+
val = inputData[input_idx];
2988+
}
2989+
const size_t expected_idx =
2990+
b * (H_out * W_out * k_h * k_w * c) +
2991+
(i * W_out + j) * (k_h * k_w * c) +
2992+
c_in * (k_h * k_w) +
2993+
r * k_w +
2994+
s;
2995+
expectedData[expected_idx] = val;
2996+
}
2997+
}
2998+
}
2999+
}
3000+
}
3001+
}
3002+
3003+
InputNode input({n, c, h, w});
3004+
input.dataInject(inputData.begin(), inputData.end());
3005+
Img2ColNode result(&input, k_h, k_w, stride, pad);
3006+
result.forward();
3007+
Tensor expected({n, 1, H_out * W_out, k_h * k_w * c});
3008+
expected.dataInject(expectedData.begin(), expectedData.end());
3009+
EXPECT_EQ(expected, *result.output);
3010+
}
3011+
3012+
TEST(NodeBasic, img2colBackward) {
3013+
const size_t n = 2;
3014+
const size_t c = 3;
3015+
const size_t h = 4;
3016+
const size_t w = 5;
3017+
const size_t k_h = 3;
3018+
const size_t k_w = 3;
3019+
const size_t stride = 1;
3020+
const size_t pad = 1;
3021+
const size_t H_out = (h + 2 * pad - k_h) / stride + 1;
3022+
const size_t W_out = (w + 2 * pad - k_w) / stride + 1;
3023+
3024+
std::vector<float> gradData({n*H_out*W_out*k_h*k_w*c});
3025+
std::vector<float> expectedGradData({n*c*h*w});
3026+
3027+
std::random_device rd;
3028+
std::mt19937 gen(rd());
3029+
std::uniform_real_distribution<float> dist(0.1f, 0.9f);
3030+
3031+
for (auto& i : gradData) {
3032+
i = dist(gen);
3033+
}
3034+
3035+
for (size_t b = 0; b < n; ++b) {
3036+
for (size_t i = 0; i < H_out; ++i) {
3037+
for (size_t j = 0; j < W_out; ++j) {
3038+
const int h_start = static_cast<int>(i * stride) - pad;
3039+
const int w_start = static_cast<int>(j * stride) - pad;
3040+
for (size_t r = 0; r < k_h; ++r) {
3041+
const int h_in = h_start + r;
3042+
for (size_t s = 0; s < k_w; ++s) {
3043+
const int w_in = w_start + s;
3044+
for (size_t c_in = 0; c_in < c; ++c_in) {
3045+
if (h_in >= 0 && h_in < h && w_in >= 0 && w_in < w) {
3046+
const size_t input_idx =
3047+
b * (c * h * w) +
3048+
c_in * (h * w) +
3049+
h_in * w +
3050+
w_in;
3051+
const size_t grad_idx =
3052+
b * (H_out * W_out * k_h * k_w * c) +
3053+
(i * W_out + j) * (k_h * k_w * c) +
3054+
c_in * (k_h * k_w) +
3055+
r * k_w +
3056+
s;
3057+
expectedGradData[input_idx] += gradData[grad_idx];
3058+
}
3059+
}
3060+
}
3061+
}
3062+
}
3063+
}
3064+
}
3065+
3066+
InputNode input({n, c, h, w}, true);
3067+
Img2ColNode result(&input, k_h, k_w, stride, pad);
3068+
result.dataInject(gradData.begin(), gradData.end(), true);
3069+
result.backward();
3070+
Tensor expected({n, c, h, w}, true);
3071+
expected.dataInject(expectedGradData.begin(), expectedGradData.end(), true);
3072+
EXPECT_EQ(expected, *input.output);
3073+
}
3074+
29463075
TEST(TenorBasic, col2imgTest) {
29473076
const size_t n = 2;
29483077
const size_t c = 3;

0 commit comments

Comments
 (0)