@@ -570,8 +570,10 @@ namespace nz::krnl {
570570 }
571571
572572 void SoftmaxJacobian (const dim3 gridDim , const dim3 blockDim , float * out, float * in,
573- const unsigned long long n, const std::vector<size_t >& offset_o, const std::vector<size_t >& offset_i) {
574- StreamManager<float >::Instance ().submitParallel (SoftmaxJacobianKernel, gridDim , blockDim , 0 , out, in, offset_o, offset_i, n);
573+ const unsigned long long n, const std::vector<size_t >& offset_o,
574+ const std::vector<size_t >& offset_i) {
575+ StreamManager<float >::Instance ().submitParallel (SoftmaxJacobianKernel, gridDim , blockDim , 0 , out, in, offset_o,
576+ offset_i, n);
575577 }
576578
577579 __global__ void MeanSquaredErrorKernel (float * out, const float * predict, const float * real,
@@ -1286,7 +1288,7 @@ namespace nz::krnl {
12861288 }
12871289
12881290 void Expand (const dim3 gridDim , const dim3 blockDim , float * out, float * in, const size_t n,
1289- const size_t total) {
1291+ const size_t total) {
12901292 StreamManager<float >::Instance ().submit (ExpandKernel, gridDim , blockDim , 0 , out, in, n, total);
12911293 }
12921294
@@ -1304,7 +1306,8 @@ namespace nz::krnl {
13041306 }
13051307
13061308 __global__ void img2colKernel (float * out, const float * in, const size_t H_out, const size_t W_out, const size_t C,
1307- const size_t K_h, const size_t K_w, const size_t stride, const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1309+ const size_t K_h, const size_t K_w, const size_t stride, const size_t pad,
1310+ const size_t H_in, const size_t W_in, const size_t batch) {
13081311 const size_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
13091312 if (idx >= H_out * W_out * C * K_h * K_w * batch) {
13101313 return ;
@@ -1325,14 +1328,16 @@ namespace nz::krnl {
13251328 }
13261329
13271330 void img2col (const dim3 gridDim , const dim3 blockDim , float * out, float * in, const size_t H_out,
1328- const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1329- const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1331+ const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1332+ const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
13301333 StreamManager<float >::Instance ().submit (img2colKernel, gridDim , blockDim , 0 , out, in, H_out, W_out, C,
13311334 K_h, K_w, stride, pad, H_in, W_in, batch);
13321335 }
13331336
1334- __global__ void img2colBackwardKernel (float * out, const float * in, const size_t H_out, const size_t W_out, const size_t C,
1335- const size_t K_h, const size_t K_w, const size_t stride, const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1337+ __global__ void img2colBackwardKernel (float * out, const float * in, const size_t H_out, const size_t W_out,
1338+ const size_t C,
1339+ const size_t K_h, const size_t K_w, const size_t stride, const size_t pad,
1340+ const size_t H_in, const size_t W_in, const size_t batch) {
13361341 const size_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
13371342 if (idx >= H_out * W_out * C * K_h * K_w * batch) {
13381343 return ;
@@ -1350,13 +1355,14 @@ namespace nz::krnl {
13501355 }
13511356
13521357 void img2colBackward (const dim3 gridDim , const dim3 blockDim , float * out, float * in, const size_t H_out,
1353- const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1354- const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
1358+ const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1359+ const size_t pad, const size_t H_in, const size_t W_in, const size_t batch) {
13551360 StreamManager<float >::Instance ().submit (img2colBackwardKernel, gridDim , blockDim , 0 , out, in, H_out,
13561361 W_out, C, K_h, K_w, stride, pad, H_in, W_in, batch);
13571362 }
13581363
1359- __global__ void col2imgKernel (float * out, const float * in, const size_t H_out, const size_t W_out, const size_t C_out, const size_t batches) {
1364+ __global__ void col2imgKernel (float * out, const float * in, const size_t H_out, const size_t W_out,
1365+ const size_t C_out, const size_t batches) {
13601366 const size_t idx = blockDim .x * blockIdx .x + threadIdx .x ;
13611367 if (idx >= H_out * W_out * C_out * batches) {
13621368 return ;
@@ -1374,4 +1380,24 @@ namespace nz::krnl {
13741380 StreamManager<float >::Instance ().submit (col2imgKernel, gridDim , blockDim , 0 , out, in, H_out, W_out, C_out,
13751381 batches);
13761382 }
1383+
1384+ __global__ void col2imgBackwardKernel (float * out, const float * in, const size_t H_out, const size_t W_out,
1385+ const size_t C_out, const size_t batches) {
1386+ const size_t idx = blockDim .x * blockIdx .x + threadIdx .x ;
1387+ if (idx >= H_out * W_out * C_out * batches) {
1388+ return ;
1389+ }
1390+ const size_t batch = idx / (C_out * H_out * W_out);
1391+ const size_t fixedIdx = idx % (C_out * H_out * W_out);
1392+ const size_t c = fixedIdx / (H_out * W_out);
1393+ const size_t h = (fixedIdx % (H_out * W_out)) / W_out;
1394+ const size_t w = (fixedIdx % (H_out * W_out)) % W_out;
1395+ out[batch * (C_out * H_out * W_out) + (h * W_out + w) * C_out + c] = in[idx];
1396+ }
1397+
1398+ void col2imgBackward (const dim3 gridDim , const dim3 blockDim , float * out, float * in, const size_t H_out,
1399+ const size_t W_out, const size_t C_out, const size_t batches) {
1400+ StreamManager<float >::Instance ().submit (col2imgBackwardKernel, gridDim , blockDim , 0 , out, in, H_out, W_out,
1401+ C_out, batches);
1402+ }
13771403}
0 commit comments