99from openvino .frontend .pytorch .utils import pt_to_ov_type_map
1010import openvino as ov
1111
12+
13+ class ConstWrap :
14+ def __init__ (self , value ):
15+ self .value = value
16+ def __eq__ (self , x ):
17+ return self .value == x
18+
19+
20+ def unpack (packed , types , index = 0 ):
21+ unpacked_result = ()
22+ if isinstance (packed , tuple ):
23+ packer_result = ()
24+ for el in packed :
25+ unpacked , packer , index = unpack (el , types , index )
26+ unpacked_result += unpacked
27+ packer_result += (packer ,)
28+ elif isinstance (packed , list ):
29+ packer_result = []
30+ for el in packed :
31+ unpacked , packer , index = unpack (el , types , index )
32+ packer_result .append (packer )
33+ elif isinstance (packed , dict ):
34+ packer_result = {}
35+ for k , v in packed .items ():
36+ unpacked , packer , index = unpack (v , types , index )
37+ unpacked_result += unpacked
38+ packer_result [k ] = packer
39+ elif isinstance (packed , types ):
40+ unpacked_result = (packed ,)
41+ packer_result = index
42+ index += 1
43+ else :
44+ packer_result = ConstWrap (packed )
45+ return unpacked_result , packer_result , index
46+
47+
48+ def pack (unpacked , packer ):
49+ if isinstance (packer , tuple ):
50+ packed_result = ()
51+ for el in packer :
52+ packed = pack (unpacked , el )
53+ packed_result += (packed ,)
54+ elif isinstance (packer , list ):
55+ packed_result = []
56+ for el in packer :
57+ packed = pack (unpacked , el )
58+ packed_result .append (packed )
59+ elif isinstance (packer , dict ):
60+ packed_result = {}
61+ for k , v in packer .items ():
62+ packed = pack (unpacked , v )
63+ packed_result [k ] = packed
64+ elif isinstance (packer , ConstWrap ):
65+ packed_result = packer .value
66+ else :
67+ packed_result = unpacked [packer ]
68+ return packed_result
69+
70+
1271global_counter_id = 0
1372
1473# makes a custom op class from a func and input/output signatures
15- def make_custom_op_class (func , input_signature , output_signature ):
74+ def make_custom_op_class (func , input_signature , output_signature , input_packer , output_packer ):
1675 import torch , numpy
1776 global global_counter_id
1877 # print('make_custom_op_class, id =', global_counter_id)
@@ -35,12 +94,11 @@ def __init__(self, *args):
3594
3695 def evaluate (self , outputs , inputs ):
3796 # print("called evaluate")
38- inputs_torch = (torch .from_numpy (input .data ) for input in inputs ) # TODO: Check memory sharing
39- result = func (* inputs_torch )
40- if result is None :
41- result = ()
42- if not isinstance (result , tuple ):
43- result = (result ,)
97+ inputs_torch = tuple (torch .from_numpy (input .data ) for input in inputs ) # TODO: Check memory sharing
98+ args , kwargs = pack (inputs_torch , input_packer )
99+ result = func (* args , ** kwargs )
100+ result , result_packer , _ = unpack (result , torch .Tensor )
101+ assert result_packer == output_packer
44102 for i , tensor in enumerate (result ):
45103 ov .Tensor (numpy .array (tensor ), shared_memory = True ).copy_to (outputs [i ]) # TODO: set the output tensor directly without copying
46104 return True
@@ -73,11 +131,10 @@ def make_signature(args):
73131
74132# Returns a tuple of tuples (element_type, partial_shape) for each argument, flattening nested structures if needed, setting all dimensions dynamic preserving rank
75133# Currently assumes that all input arguments are torch.Tensor objects
76- def make_input_signature (args , kwargs ):
134+ def make_input_signature (args ):
77135 # TODO: Avoid the current limitation: kwargs parameters should be passed in the same order as the function signature without gaps
78136 # flatten kwargs relying on the order of the keys
79- assert not kwargs , "Keyword arguments are not supported yet"
80- return make_signature (args + tuple (kwargs .values ()))
137+ return make_signature (args )
81138
82139
83140def make_output_signature (args ):
@@ -90,43 +147,46 @@ def make_output_signature(args):
90147 return make_signature (args )
91148
92149
93- def is_class_method (obj ):
94- if not inspect .isfunction (obj ) and not inspect .ismethod (obj ):
95- return False
96- argspec = inspect .getfullargspec (obj )
97- if argspec .args and argspec .args [0 ] == 'self' :
98- return True
99- else :
100- return False
101-
102-
103150def make_trampoline_class (func , op , op_attrs ):
104151 import torch
105152 class Trampoline (torch .autograd .Function ):
106153 target_extension = InlineConversionExtension () # this is a marker for this type of extension
107154
108155 # This function defines how the operation behaves when called as a part of PyTorch model code in eager execution or while jit.trace
109156 @staticmethod
110- def forward (ctx , * call_args , ** call_kwargs ): #TODO: what is `ctx`?
157+ def forward (ctx , * call_args ): #TODO: what is `ctx`?
111158 # print('Called through the trampoline')
112- func_target = func
113159 if not op :
114- if is_class_method (func ):
115- self_obj = call_args [0 ]
116- call_args = call_args [1 :]
117- wrapped = lambda * distil_args , ** distil_kwargs : func (self_obj , * distil_args , ** distil_kwargs )
118- func_target = wrapped
119- input_signature = make_input_signature (call_args , call_kwargs )
160+ input_signature = make_input_signature (call_args )
120161 # TODO: Try to trace `func` with the hope to obtain tracable shapes to build more precise `validate_and_infer_types` automatically (unlikely possible)
121- result = func_target (* call_args , ** call_kwargs )
162+ print ('about to call func_target with call_args:' , call_args )
163+ packed_args , packed_kwargs = pack (call_args , __class__ .input_packer )
164+ assert isinstance (packed_args , tuple )
165+ assert isinstance (packed_kwargs , dict )
166+ result = func (* packed_args , ** packed_kwargs )
167+ result , __class__ .output_packer , _ = unpack (result , torch .Tensor )
168+ print ('output_packer:' , __class__ .output_packer )
122169 if not op :
123170 output_signature = make_output_signature (result )
124171 #print('about to make custom op class with output signature', output_signature)
125- __class__ .op = make_custom_op_class (func_target , input_signature , output_signature )
172+ __class__ .op = make_custom_op_class (func , input_signature , output_signature , __class__ . input_packer , __class__ . output_packer )
126173 else :
127174 __class__ .op = op
128175 return result
129176
177+ # Unpack each element that is a tuple, a list, or a dict to a tuple of their values and concatenate together
178+ # Build `packer` function to pack the result back to the original nested data types, save this `packer` to
179+ # class member to use it later in `forward`.
180+ @staticmethod
181+ def unpack_inputs (args , kwargs ):
182+ unpacked , __class__ .input_packer , _ = unpack ((args , kwargs ), torch .Tensor )
183+ print ('input_packer:' , __class__ .input_packer )
184+ return unpacked
185+
186+ @staticmethod
187+ def pack_outputs (result ):
188+ return pack (result , __class__ .output_packer )
189+
130190 # This function defines how the operation is represented in OpenVINO model graph
131191 @staticmethod
132192 def convert (node_context ):
@@ -143,7 +203,11 @@ def trampoline(*args, **kwargs):
143203 # It is required because `func` is fused inside Trampoline class and can have different behaviour from call to call in PyTorch world even if
144204 # the same op is specified to wrap multiple different functions.
145205 trampoline = make_trampoline_class (func , op , op_attrs )
146- result = trampoline .apply (* args , ** kwargs )
206+ print ('calling trampoline.apply with args:' , args , 'kwargs:' , kwargs )
207+ args = trampoline .unpack_inputs (args , kwargs )
208+ result = trampoline .apply (* args )
209+ result = trampoline .pack_outputs (result )
210+ # pack to the expected nested data types here
147211 #print('just called trampoline with result:', result)
148212 return result
149213 return trampoline
0 commit comments