@@ -2943,6 +2943,135 @@ TEST(TensorBasic, img2colTest) {
29432943 EXPECT_EQ (expected, result);
29442944}
29452945
2946+ TEST (NodeBasic, img2colForward) {
2947+ const size_t n = 2 ;
2948+ const size_t c = 3 ;
2949+ const size_t h = 4 ;
2950+ const size_t w = 5 ;
2951+ const size_t k_h = 3 ;
2952+ const size_t k_w = 3 ;
2953+ const size_t stride = 1 ;
2954+ const size_t pad = 1 ;
2955+ const size_t H_out = (h + 2 * pad - k_h) / stride + 1 ;
2956+ const size_t W_out = (w + 2 * pad - k_w) / stride + 1 ;
2957+
2958+ std::vector<float > inputData ({n*c*h*w});
2959+ std::vector<float > expectedData ({n*H_out*W_out*k_h*k_w*c});
2960+
2961+ std::random_device rd;
2962+ std::mt19937 gen (rd ());
2963+ std::uniform_real_distribution<float > dist (0 .1f , 0 .9f );
2964+
2965+ for (auto & i : inputData) {
2966+ i = dist (gen);
2967+ }
2968+
2969+ for (size_t b = 0 ; b < n; ++b) {
2970+ for (size_t i = 0 ; i < H_out; ++i) {
2971+ for (size_t j = 0 ; j < W_out; ++j) {
2972+ const int h_start = static_cast <int >(i * stride) - pad;
2973+ const int w_start = static_cast <int >(j * stride) - pad;
2974+
2975+ for (size_t r = 0 ; r < k_h; ++r) {
2976+ const int h_in = h_start + r;
2977+ for (size_t s = 0 ; s < k_w; ++s) {
2978+ const int w_in = w_start + s;
2979+ for (size_t c_in = 0 ; c_in < c; ++c_in) {
2980+ float val = 0 .0f ;
2981+ if (h_in >= 0 && h_in < h && w_in >= 0 && w_in < w) {
2982+ const size_t input_idx =
2983+ b * (c * h * w) +
2984+ c_in * (h * w) +
2985+ h_in * w +
2986+ w_in;
2987+ val = inputData[input_idx];
2988+ }
2989+ const size_t expected_idx =
2990+ b * (H_out * W_out * k_h * k_w * c) +
2991+ (i * W_out + j) * (k_h * k_w * c) +
2992+ c_in * (k_h * k_w) +
2993+ r * k_w +
2994+ s;
2995+ expectedData[expected_idx] = val;
2996+ }
2997+ }
2998+ }
2999+ }
3000+ }
3001+ }
3002+
3003+ InputNode input ({n, c, h, w});
3004+ input.dataInject (inputData.begin (), inputData.end ());
3005+ Img2ColNode result (&input, k_h, k_w, stride, pad);
3006+ result.forward ();
3007+ Tensor expected ({n, 1 , H_out * W_out, k_h * k_w * c});
3008+ expected.dataInject (expectedData.begin (), expectedData.end ());
3009+ EXPECT_EQ (expected, *result.output );
3010+ }
3011+
3012+ TEST (NodeBasic, img2colBackward) {
3013+ const size_t n = 2 ;
3014+ const size_t c = 3 ;
3015+ const size_t h = 4 ;
3016+ const size_t w = 5 ;
3017+ const size_t k_h = 3 ;
3018+ const size_t k_w = 3 ;
3019+ const size_t stride = 1 ;
3020+ const size_t pad = 1 ;
3021+ const size_t H_out = (h + 2 * pad - k_h) / stride + 1 ;
3022+ const size_t W_out = (w + 2 * pad - k_w) / stride + 1 ;
3023+
3024+ std::vector<float > gradData ({n*H_out*W_out*k_h*k_w*c});
3025+ std::vector<float > expectedGradData ({n*c*h*w});
3026+
3027+ std::random_device rd;
3028+ std::mt19937 gen (rd ());
3029+ std::uniform_real_distribution<float > dist (0 .1f , 0 .9f );
3030+
3031+ for (auto & i : gradData) {
3032+ i = dist (gen);
3033+ }
3034+
3035+ for (size_t b = 0 ; b < n; ++b) {
3036+ for (size_t i = 0 ; i < H_out; ++i) {
3037+ for (size_t j = 0 ; j < W_out; ++j) {
3038+ const int h_start = static_cast <int >(i * stride) - pad;
3039+ const int w_start = static_cast <int >(j * stride) - pad;
3040+ for (size_t r = 0 ; r < k_h; ++r) {
3041+ const int h_in = h_start + r;
3042+ for (size_t s = 0 ; s < k_w; ++s) {
3043+ const int w_in = w_start + s;
3044+ for (size_t c_in = 0 ; c_in < c; ++c_in) {
3045+ if (h_in >= 0 && h_in < h && w_in >= 0 && w_in < w) {
3046+ const size_t input_idx =
3047+ b * (c * h * w) +
3048+ c_in * (h * w) +
3049+ h_in * w +
3050+ w_in;
3051+ const size_t grad_idx =
3052+ b * (H_out * W_out * k_h * k_w * c) +
3053+ (i * W_out + j) * (k_h * k_w * c) +
3054+ c_in * (k_h * k_w) +
3055+ r * k_w +
3056+ s;
3057+ expectedGradData[input_idx] += gradData[grad_idx];
3058+ }
3059+ }
3060+ }
3061+ }
3062+ }
3063+ }
3064+ }
3065+
3066+ InputNode input ({n, c, h, w}, true );
3067+ Img2ColNode result (&input, k_h, k_w, stride, pad);
3068+ result.dataInject (gradData.begin (), gradData.end (), true );
3069+ result.backward ();
3070+ Tensor expected ({n, c, h, w}, true );
3071+ expected.dataInject (expectedGradData.begin (), expectedGradData.end (), true );
3072+ EXPECT_EQ (expected, *input.output );
3073+ }
3074+
29463075TEST (TenorBasic, col2imgTest) {
29473076 const size_t n = 2 ;
29483077 const size_t c = 3 ;
0 commit comments