@@ -798,8 +798,8 @@ namespace nz::data {
798798 for (auto j = 0 ; j < out.shape ()[1 ]; j++) {
799799 offsetC.push_back (i * out.shape ().getStride (0 ) + j * out.shape ().getStride (1 ));
800800 offsetA.push_back (i * (lhs.shape ().N () > 1 ? lhs.shape ().getStride (0 ) : 0 ) + j * (lhs.shape ().C () > 1
801- ? lhs.shape ().getStride (1 )
802- : 0 ));
801+ ? lhs.shape ().getStride (1 )
802+ : 0 ));
803803 offsetB.push_back (i * (rhs.shape ().N () > 1 ? rhs.shape ().getStride (0 ) : 0 ) + j * (
804804 rhs.shape ().C () > 1 ? rhs.shape ().getStride (1 ) : 0 ));
805805 }
@@ -869,8 +869,8 @@ namespace nz::data {
869869 for (auto j = 0 ; j < out.shape ()[1 ]; j++) {
870870 offsetC.push_back (i * out.shape ().getStride (0 ) + j * out.shape ().getStride (1 ));
871871 offsetA.push_back (i * (lhs.shape ().N () > 1 ? lhs.shape ().getStride (0 ) : 0 ) + j * (lhs.shape ().C () > 1
872- ? lhs.shape ().getStride (1 )
873- : 0 ));
872+ ? lhs.shape ().getStride (1 )
873+ : 0 ));
874874 offsetB.push_back (i * (rhs.shape ().N () > 1 ? rhs.shape ().getStride (0 ) : 0 ) + j * (
875875 rhs.shape ().C () > 1 ? rhs.shape ().getStride (1 ) : 0 ));
876876 }
@@ -939,8 +939,8 @@ namespace nz::data {
939939 for (auto j = 0 ; j < out.shape ()[1 ]; j++) {
940940 offsetC.push_back (i * out.shape ().getStride (0 ) + j * out.shape ().getStride (1 ));
941941 offsetA.push_back (i * (lhs.shape ().N () > 1 ? lhs.shape ().getStride (0 ) : 0 ) + j * (lhs.shape ().C () > 1
942- ? lhs.shape ().getStride (1 )
943- : 0 ));
942+ ? lhs.shape ().getStride (1 )
943+ : 0 ));
944944 offsetB.push_back (i * (rhs.shape ().N () > 1 ? rhs.shape ().getStride (0 ) : 0 ) + j * (
945945 rhs.shape ().C () > 1 ? rhs.shape ().getStride (1 ) : 0 ));
946946 }
@@ -1008,8 +1008,8 @@ namespace nz::data {
10081008 for (auto j = 0 ; j < out.shape ()[1 ]; j++) {
10091009 offsetC.push_back (i * out.shape ().getStride (0 ) + j * out.shape ().getStride (1 ));
10101010 offsetA.push_back (i * (lhs.shape ().N () > 1 ? lhs.shape ().getStride (0 ) : 0 ) + j * (lhs.shape ().C () > 1
1011- ? lhs.shape ().getStride (1 )
1012- : 0 ));
1011+ ? lhs.shape ().getStride (1 )
1012+ : 0 ));
10131013 offsetB.push_back (i * (rhs.shape ().N () > 1 ? rhs.shape ().getStride (0 ) : 0 ) + j * (
10141014 rhs.shape ().C () > 1 ? rhs.shape ().getStride (1 ) : 0 ));
10151015 }
@@ -1085,7 +1085,8 @@ namespace nz::data {
10851085 return result;
10861086 }
10871087
1088- DL_API void iSoftmaxJacobian (float * out, float * in, size_t n, const std::vector<size_t >& offset_o, const std::vector<size_t >& offset_i);
1088+ DL_API void iSoftmaxJacobian (float * out, float * in, size_t n, const std::vector<size_t >& offset_o,
1089+ const std::vector<size_t >& offset_i);
10891090
10901091 template <typename T>
10911092 std::enable_if_t <is_valid_tensor_type<T>::value, T>
@@ -1103,5 +1104,21 @@ namespace nz::data {
11031104 iSoftmaxJacobian (result.data (), in.data (), n, offset_o, offset_i);
11041105 return result;
11051106 }
1107+
1108+ DL_API void iImg2col (float * out, float * in, const size_t H_out,
1109+ const size_t W_out, const size_t C, const size_t K_h, const size_t K_w, const size_t stride,
1110+ const size_t pad, const size_t H_in, const size_t W_in, const size_t batch);
1111+
1112+ template <typename T>
1113+ std::enable_if_t <is_valid_tensor_type<T>::value, T>
1114+ tensorImg2col (const T& in, const size_t K_h, const size_t K_w, const size_t stride,
1115+ const size_t pad) {
1116+ const size_t H_out = (in.shape ().H () + 2 * pad - K_h) / stride + 1 ;
1117+ const size_t W_out = (in.shape ().W () + 2 * pad - K_w) / stride + 1 ;
1118+ T result ({in.shape ()[0 ], 1 , H_out * W_out, in.shape ().C () * K_h * K_w});
1119+ iImg2col (result.data (), in.data (), H_out, W_out, in.shape ().C (), K_h, K_w, stride, pad,
1120+ in.shape ().H (), in.shape ().W (), in.shape ()[0 ]);
1121+ return result;
1122+ }
11061123}
11071124#endif // TENSOROPERATIONS_CUH
0 commit comments