Skip to content

[TF FE] Support complex tensors for Gather, GatherV2 operations #23493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
85f4e0d
Support complex tensors for Gather, GatherV2 operations
MonalSD Mar 16, 2024
53ce0fc
Fix logic changes
MonalSD Mar 24, 2024
ec35304
Update tests/layer_tests/tensorflow_tests/test_tf_Gather.py
rkazants Mar 29, 2024
ba1adb7
Merge branch 'master' into 22951issue
rkazants Mar 29, 2024
d25dbc3
Update gather.cpp
MonalSD Mar 31, 2024
f9082d1
Update src/frontends/tensorflow_common/src/op/gather.cpp
rkazants Apr 7, 2024
ca2b83b
Update src/frontends/tensorflow_common/src/op/gather.cpp
rkazants Apr 7, 2024
18f4320
Update src/frontends/tensorflow_common/src/op/gather.cpp
rkazants Apr 7, 2024
3b5c2a9
Update src/frontends/tensorflow_common/src/op/gather.cpp
rkazants Apr 7, 2024
39fb2f5
Fix
MonalSD Apr 8, 2024
58055e3
Update src/frontends/tensorflow_common/src/op/gather.cpp
rkazants Apr 12, 2024
25d50ad
Fixes
MonalSD Apr 21, 2024
d61e7ce
Fix
MonalSD Apr 30, 2024
8ac3bc7
Merge branch 'master' into 22951issue
rkazants May 1, 2024
d804937
Build fix
MonalSD May 1, 2024
db61377
Merge branch 'master' into 22951issue
rkazants May 4, 2024
5e2e10c
Merge branch 'master' into 22951issue
mlukasze Jun 18, 2024
b88bb15
Merge branch 'master' into 22951issue
mlukasze Jul 3, 2024
ac0e5ce
Merge branch 'master' into 22951issue
mlukasze Jul 18, 2024
f63dd19
Merge branch 'master' into 22951issue
mlukasze Jul 25, 2024
ca31e22
Merge branch 'master' into 22951issue
mlukasze Sep 18, 2024
00dc5ad
Merge branch 'master' into 22951issue
mlukasze Nov 6, 2024
254d362
Merge branch 'master' into 22951issue
mlukasze Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions src/frontends/tensorflow_common/src/op/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
#include "openvino/op/gather.hpp"

#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/gather_nd.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"

using namespace std;
using namespace ov::op;
Expand All @@ -28,8 +34,20 @@ OutputVector translate_basic_gather_op(const NodeContext& node, const ov::Output
OutputVector translate_gather_op(const NodeContext& node) {
// Gather has two inputs: data and indices
// axis by which data is sliced is always equal to 0, batch_dims is always equal to 0
default_op_checks(node, 2, {"Gather"});
default_op_checks(node, 2, {"Gather"}, true);
auto params = node.get_input(0);
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(params.get_node_shared_ptr());
auto axis = make_shared<v0::Constant>(element::i64, Shape{}, 0);

if (complex_type_mark) {
params = complex_type_mark->input_value(0);
auto indices = node.get_input(1);
auto gather = make_shared<v8::Gather>(params, indices, axis, 0);
set_node_name(node.get_name(), gather);
auto complex_reshape = make_shared<ComplexTypeMark>(gather, complex_type_mark->get_complex_part_type());
return {complex_reshape->output(0)};
}

return translate_basic_gather_op(node, axis, 0);
}

Expand All @@ -45,9 +63,35 @@ OutputVector translate_resource_gather_op(const NodeContext& node) {
OutputVector translate_gather_v2_op(const NodeContext& node) {
// GatherV2 has three inputs: data, indices, and axis by which data is sliced
// batch_dims is an attribute and can vary
default_op_checks(node, 3, {"GatherV2"});
default_op_checks(node, 3, {"GatherV2"}, true);
auto params = node.get_input(0);
auto indices = node.get_input(1);
auto axis = node.get_input(2);
auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);

auto complex_type_mark = as_type_ptr<ComplexTypeMark>(params.get_node_shared_ptr());

if (complex_type_mark) {
params = complex_type_mark->input_value(0);
auto zero = create_same_type_const_scalar<int32_t>(axis, 0);
// create a condition for the Select operation
auto condition = make_shared<v1::Less>(axis, zero);

// calculate the updated value for the axis
auto params_shape = make_shared<v3::ShapeOf>(params, ov::element::i32);
auto params_rank = make_shared<v3::ShapeOf>(params_shape, ov::element::i32);
auto updated_axis = make_shared<v1::Subtract>(params_rank, make_shared<v0::Constant>(ov::element::i32, Shape{}, 1));

// create Select operation to choose between original axis and updated axis
auto selected_axis = make_shared<v1::Select>(condition, updated_axis, axis);

auto gather = make_shared<v8::Gather>(params, indices, selected_axis, 0);

set_node_name(node.get_name(), gather);
auto complex_gather = make_shared<ComplexTypeMark>(gather, complex_type_mark->get_complex_part_type());
return {complex_gather->output(0)};
}

return translate_basic_gather_op(node, axis, batch_dims);
}

Expand All @@ -62,7 +106,6 @@ OutputVector translate_gather_nd_op(const NodeContext& node) {
set_node_name(node.get_name(), gather_nd);
return {gather_nd};
}

} // namespace op
} // namespace tensorflow
} // namespace frontend
Expand Down
78 changes: 78 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_Gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,81 @@ def test_gather(self, params, params_type, indices_type, ie_device, precision, i
self._test(*self.create_gather_net(**params, params_type=params_type, indices_type=indices_type),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

class TestComplexGather(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'real_params:0' in inputs_info
assert 'imag_params:0' in inputs_info
assert 'indices:0' in inputs_info
real_params_shape = inputs_info['real_params:0']
imag_params_shape = inputs_info['imag_params:0']
indices_shape = inputs_info['indices:0']
inputs_data = {}
if self.params_type == str or self.params_type == np.str_:
strings_dictionary = ['first', 'second sentence', ' sentence 3 three', '34ferf466 23435* ']
inputs_data['real_params:0'] = rng.choice(strings_dictionary, real_params_shape)
inputs_data['imag_params:0'] = rng.choice(strings_dictionary, imag_params_shape)
else:
inputs_data['real_params:0'] = rng.integers(-50, 50, real_params_shape).astype(self.params_type)
inputs_data['imag_params:0'] = rng.integers(-50, 50, imag_params_shape).astype(self.params_type)

inputs_data['indices:0'] = rng.integers(0, self.max_index, indices_shape).astype(self.indices_type)
return inputs_data

def create_gather_net(self, params_shape, params_type, indices_shape, indices_type, axis_value, batch_dims,
operation_type):
self.params_type = params_type
if params_type == str or params_type == np.str_:
params_type = tf.string
self.indices_type = indices_type
if batch_dims is None:
batch_dims = 0
if axis_value is None:
axis_value = 0
axis_norm = axis_value
if axis_norm < 0:
axis_norm += len(params_shape)
assert 0 <= axis_norm < len(params_shape), "Incorrect `axis` value for the test case"
self.max_index = params_shape[axis_norm]

tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
real_params = tf.compat.v1.placeholder(params_type, params_shape, 'real_params')
imag_params = tf.compat.v1.placeholder(params_type, params_shape, 'imag_params')
complex = tf.raw_ops.Complex(real=real_params, imag=imag_params)

indices = tf.compat.v1.placeholder(indices_type, indices_shape, 'indices')
if operation_type == "Gather":
tf.raw_ops.Gather(params=complex, indices=indices)
elif operation_type == "GatherV2":
axis = tf.constant(axis_value, dtype=tf.int32)
tf.raw_ops.GatherV2(params=complex, indices=indices, axis=axis, batch_dims=batch_dims)
else:
assert False, "Incorrect operation type is tested"

tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

ref_net = None

return tf_net, ref_net

test_data_precommit = [
dict(params_shape=[4, 6], indices_shape=[], axis_value=None, batch_dims=None, operation_type='Gather'),
dict(params_shape=[3, 4, 6], indices_shape=[3, 4], axis_value=None, batch_dims=None, operation_type='Gather'),
dict(params_shape=[5, 4, 3], indices_shape=[5, 2, 1], axis_value=2, batch_dims=1, operation_type='GatherV2'),
dict(params_shape=[3, 2, 6, 4], indices_shape=[3, 2, 1, 3], axis_value=-1, batch_dims=-2,
operation_type='GatherV2'),
]

@pytest.mark.parametrize("params", test_data_precommit)
@pytest.mark.parametrize("params_type", [np.float32, np.int32, str, np.str_])
@pytest.mark.parametrize("indices_type", [np.int32, np.int64])
@pytest.mark.precommit
@pytest.mark.nightly
def test_gather(self, params, params_type, indices_type, ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
self._test(*self.create_gather_net(**params, params_type=params_type, indices_type=indices_type),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

Loading