diff --git a/src/plugins/intel_cpu/src/nodes/col2im.cpp b/src/plugins/intel_cpu/src/nodes/col2im.cpp index 96cd54cd2da75b..413e90bf813ea8 100644 --- a/src/plugins/intel_cpu/src/nodes/col2im.cpp +++ b/src/plugins/intel_cpu/src/nodes/col2im.cpp @@ -4,6 +4,7 @@ #include "col2im.h" +#include #include #include #include @@ -76,7 +77,39 @@ bool Col2Im::needPrepareParams() const { } void Col2Im::executeDynamicImpl(const dnnl::stream& strm) { - execute(strm); + // 1. get data shape + auto data_shape = getSrcMemoryAtPort(0)->getStaticDims(); + size_t data_rank = data_shape.size(); + + // 2. get output_size + auto output_size_mem = getSrcMemoryAtPort(1); + const auto* output_size_ptr = output_size_mem->getDataAs(); + + // 3. get kernel_size + auto kernel_size_mem = getSrcMemoryAtPort(2); + const auto* kernel_size_ptr = kernel_size_mem->getDataAs(); + + // 4. calculate output_shape + auto kernel_prod = static_cast(kernel_size_ptr[0]) * static_cast(kernel_size_ptr[1]); + + auto H = static_cast(output_size_ptr[0]); + auto W = static_cast(output_size_ptr[1]); + + ov::Shape output_shape; + if (data_rank == 2) { // Case of Non-batched inputs + size_t C = data_shape[0] / kernel_prod; + output_shape = {C, H, W}; + redefineOutputMemory({output_shape}); + execute(strm); + } else if (data_rank == 3) { // Case of Batched inputs + size_t N = data_shape[0]; + size_t C = data_shape[1] / kernel_prod; + output_shape = {N, C, H, W}; + redefineOutputMemory({output_shape}); + execute(strm); + } else { + OPENVINO_THROW("Col2Im node supports only 2D(Non-Batched) or 3D(Batched) input tensors"); + } } template diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/col2im.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/col2im.cpp index 216d9fa4f230fc..1251fb63a820aa 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/col2im.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/col2im.cpp @@ -3,11 +3,12 @@ // #include "col2im.hpp" -#include "utils/cpu_test_utils.hpp" + #include "common_test_utils/ov_tensor_utils.hpp" +#include "openvino/op/col2im.hpp" #include "shared_test_classes/base/ov_subgraph.hpp" +#include "utils/cpu_test_utils.hpp" #include "utils/precision_support.h" -#include "openvino/op/col2im.hpp" using namespace CPUTestUtils; @@ -26,34 +27,34 @@ std::string Col2ImLayerCPUTest::getTestCaseName(const testing::TestParamInfo col2ImParamsVector = { - Col2ImSpecificParams { - InputShape{{}, {{1, 12, 9}}}, - std::vector{4, 4}, - std::vector{2, 2}, - ov::Strides{1, 1}, - ov::Strides{1, 1}, - ov::Shape{0, 0}, - ov::Shape{0, 0} - }, - Col2ImSpecificParams { - InputShape{{}, {{3, 12, 81}}}, - std::vector{16, 16}, - std::vector{2, 2}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{2, 2}, - ov::Shape{2, 2} - }, - Col2ImSpecificParams { - InputShape{{}, {{12, 81}}}, - std::vector{16, 16}, - std::vector{2, 2}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{2, 2}, - ov::Shape{2, 2} - }, - Col2ImSpecificParams { - InputShape{{}, {{3, 12, 225}}}, - std::vector{16, 16}, - std::vector{2, 2}, - ov::Strides{1, 1}, - ov::Strides{1, 1}, - ov::Shape{0, 0}, - ov::Shape{0, 0} - }, - Col2ImSpecificParams { - InputShape{{}, {{1, 27, 49}}}, - std::vector{16, 16}, - std::vector{3, 3}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{1, 1}, - ov::Shape{1, 1} - }, - Col2ImSpecificParams { - InputShape{{}, {{1, 18, 104}}}, - std::vector{16, 16}, - std::vector{2, 3}, - ov::Strides{2, 1}, - ov::Strides{2, 2}, - ov::Shape{1, 0}, - ov::Shape{0, 1} - }, - Col2ImSpecificParams { - InputShape{{-1, -1, -1}, {{1, 12, 120}, {3, 12, 120}}}, - std::vector{16, 16}, - std::vector{2, 2}, - ov::Strides{2, 1}, - ov::Strides{2, 2}, - ov::Shape{1, 0}, - ov::Shape{0, 1} - }, - Col2ImSpecificParams { - InputShape{{}, {{12, 12, 324}}}, - std::vector{32, 32}, - std::vector{2, 2}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{3, 3}, - ov::Shape{3, 3} - }, - Col2ImSpecificParams { - InputShape{{-1, 12, 324}, {{12, 12, 324}}}, - std::vector{32, 32}, - std::vector{2, 2}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{3, 3}, - ov::Shape{3, 3} - }, - Col2ImSpecificParams { - InputShape{{-1, -1, -1}, {{12, 12, 324}}}, - std::vector{32, 32}, - std::vector{2, 2}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{3, 3}, - ov::Shape{3, 3} - }, - Col2ImSpecificParams { - InputShape{{12, -1, -1}, {{12, 12, 324}}}, - std::vector{32, 32}, - std::vector{2, 2}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{3, 3}, - ov::Shape{3, 3} - }, - Col2ImSpecificParams { - InputShape{{12, 12, -1}, {{12, 12, 324}}}, - std::vector{32, 32}, - std::vector{2, 2}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{3, 3}, - ov::Shape{3, 3} - }, - Col2ImSpecificParams { - InputShape{{12, -1, 324}, {{12, 12, 324}}}, - std::vector{32, 32}, - std::vector{2, 2}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{3, 3}, - ov::Shape{3, 3} - }, - Col2ImSpecificParams { - InputShape{{-1, -1}, {{12, 324}}}, - std::vector{32, 32}, - std::vector{2, 2}, - ov::Strides{2, 2}, - ov::Strides{2, 2}, - ov::Shape{3, 3}, - ov::Shape{3, 3} - } -}; + Col2ImSpecificParams{InputShape{{}, {{1, 12, 9}}}, + std::vector{4, 4}, + std::vector{2, 2}, + ov::Strides{1, 1}, + ov::Strides{1, 1}, + ov::Shape{0, 0}, + ov::Shape{0, 0}}, + Col2ImSpecificParams{InputShape{{}, {{3, 12, 81}}}, + std::vector{16, 16}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{2, 2}, + ov::Shape{2, 2}}, + Col2ImSpecificParams{InputShape{{}, {{12, 81}}}, + std::vector{16, 16}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{2, 2}, + ov::Shape{2, 2}}, + Col2ImSpecificParams{InputShape{{}, {{3, 12, 225}}}, + std::vector{16, 16}, + std::vector{2, 2}, + ov::Strides{1, 1}, + ov::Strides{1, 1}, + ov::Shape{0, 0}, + ov::Shape{0, 0}}, + Col2ImSpecificParams{InputShape{{}, {{1, 27, 49}}}, + std::vector{16, 16}, + std::vector{3, 3}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{1, 1}, + ov::Shape{1, 1}}, + Col2ImSpecificParams{InputShape{{}, {{1, 18, 104}}}, + std::vector{16, 16}, + std::vector{2, 3}, + ov::Strides{2, 1}, + ov::Strides{2, 2}, + ov::Shape{1, 0}, + ov::Shape{0, 1}}, + Col2ImSpecificParams{InputShape{{-1, -1, -1}, {{1, 12, 120}, {3, 12, 120}}}, + std::vector{16, 16}, + std::vector{2, 2}, + ov::Strides{2, 1}, + ov::Strides{2, 2}, + ov::Shape{1, 0}, + ov::Shape{0, 1}}, + Col2ImSpecificParams{InputShape{{}, {{12, 12, 324}}}, + std::vector{32, 32}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{3, 3}, + ov::Shape{3, 3}}, + Col2ImSpecificParams{InputShape{{-1, 12, 324}, {{12, 12, 324}}}, + std::vector{32, 32}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{3, 3}, + ov::Shape{3, 3}}, + Col2ImSpecificParams{InputShape{{-1, -1, -1}, {{12, 12, 324}}}, + std::vector{32, 32}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{3, 3}, + ov::Shape{3, 3}}, + Col2ImSpecificParams{InputShape{{12, -1, -1}, {{12, 12, 324}}}, + std::vector{32, 32}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{3, 3}, + ov::Shape{3, 3}}, + Col2ImSpecificParams{InputShape{{12, 12, -1}, {{12, 12, 324}}}, + std::vector{32, 32}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{3, 3}, + ov::Shape{3, 3}}, + Col2ImSpecificParams{InputShape{{12, -1, 324}, {{12, 12, 324}}}, + std::vector{32, 32}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{3, 3}, + ov::Shape{3, 3}}, + Col2ImSpecificParams{InputShape{{-1, -1}, {{12, 324}}}, + std::vector{32, 32}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{2, 2}, + ov::Shape{3, 3}, + ov::Shape{3, 3}}, + Col2ImSpecificParams{// Batched default + InputShape{{-1, -1, -1}, {{1, 4, 4}, {1, 8, 4}, {1, 12, 4}}}, + std::vector{3, 3}, + std::vector{2, 2}, + ov::Strides{1, 1}, + ov::Strides{1, 1}, + ov::Shape{0, 0}, + ov::Shape{0, 0}}, + Col2ImSpecificParams{// Batched dilations + InputShape{{-1, -1, -1}, {{1, 4, 9}, {1, 8, 9}, {1, 12, 9}}}, + std::vector{5, 5}, + std::vector{2, 2}, + ov::Strides{1, 1}, + ov::Strides{2, 2}, + ov::Shape{0, 0}, + ov::Shape{0, 0}}, + Col2ImSpecificParams{// Batched pads + InputShape{{-1, -1, -1}, {{1, 4, 9}, {1, 8, 9}, {1, 12, 9}}}, + std::vector{2, 2}, + std::vector{2, 2}, + ov::Strides{1, 1}, + ov::Strides{1, 1}, + ov::Shape{1, 1}, + ov::Shape{1, 1}}, + Col2ImSpecificParams{// Batched strides + InputShape{{-1, -1, -1}, {{1, 4, 4}, {1, 8, 4}, {1, 12, 4}}}, + std::vector{4, 4}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{1, 1}, + ov::Shape{0, 0}, + ov::Shape{0, 0}}, + Col2ImSpecificParams{// Non-batched default + InputShape{{-1, -1}, {{4, 4}, {8, 4}, {12, 4}}}, + std::vector{3, 3}, + std::vector{2, 2}, + ov::Strides{1, 1}, + ov::Strides{1, 1}, + ov::Shape{0, 0}, + ov::Shape{0, 0}}, + Col2ImSpecificParams{// Non-batched dilations + InputShape{{-1, -1}, {{4, 9}, {8, 9}, {12, 9}}}, + std::vector{5, 5}, + std::vector{2, 2}, + ov::Strides{1, 1}, + ov::Strides{2, 2}, + ov::Shape{0, 0}, + ov::Shape{0, 0}}, + Col2ImSpecificParams{// Non-batched pads + InputShape{{-1, -1}, {{4, 9}, {8, 9}, {12, 9}}}, + std::vector{2, 2}, + std::vector{2, 2}, + ov::Strides{1, 1}, + ov::Strides{1, 1}, + ov::Shape{1, 1}, + ov::Shape{1, 1}}, + Col2ImSpecificParams{// Non-batched strides + InputShape{{-1, -1}, {{4, 4}, {8, 4}, {12, 4}}}, + std::vector{4, 4}, + std::vector{2, 2}, + ov::Strides{2, 2}, + ov::Strides{1, 1}, + ov::Shape{0, 0}, + ov::Shape{0, 0}}}; } // namespace Col2Im } // namespace test