diff --git a/src/frontends/tensorflow_common/src/op/gather.cpp b/src/frontends/tensorflow_common/src/op/gather.cpp index a566a3166bfbb6..811f0ac51b5d93 100644 --- a/src/frontends/tensorflow_common/src/op/gather.cpp +++ b/src/frontends/tensorflow_common/src/op/gather.cpp @@ -5,8 +5,12 @@ #include "openvino/op/gather.hpp" #include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/gather_nd.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/subtract.hpp" using namespace std; using namespace ov::op; @@ -26,34 +30,61 @@ OutputVector translate_basic_gather_op(const NodeContext& node, const ov::Output } OutputVector translate_gather_op(const NodeContext& node) { - // Gather has two inputs: data and indices - // axis by which data is sliced is always equal to 0, batch_dims is always equal to 0 - default_op_checks(node, 2, {"Gather"}); + default_op_checks(node, 2, {"Gather"}, true); + auto params = node.get_input(0); auto axis = make_shared(element::i64, Shape{}, 0); + auto complex_type_mark = as_type_ptr(params.get_node_shared_ptr()); + + if (complex_type_mark) { + auto complex_part_type = complex_type_mark->get_complex_part_type(); + params = complex_type_mark->get_data(); + auto indices = node.get_input(1); + auto gather = make_shared(params, indices, axis, 0); + set_node_name(node.get_name(), gather); + // Return the Gather result directly without ComplexTypeMark + return {gather}; + } + return translate_basic_gather_op(node, axis, 0); } OutputVector translate_resource_gather_op(const NodeContext& node) { - // ResourceGather has two inputs: data and indices - // axis by which data is sliced is always equal to 0, batch_dims is an attribute and can vary default_op_checks(node, 2, {"ResourceGather"}); auto axis = make_shared(element::i64, Shape{}, 0); auto batch_dims = node.get_attribute("batch_dims", 0); return translate_basic_gather_op(node, axis, batch_dims); } +// Update the translate_gather_v2_op function OutputVector translate_gather_v2_op(const NodeContext& node) { - // GatherV2 has three inputs: data, indices, and axis by which data is sliced - // batch_dims is an attribute and can vary - default_op_checks(node, 3, {"GatherV2"}); + default_op_checks(node, 3, {"GatherV2"}, true); + auto params = node.get_input(0); auto axis = node.get_input(2); auto batch_dims = node.get_attribute("batch_dims", 0); + auto complex_type_mark = as_type_ptr(params.get_node_shared_ptr()); + + if (complex_type_mark) { + auto complex_part_type = complex_type_mark->get_complex_part_type(); + params = complex_type_mark->get_data(); + auto indices = node.get_input(1); + + // Calculate the axis without subtracting 1 for complex tensors + auto input_rank = params.get_partial_shape().rank().get_length(); + auto axis_val = axis.get_node_shared_ptr()->get_attribute("value")[0]; + if (axis_val < 0) { + axis_val += input_rank; // Use input_rank directly, not input_rank - 1 + } + + auto updated_axis = std::make_shared(element::i64, Shape{}, axis_val); + auto gather = make_shared(params, indices, updated_axis, batch_dims); + set_node_name(node.get_name(), gather); + return {gather}; // Remove ComplexTypeMark from output + } + return translate_basic_gather_op(node, axis, batch_dims); } OutputVector translate_gather_nd_op(const NodeContext& node) { - // GatherND has two inputs: data and indices - // batch_dims is always equal to 0 default_op_checks(node, 2, {"GatherNd", "GATHER_ND"}); auto input = node.get_input(0); auto input_indices = node.get_input(1); @@ -66,4 +97,4 @@ OutputVector translate_gather_nd_op(const NodeContext& node) { } // namespace op } // namespace tensorflow } // namespace frontend -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Gather.py b/tests/layer_tests/tensorflow_tests/test_tf_Gather.py index 702db00272d27f..43799ddbbf493d 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_Gather.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_Gather.py @@ -74,8 +74,74 @@ def create_gather_net(self, params_shape, params_type, indices_shape, indices_ty @pytest.mark.nightly def test_gather(self, params, params_type, indices_type, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): - if ie_device == 'GPU': - pytest.skip("timeout issue on GPU") + # GPU skip removed self._test(*self.create_gather_net(**params, params_type=params_type, indices_type=indices_type), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) + + +class TestComplexGather(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'real_params:0' in inputs_info + assert 'imag_params:0' in inputs_info + assert 'indices:0' in inputs_info + real_params_shape = inputs_info['real_params:0'] + imag_params_shape = inputs_info['imag_params:0'] + indices_shape = inputs_info['indices:0'] + inputs_data = {} + inputs_data['real_params:0'] = rng.integers(-50, 50, real_params_shape).astype(self.params_type) + inputs_data['imag_params:0'] = rng.integers(-50, 50, imag_params_shape).astype(self.params_type) + inputs_data['indices:0'] = rng.integers(0, self.max_index, indices_shape).astype(self.indices_type) + return inputs_data + + def create_complex_gather_net(self, params_shape, params_type, indices_shape, indices_type, axis_value, batch_dims, + operation_type): + self.params_type = params_type + self.indices_type = indices_type + if batch_dims is None: + batch_dims = 0 + axis_norm = axis_value if axis_value is not None else 0 + if axis_norm < 0: + axis_norm += len(params_shape) + assert 0 <= axis_norm < len(params_shape), "Incorrect `axis` value" + self.max_index = params_shape[axis_norm] + + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + real_params = tf.compat.v1.placeholder(params_type, params_shape, 'real_params') + imag_params = tf.compat.v1.placeholder(params_type, params_shape, 'imag_params') + complex = tf.raw_ops.Complex(real=real_params, imag=imag_params) + + indices = tf.compat.v1.placeholder(indices_type, indices_shape, 'indices') + if operation_type == "Gather": + gather = tf.raw_ops.Gather(params=complex, indices=indices) + elif operation_type == "GatherV2": + axis = tf.constant(axis_value, dtype=tf.int32) + gather = tf.raw_ops.GatherV2(params=complex, indices=indices, axis=axis, batch_dims=batch_dims) + else: + assert False, "Invalid operation type" + + # Directly validate complex output + tf.identity(gather, name='output') + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_precommit = [ + dict(params_shape=[4, 6], indices_shape=[], axis_value=None, batch_dims=None, operation_type='Gather'), + dict(params_shape=[3, 4, 6], indices_shape=[3, 4], axis_value=None, batch_dims=None, operation_type='Gather'), + dict(params_shape=[5, 4, 3], indices_shape=[5, 2, 1], axis_value=2, batch_dims=1, operation_type='GatherV2'), + ] + + @pytest.mark.parametrize("params", test_data_precommit) + @pytest.mark.parametrize("params_type", [np.float32]) + @pytest.mark.parametrize("indices_type", [np.int32, np.int64]) + @pytest.mark.precommit + @pytest.mark.nightly + def test_complex_gather(self, params, params_type, indices_type, ie_device, precision, ir_version, temp_dir, + use_legacy_frontend): + # GPU skip removed and direct complex validation + self._test(*self.create_complex_gather_net(**params, params_type=params_type, indices_type=indices_type), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend) \ No newline at end of file