Skip to content

Commit 54ef859

Browse files
committed
feat(convolution - ops): add backpropagation for col2img, computational graph nodes, and related test cases
- Added backpropagation functionality for the col2img operation to enable gradient flow. - Integrated computational graph nodes for col2img backpropagation to support automatic differentiation. - Developed a suite of test cases to validate the correctness of the backpropagation and graph nodes. - The test cases cover various input sizes and configurations to ensure the robustness of the implementation.
1 parent 4be3aac commit 54ef859

7 files changed

Lines changed: 164 additions & 17 deletions

File tree

include/NeuZephyr/Nodes.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3316,6 +3316,19 @@ namespace nz::nodes {
33163316

33173317
void backward() override;
33183318
};
3319+
3320+
class DL_API Col2ImgNode : public Node {
3321+
public:
3322+
Tensor::size_type outputHeight;
3323+
Tensor::size_type outputWidth;
3324+
Tensor::size_type outputChannels;
3325+
3326+
Col2ImgNode(Node* input, Tensor::size_type outputHeight, Tensor::size_type outputWidth);
3327+
3328+
void forward() override;
3329+
3330+
void backward() override;
3331+
};
33193332
}
33203333

33213334
/**

include/NeuZephyr/OperationKernels.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,9 @@ namespace nz::krnl {
10161016

10171017
void col2img(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
10181018
const size_t W_out, const size_t C_out, const size_t batches);
1019+
1020+
void col2imgBackward(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
1021+
const size_t W_out, const size_t C_out, const size_t batches);
10191022
#endif
10201023
}
10211024

include/NeuZephyr/TensorOperations.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,5 +1142,7 @@ namespace nz::data {
11421142
}
11431143
return result;
11441144
}
1145+
1146+
DL_API void iCol2imgBackward(float* out, float* in, size_t H_out, size_t W_out, size_t C_out, size_t batches);
11451147
}
11461148
#endif //TENSOROPERATIONS_CUH

src/Nodes.cu

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -622,18 +622,45 @@ namespace nz::nodes {
622622
}
623623

624624
void Img2ColNode::forward() {
625-
iImg2col(output->data(), inputs[0]->output->data(), outputHeight, outputWidth, inputs[0]->output->shape()[1],
626-
kernelHeight, kernelWidth, stride, padding, inputs[0]->output->shape()[2], inputs[0]->output->shape()[3],
627-
inputs[0]->output->shape()[0]);
625+
iImg2col(output->data(), inputs[0]->output->data(), outputHeight, outputWidth,
626+
inputs[0]->output->shape()[1],
627+
kernelHeight, kernelWidth, stride, padding, inputs[0]->output->shape()[2],
628+
inputs[0]->output->shape()[3],
629+
inputs[0]->output->shape()[0]);
628630
}
629631

630632
void Img2ColNode::backward() {
631633
if (inputs[0]->output->requiresGrad()) {
632-
iImg2colBackward(inputs[0]->output->grad(), output->grad(), outputHeight, outputWidth, inputs[0]->output->shape()[1],
633-
kernelHeight, kernelWidth, stride, padding, inputs[0]->output->shape()[2], inputs[0]->output->shape()[3],
634-
inputs[0]->output->shape()[0]);
634+
iImg2colBackward(inputs[0]->output->grad(), output->grad(), outputHeight, outputWidth,
635+
inputs[0]->output->shape()[1],
636+
kernelHeight, kernelWidth, stride, padding, inputs[0]->output->shape()[2],
637+
inputs[0]->output->shape()[3],
638+
inputs[0]->output->shape()[0]);
635639
}
636640
}
641+
642+
Col2ImgNode::Col2ImgNode(Node* input, const Tensor::size_type outputHeight,
643+
const Tensor::size_type outputWidth) : outputHeight(outputHeight),
644+
outputWidth(outputWidth),
645+
outputChannels(input->output->shape()[3]) {
646+
inputs.push_back(input);
647+
output = std::make_shared<Tensor>(Tensor::shape_type(
648+
input->output->shape()[0],
649+
outputChannels,
650+
outputHeight,
651+
outputWidth), input->output->requiresGrad());
652+
type = "Col2Img";
653+
}
654+
655+
void Col2ImgNode::forward() {
656+
iCol2img(output->data(), inputs[0]->output->data(), outputHeight, outputWidth, outputChannels,
657+
inputs[0]->output->shape()[0]);
658+
}
659+
660+
void Col2ImgNode::backward() {
661+
iCol2imgBackward(inputs[0]->output->grad(), output->grad(), outputHeight, outputWidth, outputChannels,
662+
inputs[0]->output->shape()[0]);
663+
}
637664
}
638665

639666
namespace loss {

src/OperationKernels.cu

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,10 @@ namespace nz::krnl {
570570
}
571571

572572
void SoftmaxJacobian(const dim3 gridDim, const dim3 blockDim, float* out, float* in,
573-
const unsigned long long n, const std::vector<size_t>& offset_o, const std::vector<size_t>& offset_i) {
574-
StreamManager<float>::Instance().submitParallel(SoftmaxJacobianKernel, gridDim, blockDim, 0, out, in, offset_o, offset_i, n);
573+
const unsigned long long n, const std::vector<size_t>& offset_o,
574+
const std::vector<size_t>& offset_i) {
575+
StreamManager<float>::Instance().submitParallel(SoftmaxJacobianKernel, gridDim, blockDim, 0, out, in, offset_o,
576+
offset_i, n);
575577
}
576578

577579
__global__ void MeanSquaredErrorKernel(float* out, const float* predict, const float* real,
@@ -1286,7 +1288,7 @@ namespace nz::krnl {
12861288
}
12871289

12881290
void Expand(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t n,
1289-
const size_t total) {
1291+
const size_t total) {
12901292
StreamManager<float>::Instance().submit(ExpandKernel, gridDim, blockDim, 0, out, in, n, total);
12911293
}
12921294

@@ -1304,7 +1306,8 @@ namespace nz::krnl {
13041306
}
13051307

13061308
__global__ void img2colKernel(float* out, const float* in, const size_t H_out, const size_t W_out, const size_t C,
1307-
const size_t K_h, const size_t K_w, const size_t stride, const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1309+
const size_t K_h, const size_t K_w, const size_t stride, const size_t pad,
1310+
const size_t H_in, const size_t W_in, const size_t batch) {
13081311
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
13091312
if (idx >= H_out * W_out * C * K_h * K_w * batch) {
13101313
return;
@@ -1325,14 +1328,16 @@ namespace nz::krnl {
13251328
}
13261329

13271330
void img2col(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
1328-
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1329-
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1331+
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1332+
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
13301333
StreamManager<float>::Instance().submit(img2colKernel, gridDim, blockDim, 0, out, in, H_out, W_out, C,
13311334
K_h, K_w, stride, pad, H_in, W_in, batch);
13321335
}
13331336

1334-
__global__ void img2colBackwardKernel(float* out, const float* in, const size_t H_out, const size_t W_out, const size_t C,
1335-
const size_t K_h, const size_t K_w, const size_t stride, const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1337+
__global__ void img2colBackwardKernel(float* out, const float* in, const size_t H_out, const size_t W_out,
1338+
const size_t C,
1339+
const size_t K_h, const size_t K_w, const size_t stride, const size_t pad,
1340+
const size_t H_in, const size_t W_in, const size_t batch) {
13361341
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
13371342
if (idx >= H_out * W_out * C * K_h * K_w * batch) {
13381343
return;
@@ -1350,13 +1355,14 @@ namespace nz::krnl {
13501355
}
13511356

13521357
void img2colBackward(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
1353-
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1354-
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1358+
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1359+
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
13551360
StreamManager<float>::Instance().submit(img2colBackwardKernel, gridDim, blockDim, 0, out, in, H_out,
13561361
W_out, C, K_h, K_w, stride, pad, H_in, W_in, batch);
13571362
}
13581363

1359-
__global__ void col2imgKernel(float* out, const float* in, const size_t H_out, const size_t W_out, const size_t C_out, const size_t batches) {
1364+
__global__ void col2imgKernel(float* out, const float* in, const size_t H_out, const size_t W_out,
1365+
const size_t C_out, const size_t batches) {
13601366
const size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
13611367
if (idx >= H_out * W_out * C_out * batches) {
13621368
return;
@@ -1374,4 +1380,24 @@ namespace nz::krnl {
13741380
StreamManager<float>::Instance().submit(col2imgKernel, gridDim, blockDim, 0, out, in, H_out, W_out, C_out,
13751381
batches);
13761382
}
1383+
1384+
__global__ void col2imgBackwardKernel(float* out, const float* in, const size_t H_out, const size_t W_out,
1385+
const size_t C_out, const size_t batches) {
1386+
const size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
1387+
if (idx >= H_out * W_out * C_out * batches) {
1388+
return;
1389+
}
1390+
const size_t batch = idx / (C_out * H_out * W_out);
1391+
const size_t fixedIdx = idx % (C_out * H_out * W_out);
1392+
const size_t c = fixedIdx / (H_out * W_out);
1393+
const size_t h = (fixedIdx % (H_out * W_out)) / W_out;
1394+
const size_t w = (fixedIdx % (H_out * W_out)) % W_out;
1395+
out[batch * (C_out * H_out * W_out) + (h * W_out + w) * C_out + c] = in[idx];
1396+
}
1397+
1398+
void col2imgBackward(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
1399+
const size_t W_out, const size_t C_out, const size_t batches) {
1400+
StreamManager<float>::Instance().submit(col2imgBackwardKernel, gridDim, blockDim, 0, out, in, H_out, W_out,
1401+
C_out, batches);
1402+
}
13771403
}

src/TensorOperations.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,10 @@ namespace nz::data {
154154
const dim3 grid((H_out * W_out * C_out * batches + BLOCKSIZE - 1) / BLOCKSIZE);
155155
krnl::col2img(grid, block, out, in, H_out, W_out, C_out, batches);
156156
}
157+
158+
void iCol2imgBackward(float* out, float* in, size_t H_out, size_t W_out, size_t C_out, size_t batches) {
159+
const dim3 block(BLOCKSIZE);
160+
const dim3 grid((H_out * W_out * C_out * batches + BLOCKSIZE - 1) / BLOCKSIZE);
161+
krnl::col2imgBackward(grid, block, out, in, H_out, W_out, C_out, batches);
162+
}
157163
}

test/Test.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3105,4 +3105,74 @@ TEST(TenorBasic, col2imgTest) {
31053105
Tensor expected({n, c, h, w});
31063106
expected.dataInject(expectedData.begin(), expectedData.end());
31073107
EXPECT_EQ(expected, result);
3108+
}
3109+
3110+
TEST(NodeBasic, col2imgForward) {
3111+
const size_t n = 2;
3112+
const size_t c = 3;
3113+
const size_t h = 4;
3114+
const size_t w = 5;
3115+
3116+
std::vector<float> inputData({n*c*h*w});
3117+
std::vector<float> expectedData({n*c*h*w});
3118+
3119+
std::random_device rd;
3120+
std::mt19937 gen(rd());
3121+
std::uniform_real_distribution<float> dist(0.1f, 0.9f);
3122+
for (auto& i : inputData) {
3123+
i = dist(gen);
3124+
}
3125+
for (auto i = 0; i < n; i++) {
3126+
for (auto j = 0; j < c; j++) {
3127+
for (auto k = 0; k < h; k++) {
3128+
for (auto l = 0; l < w; l++) {
3129+
expectedData[i * (c*h*w) + j * (h*w) + k * w + l] =
3130+
inputData[i * (c*h*w) + (k * w + l) * c + j];
3131+
}
3132+
}
3133+
}
3134+
}
3135+
3136+
InputNode input({n ,1, h*w, c});
3137+
input.dataInject(inputData.begin(), inputData.end());
3138+
Col2ImgNode result(&input, h, w);
3139+
result.forward();
3140+
Tensor expected({n, c, h, w});
3141+
expected.dataInject(expectedData.begin(), expectedData.end());
3142+
EXPECT_EQ(expected, *result.output);
3143+
}
3144+
3145+
TEST(NodeBasic, Col2imgBackward) {
3146+
const size_t n = 2;
3147+
const size_t c = 3;
3148+
const size_t h = 4;
3149+
const size_t w = 5;
3150+
3151+
std::vector<float> inputData({n*c*h*w});
3152+
std::vector<float> expectedData({n*c*h*w});
3153+
3154+
std::random_device rd;
3155+
std::mt19937 gen(rd());
3156+
std::uniform_real_distribution<float> dist(0.1f, 0.9f);
3157+
for (auto& i : inputData) {
3158+
i = dist(gen);
3159+
}
3160+
for (auto i = 0; i < n; i++) {
3161+
for (auto j = 0; j < c; j++) {
3162+
for (auto k = 0; k < h; k++) {
3163+
for (auto l = 0; l < w; l++) {
3164+
expectedData[i * (c*h*w) + j * (h*w) + k * w + l] =
3165+
inputData[i * (c*h*w) + (k * w + l) * c + j];
3166+
}
3167+
}
3168+
}
3169+
}
3170+
3171+
InputNode input({n, 1, h*w, c}, true);
3172+
Col2ImgNode result(&input, h, w);
3173+
result.dataInject(expectedData.begin(), expectedData.end(), true);
3174+
result.backward();
3175+
Tensor expected({n, 1, h*w, c}, true);
3176+
expected.dataInject(inputData.begin(), inputData.end(), true);
3177+
EXPECT_EQ(expected, *input.output);
31083178
}

0 commit comments

Comments
 (0)