Skip to content

Commit 35d13e1

Browse files
committed
Support arbitrary nested tuples, lists and dicts in inputs and outputs of inlined custom ops. Also supported kwargs for original functions.
1 parent 2ea382a commit 35d13e1

File tree

1 file changed

+95
-31
lines changed

1 file changed

+95
-31
lines changed

src/bindings/python/src/openvino/frontend/pytorch/inlined_extension.py

Lines changed: 95 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,69 @@
99
from openvino.frontend.pytorch.utils import pt_to_ov_type_map
1010
import openvino as ov
1111

12+
13+
class ConstWrap:
14+
def __init__(self, value):
15+
self.value = value
16+
def __eq__(self, x):
17+
return self.value == x
18+
19+
20+
def unpack(packed, types, index=0):
21+
unpacked_result = ()
22+
if isinstance(packed, tuple):
23+
packer_result = ()
24+
for el in packed:
25+
unpacked, packer, index = unpack(el, types, index)
26+
unpacked_result += unpacked
27+
packer_result += (packer,)
28+
elif isinstance(packed, list):
29+
packer_result = []
30+
for el in packed:
31+
unpacked, packer, index = unpack(el, types, index)
32+
packer_result.append(packer)
33+
elif isinstance(packed, dict):
34+
packer_result = {}
35+
for k, v in packed.items():
36+
unpacked, packer, index = unpack(v, types, index)
37+
unpacked_result += unpacked
38+
packer_result[k] = packer
39+
elif isinstance(packed, types):
40+
unpacked_result = (packed,)
41+
packer_result = index
42+
index += 1
43+
else:
44+
packer_result = ConstWrap(packed)
45+
return unpacked_result, packer_result, index
46+
47+
48+
def pack(unpacked, packer):
49+
if isinstance(packer, tuple):
50+
packed_result = ()
51+
for el in packer:
52+
packed = pack(unpacked, el)
53+
packed_result += (packed,)
54+
elif isinstance(packer, list):
55+
packed_result = []
56+
for el in packer:
57+
packed = pack(unpacked, el)
58+
packed_result.append(packed)
59+
elif isinstance(packer, dict):
60+
packed_result = {}
61+
for k, v in packer.items():
62+
packed = pack(unpacked, v)
63+
packed_result[k] = packed
64+
elif isinstance(packer, ConstWrap):
65+
packed_result = packer.value
66+
else:
67+
packed_result = unpacked[packer]
68+
return packed_result
69+
70+
1271
global_counter_id = 0
1372

1473
# makes a custom op class from a func and input/output signatures
15-
def make_custom_op_class(func, input_signature, output_signature):
74+
def make_custom_op_class(func, input_signature, output_signature, input_packer, output_packer):
1675
import torch, numpy
1776
global global_counter_id
1877
# print('make_custom_op_class, id =', global_counter_id)
@@ -35,12 +94,11 @@ def __init__(self, *args):
3594

3695
def evaluate(self, outputs, inputs):
3796
# print("called evaluate")
38-
inputs_torch = (torch.from_numpy(input.data) for input in inputs) # TODO: Check memory sharing
39-
result = func(*inputs_torch)
40-
if result is None:
41-
result = ()
42-
if not isinstance(result, tuple):
43-
result = (result,)
97+
inputs_torch = tuple(torch.from_numpy(input.data) for input in inputs) # TODO: Check memory sharing
98+
args, kwargs = pack(inputs_torch, input_packer)
99+
result = func(*args, **kwargs)
100+
result, result_packer, _ = unpack(result, torch.Tensor)
101+
assert result_packer == output_packer
44102
for i, tensor in enumerate(result):
45103
ov.Tensor(numpy.array(tensor), shared_memory=True).copy_to(outputs[i]) # TODO: set the output tensor directly without copying
46104
return True
@@ -73,11 +131,10 @@ def make_signature(args):
73131

74132
# Returns a tuple of tuples (element_type, partial_shape) for each argument, flattening nested structures if needed, setting all dimensions dynamic preserving rank
75133
# Currently assumes that all input arguments are torch.Tensor objects
76-
def make_input_signature(args, kwargs):
134+
def make_input_signature(args):
77135
# TODO: Avoid the current limitation: kwargs parameters should be passed in the same order as the function signature without gaps
78136
# flatten kwargs relying on the order of the keys
79-
assert not kwargs, "Keyword arguments are not supported yet"
80-
return make_signature(args + tuple(kwargs.values()))
137+
return make_signature(args)
81138

82139

83140
def make_output_signature(args):
@@ -90,43 +147,46 @@ def make_output_signature(args):
90147
return make_signature(args)
91148

92149

93-
def is_class_method(obj):
94-
if not inspect.isfunction(obj) and not inspect.ismethod(obj):
95-
return False
96-
argspec = inspect.getfullargspec(obj)
97-
if argspec.args and argspec.args[0] == 'self':
98-
return True
99-
else:
100-
return False
101-
102-
103150
def make_trampoline_class(func, op, op_attrs):
104151
import torch
105152
class Trampoline(torch.autograd.Function):
106153
target_extension = InlineConversionExtension() # this is a marker for this type of extension
107154

108155
# This function defines how the operation behaves when called as a part of PyTorch model code in eager execution or while jit.trace
109156
@staticmethod
110-
def forward(ctx, *call_args, **call_kwargs): #TODO: what is `ctx`?
157+
def forward(ctx, *call_args): #TODO: what is `ctx`?
111158
# print('Called through the trampoline')
112-
func_target = func
113159
if not op:
114-
if is_class_method(func):
115-
self_obj = call_args[0]
116-
call_args = call_args[1:]
117-
wrapped = lambda *distil_args, **distil_kwargs: func(self_obj, *distil_args, **distil_kwargs)
118-
func_target = wrapped
119-
input_signature = make_input_signature(call_args, call_kwargs)
160+
input_signature = make_input_signature(call_args)
120161
# TODO: Try to trace `func` with the hope to obtain tracable shapes to build more precise `validate_and_infer_types` automatically (unlikely possible)
121-
result = func_target(*call_args, **call_kwargs)
162+
print('about to call func_target with call_args:', call_args)
163+
packed_args, packed_kwargs = pack(call_args, __class__.input_packer)
164+
assert isinstance(packed_args, tuple)
165+
assert isinstance(packed_kwargs, dict)
166+
result = func(*packed_args, **packed_kwargs)
167+
result, __class__.output_packer, _ = unpack(result, torch.Tensor)
168+
print('output_packer:', __class__.output_packer)
122169
if not op:
123170
output_signature = make_output_signature(result)
124171
#print('about to make custom op class with output signature', output_signature)
125-
__class__.op = make_custom_op_class(func_target, input_signature, output_signature)
172+
__class__.op = make_custom_op_class(func, input_signature, output_signature, __class__.input_packer, __class__.output_packer)
126173
else:
127174
__class__.op = op
128175
return result
129176

177+
# Unpack each element that is a tuple, a list, or a dict to a tuple of their values and concatenate together
178+
# Build `packer` function to pack the result back to the original nested data types, save this `packer` to
179+
# class member to use it later in `forward`.
180+
@staticmethod
181+
def unpack_inputs(args, kwargs):
182+
unpacked, __class__.input_packer, _ = unpack((args, kwargs), torch.Tensor)
183+
print('input_packer:', __class__.input_packer)
184+
return unpacked
185+
186+
@staticmethod
187+
def pack_outputs(result):
188+
return pack(result, __class__.output_packer)
189+
130190
# This function defines how the operation is represented in OpenVINO model graph
131191
@staticmethod
132192
def convert(node_context):
@@ -143,7 +203,11 @@ def trampoline(*args, **kwargs):
143203
# It is required because `func` is fused inside Trampoline class and can have different behaviour from call to call in PyTorch world even if
144204
# the same op is specified to wrap multiple different functions.
145205
trampoline = make_trampoline_class(func, op, op_attrs)
146-
result = trampoline.apply(*args, **kwargs)
206+
print('calling trampoline.apply with args:', args, 'kwargs:', kwargs)
207+
args = trampoline.unpack_inputs(args, kwargs)
208+
result = trampoline.apply(*args)
209+
result = trampoline.pack_outputs(result)
210+
# pack to the expected nested data types here
147211
#print('just called trampoline with result:', result)
148212
return result
149213
return trampoline

0 commit comments

Comments
 (0)