2222    create_immediate_value ,
2323    create_list_scalarvalue ,
2424    create_scalar_value ,
25-     types_to_proto ,
25+     create_valuetype_list ,
26+     create_valuetype_scalar ,
27+     create_valuetype_tensor ,
2628    types_to_proto_primitive ,
2729)
2830from  coremltools .converters .mil .backend .nn .load  import  _set_optional_inputs 
@@ -158,7 +160,7 @@ def translate_const(self, op: Operation) -> proto.MIL_pb2.Operation:
158160            attributes = {"name" : create_scalar_value (op .name ), "val" : value },
159161            outputs = [
160162                proto .MIL_pb2 .NamedValueType (
161-                     name = output_var .name , type = types_to_proto (output_var .sym_type )
163+                     name = output_var .name , type = self . types_to_proto (output_var .sym_type )
162164                )
163165            ],
164166        )
@@ -190,12 +192,58 @@ def translate_constexpr(self, op: Operation) -> proto.MIL_pb2.Operation:
190192            attributes = attributes ,
191193            outputs = [
192194                proto .MIL_pb2 .NamedValueType (
193-                     name = output_var .name , type = types_to_proto (output_var .sym_type )
195+                     name = output_var .name , type = self . types_to_proto (output_var .sym_type )
194196                )
195197                for  output_var  in  op .outputs 
196198            ],
197199        )
198200
201+     def  create_valuetype_dict (self , key_type : type , value_type : type ) ->  proto .MIL_pb2 .ValueType :
202+         """ 
203+         Return proto.MIL_pb2.ValueType with dict (dictionaryType) set 
204+         """ 
205+         v_type  =  proto .MIL_pb2 .ValueType ()
206+         v_type .dictionaryType .keyType .CopyFrom (self .types_to_proto (key_type ))
207+         v_type .dictionaryType .valueType .CopyFrom (self .types_to_proto (value_type ))
208+         return  v_type 
209+ 
210+     def  types_to_proto (self , valuetype : type ) ->  proto .MIL_pb2 .ValueType :
211+         """ 
212+         Return proto.MIL_pb2.ValueType from PyMIL types. 
213+         """ 
214+         if  types .is_tensor (valuetype ):
215+             primitive  =  types_to_proto_primitive (valuetype .get_primitive ())
216+             return  create_valuetype_tensor (valuetype .get_shape (), primitive )
217+         elif  types .is_tuple (valuetype ):
218+             v_type  =  proto .MIL_pb2 .ValueType ()
219+             t_type  =  v_type .tupleType 
220+             for  t  in  valuetype .T :
221+                 new_v_type  =  t_type .types .add ()
222+                 new_v_type .CopyFrom (self .types_to_proto (t ))
223+             return  v_type 
224+         elif  types .is_list (valuetype ):
225+             elem  =  valuetype .T [0 ]
226+             length  =  valuetype .T [1 ]
227+             if  types .is_tensor (elem ):
228+                 dtype  =  types_to_proto_primitive (elem .get_primitive ())
229+                 elem_shape  =  elem .get_shape ()
230+             elif  types .is_scalar (elem ):
231+                 dtype  =  types_to_proto_primitive (valuetype )
232+                 elem_shape  =  ()
233+             elif  types .is_str (elem ):
234+                 dtype  =  types_to_proto_primitive (elem )
235+                 elem_shape  =  ()
236+             else :
237+                 raise  NotImplementedError (
238+                     "Only list of either tensors or scalars supported. " 
239+                     "Got element of type {}" .format (elem .__type_info__ ())
240+                 )
241+             return  create_valuetype_list (length = length , elem_shape = elem_shape , dtype = dtype )
242+         elif  types .is_dict (valuetype ):
243+             return  self .create_valuetype_dict (valuetype .T [0 ], valuetype .T [1 ])
244+         else :
245+             return  create_valuetype_scalar (types_to_proto_primitive (valuetype ))
246+ 
199247    def  translate_generic_op (
200248        self , op : Operation , literal_params : Optional [List [str ]] =  None 
201249    ) ->  proto .MIL_pb2 .Operation :
@@ -228,7 +276,7 @@ def translate_generic_op(
228276            inputs [param_name ] =  args 
229277
230278        outputs  =  [
231-             proto .MIL_pb2 .NamedValueType (name = v .name , type = types_to_proto (v .sym_type ))
279+             proto .MIL_pb2 .NamedValueType (name = v .name , type = self . types_to_proto (v .sym_type ))
232280            for  v  in  op .outputs 
233281        ]
234282        blocks  =  None 
@@ -311,14 +359,18 @@ def feeds_to_only_constexprs(op: Operation) -> bool:
311359                literal_params  =  ["begins" , "ends" , "end_masks" ]
312360                proto_ops .append (self .translate_generic_op (op , literal_params ))
313361            else :
314-                 proto_ops .append (self .translate_generic_op (op ))
362+                 # A single pymil op might be decomposed into multiple ops 
363+                 ops  =  self .translate_generic_op (op )
364+                 if  not  isinstance (ops , list ):
365+                     ops  =  [ops ]
366+                 proto_ops .extend (ops )
315367
316368        inputs  =  []
317369        if  not  isinstance (block , Function ):
318370            # Function is subclass of Block, but function's block has no input, 
319371            # and hence skipping reading the block inputs. 
320372            for  var  in  block .inputs :
321-                 proto_type  =  types_to_proto (var .sym_type )
373+                 proto_type  =  self . types_to_proto (var .sym_type )
322374                inputs .append (proto .MIL_pb2 .NamedValueType (name = var .name , type = proto_type ))
323375        output_names  =  [v .name  for  v  in  block .outputs ]
324376        return  proto .MIL_pb2 .Block (inputs = inputs , outputs = output_names , operations = proto_ops )
@@ -331,7 +383,7 @@ def convert_function(self, function: Function, opset: str) -> proto.MIL_pb2.Func
331383
332384        inputs  =  []
333385        for  name , var  in  function .inputs .items ():
334-             proto_type  =  types_to_proto (var .sym_type )
386+             proto_type  =  self . types_to_proto (var .sym_type )
335387            inputs .append (proto .MIL_pb2 .NamedValueType (name = name , type = proto_type ))
336388
337389        return  proto .MIL_pb2 .Function (
@@ -467,6 +519,15 @@ def get_additional_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
467519        """ 
468520        return  {}
469521
522+     @staticmethod  
523+     def  _try_convert_other_input_type (
524+         input_var : Var , input_features : List [proto .Model_pb2 .FeatureDescription ]
525+     ) ->  bool :
526+         """ 
527+         Try to convert an input var with additional type. 
528+         """ 
529+         return  False 
530+ 
470531    def  get_func_input (self , func : mil .Function ) ->  List [proto .Model_pb2 .FeatureDescription ]:
471532        """ 
472533        Utils to get function input feature description. 
@@ -554,7 +615,7 @@ def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDesc
554615                input_features .append (
555616                    proto .Model_pb2 .FeatureDescription (name = var .name , type = input_feature_type )
556617                )
557-             else :
618+             elif   not   self . _try_convert_other_input_type ( var ,  input_features ) :
558619                raise  NotImplementedError (f"Unsupported input type { var .sym_type }  ." )
559620
560621            if  not  is_input_shape_symbolic :
@@ -746,6 +807,16 @@ def get_func_output(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDes
746807
747808        return  output_features 
748809
810+     def  create_model_description (
811+         self ,
812+         input_features : List [proto .Model_pb2 .FeatureDescription ],
813+         output_features : List [proto .Model_pb2 .FeatureDescription ],
814+     ) ->  proto .Model_pb2 .ModelDescription :
815+         """ 
816+         Create model description from input and output features 
817+         """ 
818+         return  proto .Model_pb2 .ModelDescription (input = input_features , output = output_features )
819+ 
749820    def  get_coreml_model (
750821        self ,
751822        input : Dict [str , List [proto .Model_pb2 .FeatureDescription ]],
@@ -758,7 +829,7 @@ def get_coreml_model(
758829        # Model description 
759830        input_features  =  input [self ._DEFAULT_FUNCTION_NAME ]
760831        output_features  =  output [self ._DEFAULT_FUNCTION_NAME ]
761-         desc  =  proto . Model_pb2 . ModelDescription ( input = input_features , output = output_features )
832+         desc  =  self . create_model_description ( input_features , output_features )
762833
763834        if  self .classifier_config  is  not   None :
764835            desc .predictedFeatureName  =  self .predicted_feature_name 
0 commit comments