Skip to content

Commit 7f828ff

Browse files
committed
feat(convolution - ops): add col2img operation and related test cases
- Added the col2img operation to handle the reverse transformation in convolution operations. - Developed a series of test cases to validate the functionality and accuracy of the col2img operation. - The test cases include different input sizes and configurations to ensure the robustness of the implementation.
1 parent fc45b32 commit 7f828ff

5 files changed

Lines changed: 83 additions & 1 deletion

File tree

include/NeuZephyr/OperationKernels.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,9 @@ namespace nz::krnl {
10091009
void img2col(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
10101010
const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
10111011
const size_t pad, const size_t H_in, const size_t W_in, const size_t batch);
1012+
1013+
void col2img(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
1014+
const size_t W_out, const size_t C_out, const size_t batches);
10121015
#endif
10131016
}
10141017

include/NeuZephyr/TensorOperations.cuh

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1115,9 +1115,27 @@ namespace nz::data {
11151115
const size_t pad) {
11161116
const size_t H_out = (in.shape().H() + 2 * pad - K_h) / stride + 1;
11171117
const size_t W_out = (in.shape().W() + 2 * pad - K_w) / stride + 1;
1118-
T result({in.shape()[0], 1, H_out * W_out, in.shape().C() * K_h * K_w});
1118+
T result({in.shape()[0], 1, H_out * W_out, in.shape().C() * K_h * K_w}, in.requiresGrad());
11191119
iImg2col(result.data(), in.data(), H_out, W_out, in.shape().C(), K_h, K_w, stride, pad,
11201120
in.shape().H(), in.shape().W(), in.shape()[0]);
1121+
if (in.requiresGrad()) {
1122+
iImg2col(result.grad(), in.grad(), H_out, W_out, in.shape().C(), K_h, K_w, stride, pad,
1123+
in.shape().H(), in.shape().W(), in.shape()[0]);
1124+
}
1125+
return result;
1126+
}
1127+
1128+
DL_API void iCol2img(float* out, float* in, size_t H_out,
1129+
size_t W_out, size_t C_out, size_t batches);
1130+
1131+
template <typename T>
1132+
std::enable_if_t<is_valid_tensor_type<T>::value, T>
1133+
tensorCol2img(const T& in, const size_t H_out, const size_t W_out) {
1134+
T result({in.shape()[0], in.shape()[3], H_out, W_out}, in.requiresGrad());
1135+
iCol2img(result.data(), in.data(), H_out, W_out, in.shape()[3], in.shape()[0]);
1136+
if (in.requiresGrad()) {
1137+
iCol2img(result.grad(), in.grad(), H_out, W_out, in.shape()[3], in.shape()[0]);
1138+
}
11211139
return result;
11221140
}
11231141
}

src/OperationKernels.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,4 +1330,23 @@ namespace nz::krnl {
13301330
StreamManager<float>::Instance().submit(img2colKernel, gridDim, blockDim, 0, out, in, H_out, W_out, C,
13311331
K_h, K_w, stride, pad, H_in, W_in, batch);
13321332
}
1333+
1334+
__global__ void col2imgKernel(float* out, const float* in, const size_t H_out, const size_t W_out, const size_t C_out, const size_t batches) {
1335+
const size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
1336+
if (idx >= H_out * W_out * C_out * batches) {
1337+
return;
1338+
}
1339+
const size_t batch = idx / (C_out * H_out * W_out);
1340+
const size_t fixedIdx = idx % (C_out * H_out * W_out);
1341+
const size_t c = fixedIdx / (H_out * W_out);
1342+
const size_t h = (fixedIdx % (H_out * W_out)) / W_out;
1343+
const size_t w = (fixedIdx % (H_out * W_out)) % W_out;
1344+
out[idx] = in[batch * (C_out * H_out * W_out) + (h * W_out + w) * C_out + c];
1345+
}
1346+
1347+
void col2img(const dim3 gridDim, const dim3 blockDim, float* out, float* in, const size_t H_out,
1348+
const size_t W_out, const size_t C_out, const size_t batches) {
1349+
StreamManager<float>::Instance().submit(col2imgKernel, gridDim, blockDim, 0, out, in, H_out, W_out, C_out,
1350+
batches);
1351+
}
13331352
}

src/TensorOperations.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,11 @@ namespace nz::data {
139139
const dim3 grid((H_out * W_out * C * K_h * K_w * batch + BLOCKSIZE - 1) / BLOCKSIZE);
140140
krnl::img2col(grid, block, out, in, H_out, W_out, C, K_h, K_w, stride, pad, H_in, W_in, batch);
141141
}
142+
143+
void iCol2img(float* out, float* in, const size_t H_out, const size_t W_out, const size_t C_out,
144+
const size_t batches) {
145+
const dim3 block(BLOCKSIZE);
146+
const dim3 grid((H_out * W_out * C_out * batches + BLOCKSIZE - 1) / BLOCKSIZE);
147+
krnl::col2img(grid, block, out, in, H_out, W_out, C_out, batches);
148+
}
142149
}

test/Test.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,4 +2941,39 @@ TEST(TensorBasic, img2colTest) {
29412941
Tensor expected({n, 1, H_out * W_out, k_h * k_w * c});
29422942
expected.dataInject(expectedData.begin(), expectedData.end());
29432943
EXPECT_EQ(expected, result);
2944+
}
2945+
2946+
TEST(TenorBasic, col2imgTest) {
2947+
const size_t n = 2;
2948+
const size_t c = 3;
2949+
const size_t h = 4;
2950+
const size_t w = 5;
2951+
2952+
std::vector<float> inputData({n*c*h*w});
2953+
std::vector<float> expectedData({n*c*h*w});
2954+
2955+
std::random_device rd;
2956+
std::mt19937 gen(rd());
2957+
std::uniform_real_distribution<float> dist(0.1f, 0.9f);
2958+
2959+
for (auto& i : inputData) {
2960+
i = dist(gen);
2961+
}
2962+
for (auto i = 0; i < n; i++) {
2963+
for (auto j = 0; j < c; j++) {
2964+
for (auto k = 0; k < h; k++) {
2965+
for (auto l = 0; l < w; l++) {
2966+
expectedData[i * (c*h*w) + j * (h*w) + k * w + l] =
2967+
inputData[i * (c*h*w) + (k * w + l) * c + j];
2968+
}
2969+
}
2970+
}
2971+
}
2972+
2973+
Tensor input({n ,1, h*w, c});
2974+
input.dataInject(inputData.begin(), inputData.end());
2975+
auto result = tensorCol2img(input, h, w);
2976+
Tensor expected({n, c, h, w});
2977+
expected.dataInject(expectedData.begin(), expectedData.end());
2978+
EXPECT_EQ(expected, result);
29442979
}

0 commit comments

Comments
 (0)