diff --git a/src/frontends/tensorflow_common/src/op/gather.cpp b/src/frontends/tensorflow_common/src/op/gather.cpp index a566a3166bfbb6..1a1ff5fe8ca13a 100644 --- a/src/frontends/tensorflow_common/src/op/gather.cpp +++ b/src/frontends/tensorflow_common/src/op/gather.cpp @@ -5,8 +5,12 @@ #include "openvino/op/gather.hpp" #include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/gather_nd.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/subtract.hpp" using namespace std; using namespace ov::op; @@ -28,8 +32,21 @@ OutputVector translate_basic_gather_op(const NodeContext& node, const ov::Output OutputVector translate_gather_op(const NodeContext& node) { // Gather has two inputs: data and indices // axis by which data is sliced is always equal to 0, batch_dims is always equal to 0 - default_op_checks(node, 2, {"Gather"}); + default_op_checks(node, 2, {"Gather"}, true); + auto params = node.get_input(0); auto axis = make_shared(element::i64, Shape{}, 0); + auto complex_type_mark = as_type_ptr(params.get_node_shared_ptr()); + + if (complex_type_mark) { + auto complex_part_type = complex_type_mark->get_complex_part_type(); + params = complex_type_mark->get_data(); + auto indices = node.get_input(1); + auto gather = make_shared(params, indices, axis, 0); + set_node_name(node.get_name(), gather); + auto complex_gather = make_shared(gather, complex_part_type); + return {complex_gather->output(0)}; + } + return translate_basic_gather_op(node, axis, 0); } @@ -45,9 +62,30 @@ OutputVector translate_resource_gather_op(const NodeContext& node) { OutputVector translate_gather_v2_op(const NodeContext& node) { // GatherV2 has three inputs: data, indices, and axis by which data is sliced // batch_dims is an attribute and can vary - default_op_checks(node, 3, {"GatherV2"}); + default_op_checks(node, 3, {"GatherV2"}, true); + auto params = node.get_input(0); auto axis = node.get_input(2); auto batch_dims = node.get_attribute("batch_dims", 0); + auto complex_type_mark = as_type_ptr(params.get_node_shared_ptr()); + + if (complex_type_mark) { + auto complex_part_type = complex_type_mark->get_complex_part_type(); + params = complex_type_mark->get_data(); + auto indices = node.get_input(1); + + // The axis can be negative + auto zero = make_shared(ov::element::i32, Shape{}, 0); + auto one = make_shared(ov::element::i32, Shape{}, 1); + auto condition = make_shared(axis, zero); + auto axis_subtract_one = make_shared(axis, one); + auto updated_axis = make_shared(condition, axis_subtract_one, axis); + + auto gather = make_shared(params, indices, updated_axis, batch_dims); + set_node_name(node.get_name(), gather); + auto complex_gather = make_shared(gather, complex_part_type); + return {complex_gather->output(0)}; + } + return translate_basic_gather_op(node, axis, batch_dims); } diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Gather.py b/tests/layer_tests/tensorflow_tests/test_tf_Gather.py index dd41cc6971e6ef..bba8e857940b1f 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_Gather.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_Gather.py @@ -77,3 +77,77 @@ def test_gather(self, params, params_type, indices_type, ie_device, precision, i pytest.skip("timeout issue on GPU") self._test(*self.create_gather_net(**params, params_type=params_type, indices_type=indices_type), ie_device, precision, ir_version, temp_dir=temp_dir) + + +class TestComplexGather(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'real_params:0' in inputs_info + assert 'imag_params:0' in inputs_info + assert 'indices:0' in inputs_info + real_params_shape = inputs_info['real_params:0'] + imag_params_shape = inputs_info['imag_params:0'] + indices_shape = inputs_info['indices:0'] + inputs_data = {} + inputs_data['real_params:0'] = rng.integers(-50, 50, real_params_shape).astype(self.params_type) + inputs_data['imag_params:0'] = rng.integers(-50, 50, imag_params_shape).astype(self.params_type) + + inputs_data['indices:0'] = rng.integers(0, self.max_index, indices_shape).astype(self.indices_type) + return inputs_data + + + def create_complex_gather_net(self, params_shape, params_type, indices_shape, indices_type, axis_value, batch_dims, + operation_type): + self.params_type = params_type + if params_type == str or params_type == np.str_: + params_type = tf.string + self.indices_type = indices_type + if batch_dims is None: + batch_dims = 0 + if axis_value is None: + axis_value = 0 + axis_norm = axis_value + if axis_norm < 0: + axis_norm += len(params_shape) + assert 0 <= axis_norm < len(params_shape), "Incorrect `axis` value for the test case" + self.max_index = params_shape[axis_norm] + + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + real_params = tf.compat.v1.placeholder(params_type, params_shape, 'real_params') + imag_params = tf.compat.v1.placeholder(params_type, params_shape, 'imag_params') + complex = tf.raw_ops.Complex(real=real_params, imag=imag_params) + + indices = tf.compat.v1.placeholder(indices_type, indices_shape, 'indices') + if operation_type == "Gather": + transpose = tf.raw_ops.Gather(params=complex, indices=indices) + elif operation_type == "GatherV2": + axis = tf.constant(axis_value, dtype=tf.int32) + transpose = tf.raw_ops.GatherV2(params=complex, indices=indices, axis=axis, batch_dims=batch_dims) + else: + assert False, "Incorrect operation type is tested" + + tf.raw_ops.Real(input=transpose) + tf.raw_ops.Imag(input=transpose) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + test_data_precommit = [ + dict(params_shape=[4, 6], indices_shape=[], axis_value=None, batch_dims=None, operation_type='Gather'), + dict(params_shape=[3, 4, 6], indices_shape=[3, 4], axis_value=None, batch_dims=None, operation_type='Gather'), + dict(params_shape=[5, 4, 3], indices_shape=[5, 2, 1], axis_value=2, batch_dims=1, operation_type='GatherV2'), + dict(params_shape=[3, 2, 6, 4], indices_shape=[3, 2, 1, 3], axis_value=-1, batch_dims=-2, + operation_type='GatherV2'), + ] + + @pytest.mark.parametrize("params", test_data_precommit) + @pytest.mark.parametrize("params_type", [np.float32]) + @pytest.mark.parametrize("indices_type", [np.int32, np.int64]) + @pytest.mark.precommit + @pytest.mark.nightly + def test_complex_gather(self, params, params_type, indices_type, ie_device, precision, ir_version, temp_dir): + self._test(*self.create_complex_gather_net(**params, params_type=params_type, indices_type=indices_type), + ie_device, precision, ir_version, temp_dir=temp_dir)