Skip to content

Commit 781d190

Browse files
committed
Fixed two bugs: with incorrect list packing in the new functionality, and with broken TupleUnpack in old functionality.
1 parent 1a5cd57 commit 781d190

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/bindings/python/src/openvino/frontend/pytorch/inlined_extension.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def __init__(self, value):
1515
self.value = value
1616
def __eq__(self, x):
1717
return self.value == x
18+
def __repr__(self):
19+
return f'ConstWrap({str(self.value)})'
1820

1921

2022
def unpack(packed, types, index=0):
@@ -29,6 +31,7 @@ def unpack(packed, types, index=0):
2931
packer_result = []
3032
for el in packed:
3133
unpacked, packer, index = unpack(el, types, index)
34+
unpacked_result += unpacked
3235
packer_result.append(packer)
3336
elif isinstance(packed, dict):
3437
packer_result = {}
@@ -175,7 +178,9 @@ def pack_outputs(result):
175178
@staticmethod
176179
def convert(node_context):
177180
inputs = [node_context.get_input(i) for i in range(node_context.get_input_size())]
178-
return __class__.op(*inputs, **op_attrs).outputs()
181+
node = __class__.op(*inputs, **op_attrs)
182+
node.get_rt_info()['__torch_tuple_unpackable__'] = True # to trigger prim::TupleUnpack bypass in PyTorch FrontEnd transformation
183+
return node.outputs()
179184

180185
return Trampoline
181186

src/frontends/pytorch/src/transforms/tuple_unpack_replacer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ PrimTupleUnpackReplacer::PrimTupleUnpackReplacer() {
2929
auto input_node = tuple_unpack->get_input_node_shared_ptr(0);
3030
auto tuple_construct = cast_fw_node(input_node, "prim::TupleConstruct");
3131
if (!tuple_construct) {
32-
if(!ov::as_type_ptr<ov::op::util::FrameworkNode>(input_node)) {
33-
// remove TupleUnpack just bypassing it with all outputs from any op except FrameworkNode
34-
// We are leaving FrameworkNode case for further processing
32+
if(input_node->get_rt_info().count("__torch_tuple_unpackable__")) {
33+
input_node->get_rt_info().erase("__torch_tuple_unpackable__");
34+
// remove TupleUnpack just bypassing it with all outputs from a custom operation which returns tuple
3535
replace_node(tuple_unpack, input_node->outputs());
3636
return true;
3737
}

0 commit comments

Comments
 (0)