1515from  coremltools  import  ComputeUnit  as  _ComputeUnit 
1616from  coremltools  import  __version__  as  _ct_version 
1717from  coremltools  import  _logger  as  logger 
18- from  coremltools ._deps  import  _HAS_TF_1 , _HAS_TF_2 , _HAS_TORCH 
18+ from  coremltools ._deps  import  _HAS_TF_1 , _HAS_TF_2 , _HAS_TORCH ,  _HAS_TORCH_EXPORT_API 
1919from  coremltools .converters ._profile_utils  import  _profile 
2020from  coremltools .converters .mil ._deployment_compatibility  import  (
2121    AvailableTarget ,
3636from  coremltools .converters .mil .mil .passes .defs .quantization  import  FP16ComputePrecision 
3737from  coremltools .converters .mil .mil .passes .graph_pass  import  PassOption  as  _PassOption 
3838from  coremltools .converters .mil .mil .passes .pass_pipeline  import  PassPipeline 
39- from  coremltools .models  import  _METADATA_SOURCE , _METADATA_VERSION 
39+ from  coremltools .models  import  _METADATA_SOURCE , _METADATA_SOURCE_DIALECT ,  _METADATA_VERSION 
4040from  coremltools .models .utils  import  _MLPACKAGE_EXTENSION 
4141
4242if  _HAS_TF_1 :
5151if  _HAS_TORCH :
5252    import  torch 
5353
54-     from  coremltools .converters .mil .frontend .torch .load  import  \
55-         _torchscript_from_model  as  pytorch_load 
54+     from  coremltools .converters .mil .frontend .torch .load  import  (
55+         _torchscript_from_spec  as  try_load_torchscript ,
56+     )
57+ 
58+     if  _HAS_TORCH_EXPORT_API :
59+         from  torch .export  import  ExportedProgram 
60+ 
5661
5762
5863@_profile  
@@ -102,8 +107,12 @@ def convert(
102107
103108        * PyTorch 
104109
105-             - A `TorchScript <https://pytorch.org/docs/stable/jit.html>`_ object 
106-             - Path to a ``.pt`` file 
110+             - TorchScript Models: 
111+                 - A `TorchScript <https://pytorch.org/docs/stable/jit.html>`_ object 
112+                 - Path to a ``.pt`` file 
113+ 
114+             - Torch Exported Models: 
115+                 - A `ExportedProgram <https://pytorch.org/docs/stable/export.html#torch.export.ExportedProgram> ` object with `EDGE` dialect 
107116
108117    source : str (optional) 
109118
@@ -161,18 +170,23 @@ def convert(
161170              When ``inputs`` not provided or ``dtype`` not specified. The float 32 inputs defaults to float 16. 
162171
163172        * PyTorch: 
164-             - The ``inputs`` parameter is required. 
165-             - Number of elements in ``inputs`` must match the number of inputs 
166-               of the PyTorch model. 
167-             - ``inputs`` may be a nested list or tuple. 
168-             - ``TensorType`` and ``ImageType`` must have the ``shape`` specified. 
169-             - If the ``name`` argument is specified with ``TensorType`` or 
170-               ``ImageType``, the converted Core ML model will have inputs with 
171-               the same name. 
172-             - If ``dtype`` is missing: 
173-               * For ``minimum_deployment_target <= ct.target.macOS12``, it defaults to float 32. 
174-               * For ``minimum_deployment_target >= ct.target.macOS13``, and with ``compute_precision`` in float 16 precision. 
175-                 It defaults to float 16. 
173+ 
174+             - TorchScript Models: 
175+                 - The ``inputs`` parameter is required. 
176+                 - Number of elements in ``inputs`` must match the number of inputs 
177+                   of the PyTorch model. 
178+                 - ``inputs`` may be a nested list or tuple. 
179+                 - ``TensorType`` and ``ImageType`` must have the ``shape`` specified. 
180+                 - If the ``name`` argument is specified with ``TensorType`` or 
181+                   ``ImageType``, the converted Core ML model will have inputs with 
182+                   the same name. 
183+                 - If ``dtype`` is missing: 
184+                   * For ``minimum_deployment_target <= ct.target.macOS12``, it defaults to float 32. 
185+                   * For ``minimum_deployment_target >= ct.target.macOS13``, and with ``compute_precision`` in float 16 precision. 
186+                     It defaults to float 16. 
187+ 
188+             - Torch Exported Models: 
189+                 - The ``inputs`` parameter is not supported. ``inputs`` parameter is inferred from Torch ExportedProgram. 
176190
177191    outputs : list of ``TensorType`` or ``ImageType`` (optional) 
178192
@@ -218,13 +232,17 @@ def convert(
218232
219233        * PyTorch: 
220234
221-             - If specified, the length of the list must match the number of 
222-               outputs returned by the PyTorch model. 
223-             - If ``name`` is specified, it is applied to the output names of the 
224-               converted Core ML model. 
225-             - For ``minimum_deployment_target >= ct.target.macOS13``, and with ``compute_precision`` in float 16 precision. 
226-               If ``dtype`` not specified, the outputs inferred of type float 32 
227-               defaults to float 16. 
235+             - TorchScript Models: 
236+                 - If specified, the length of the list must match the number of 
237+                 outputs returned by the PyTorch model. 
238+                 - If ``name`` is specified, it is applied to the output names of the 
239+                 converted Core ML model. 
240+                 - For ``minimum_deployment_target >= ct.target.macOS13``, and with ``compute_precision`` in float 16 precision. 
241+                 If ``dtype`` not specified, the outputs inferred of type float 32 
242+                 defaults to float 16. 
243+ 
244+             - Torch Exported Models: 
245+                 - The ``outputs`` parameter is not supported. ``outputs`` parameter is inferred from Torch ExportedProgram. 
228246
229247
230248    classifier_config : ClassifierConfig class (optional) 
@@ -308,7 +326,7 @@ def convert(
308326          The above transform iterates through all the ops, looking at each op's 
309327          inputs and outputs. If they are of type float 32, ``cast`` 
310328          ops are injected to convert those tensors (also known as `vars`) to 
311-           type float 16. 
329+           type float 16. Similarly, int32 vars will also be cast to int16.  
312330
313331        - ``coremltools.precision.FLOAT32`` enum: No transform is applied. 
314332
@@ -489,15 +507,17 @@ def skip_real_div_ops(op):
489507
490508    PyTorch: 
491509
492-         >>> model = torchvision.models.mobilenet_v2() 
493-         >>> model.eval() 
494-         >>> example_input = torch.rand(1, 3, 256, 256) 
495-         >>> traced_model = torch.jit.trace(model, example_input) 
510+         TorchScript Models: 
496511
497-         >>> input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256)) 
498-         >>> mlmodel = ct.convert(traced_model, inputs=[input]) 
499-         >>> results = mlmodel.predict({"input": example_input.numpy()}) 
500-         >>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT 
512+             >>> model = torchvision.models.mobilenet_v2() 
513+             >>> model.eval() 
514+             >>> example_input = torch.rand(1, 3, 256, 256) 
515+             >>> traced_model = torch.jit.trace(model, example_input) 
516+ 
517+             >>> input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256)) 
518+             >>> mlmodel = ct.convert(traced_model, inputs=[input]) 
519+             >>> results = mlmodel.predict({"input": example_input.numpy()}) 
520+             >>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT 
501521
502522    See `Conversion Options <https://coremltools.readme.io/docs/neural-network-conversion>`_ for 
503523    more advanced options. 
@@ -508,6 +528,7 @@ def skip_real_div_ops(op):
508528                                     outputs_as_strings ,
509529                                     outputs_as_tensor_or_image_types ,
510530                                     outputs )
531+     source_dialect  =  _determine_source_dialect (model , exact_source )
511532    exact_target  =  _determine_target (convert_to , minimum_deployment_target )
512533    _validate_conversion_arguments (
513534        model ,
@@ -525,7 +546,7 @@ def skip_real_div_ops(op):
525546    if  pass_pipeline  is  None :
526547        pass_pipeline  =  PassPipeline ()
527548    if  not  need_fp16_cast_pass :
528-         pass_pipeline .remove_passes ({"common::add_fp16_cast" })
549+         pass_pipeline .remove_passes ({"common::add_fp16_cast" ,  "common::add_int16_cast" })
529550    if  isinstance (compute_precision , FP16ComputePrecision ):
530551        # For backward compatibility with the `op_selector` param in FP16ComputePrecision. 
531552        pass_pipeline ._pass_options ["common::add_fp16_cast" ] =  [
@@ -584,7 +605,7 @@ def skip_real_div_ops(op):
584605
585606    gc .collect ()
586607
587-     mlmodel  =  _record_build_metadata (mlmodel , exact_source )
608+     mlmodel  =  _record_build_metadata (mlmodel , exact_source ,  source_dialect = source_dialect )
588609
589610    return  mlmodel 
590611
@@ -819,16 +840,45 @@ def _flatten_list(_inputs):
819840            raise  ValueError ("Input should be a list of TensorType or ImageType" )
820841
821842    elif  exact_source  ==  "pytorch" :
822-         if  inputs  is  None :
823-             raise  ValueError ('Expected argument for pytorch "inputs" not provided' )
843+         if  _HAS_TORCH_EXPORT_API  and  isinstance (model , ExportedProgram ):
844+             if  model .dialect  !=  "EDGE" :
845+                 raise  NotImplementedError (
846+                     f"Conversion for models with only EDGE dialect is supported/tested. Provided Dialect: { model .dialect }  " 
847+                 )
824848
825-         raise_if_duplicated (flat_inputs )
826-         if  inputs  is  not   None  and  not  all (
827-             [isinstance (_input , InputType ) for  _input  in  flat_inputs ]
828-         ):
829-             raise  ValueError (
830-                 "Input should be a list/tuple (or nested lists/tuples) of TensorType or ImageType" 
831-             )
849+             # TODO: rdar://115845792 ([Executorch] Handle user provided inputs/outputs in the convert API) 
850+             if  inputs  is  not   None :
851+                 raise  AssertionError ("'inputs' argument should be None for ExportedProgram" )
852+ 
853+             if  outputs  is  not   None :
854+                 raise  AssertionError ("'outputs' argument should be None for ExportedProgram" )
855+ 
856+         else :
857+             is_torch_load_successful  =  False 
858+             try :
859+                 try_load_torchscript (model )
860+                 is_torch_load_successful  =  True 
861+             except :
862+                 pass 
863+             if  is_torch_load_successful :
864+                 if  inputs  is  None :
865+                     raise  ValueError (
866+                         'Expected argument "inputs" for TorchScript models not provided' 
867+                     )
868+ 
869+                 raise_if_duplicated (flat_inputs )
870+                 if  inputs  is  not   None  and  not  all (
871+                     [isinstance (_input , InputType ) for  _input  in  flat_inputs ]
872+                 ):
873+                     raise  ValueError (
874+                         "Input should be a list/tuple (or nested lists/tuples) of TensorType or ImageType" 
875+                     )
876+             else :
877+                 raise  TypeError (
878+                     "@model must either be a TorchScript object (or .pt or .pth file) or an ExportedProgram object (if using torch.export based API), received: {}" .format (
879+                         type (model )
880+                     )
881+                 )
832882
833883    elif  exact_source  ==  "milinternal" :
834884        if  not  isinstance (model , Program ):
@@ -837,6 +887,19 @@ def _flatten_list(_inputs):
837887            )
838888
839889
890+ def  _determine_source_dialect (model , exact_source ):
891+ 
892+     source_dialect  =  None 
893+     if  exact_source  ==  "pytorch" :
894+ 
895+         if  _HAS_TORCH_EXPORT_API  and  isinstance (model , ExportedProgram ):
896+             return  f"TorchExport::{ model .dialect }  " 
897+         else :
898+             return  "TorchScript" 
899+ 
900+     return  source_dialect 
901+ 
902+ 
840903def  _determine_source (model , source ,
841904                      output_names ,
842905                      outputs_as_tensor_or_image_types ,
@@ -875,9 +938,13 @@ def _determine_source(model, source,
875938            pass 
876939
877940    if  source  ==  "auto"  and  _HAS_TORCH :
941+ 
942+         if  _HAS_TORCH_EXPORT_API  and  isinstance (model , ExportedProgram ):
943+             return  "pytorch" 
944+ 
878945        is_torch_load_successful  =  False 
879946        try :
880-             pytorch_load (model )
947+             try_load_torchscript (model )
881948            is_torch_load_successful  =  True 
882949        except :
883950            pass 
@@ -953,6 +1020,12 @@ def _get_metadata_from_mlmodel(mlmodel):
9531020    src_pkg_version  =  mlmodel .user_defined_metadata [_METADATA_SOURCE ]
9541021    coremltools_version  =  mlmodel .user_defined_metadata [_METADATA_VERSION ]
9551022
1023+     src_dialect  =  (
1024+         None 
1025+         if  _METADATA_SOURCE_DIALECT  not  in   mlmodel .user_defined_metadata 
1026+         else  mlmodel .user_defined_metadata [_METADATA_SOURCE_DIALECT ]
1027+     )
1028+ 
9561029    src_pkg_version_list  =  src_pkg_version .split ("==" )
9571030    if  len (src_pkg_version_list ) ==  0 :
9581031        src_pkg , pkg_ver  =  None , None 
@@ -969,10 +1042,13 @@ def _get_metadata_from_mlmodel(mlmodel):
9691042    if  src_pkg  is  not   None  and  pkg_ver  is  not   None :
9701043        build_info ['coremltools-component-'  +  src_pkg ] =  str (pkg_ver )
9711044
1045+     if  src_dialect  is  not   None :
1046+         build_info ["coremltools-source-dialect" ] =  src_dialect 
1047+ 
9721048    return  build_info 
9731049
9741050
975- def  _record_build_metadata (mlmodel , exact_source ):
1051+ def  _record_build_metadata (mlmodel , exact_source ,  source_dialect = None ):
9761052    # recording metadata: coremltools version, source framework and version 
9771053    if  exact_source  in  {"tensorflow" , "tensorflow2" } and  (_HAS_TF_1  or  _HAS_TF_2 ):
9781054        src_pkg_version  =  "tensorflow=={0}" .format (tf .__version__ )
@@ -986,6 +1062,9 @@ def _record_build_metadata(mlmodel, exact_source):
9861062    mlmodel .user_defined_metadata [_METADATA_SOURCE ] =  src_pkg_version 
9871063    mlmodel .user_defined_metadata [_METADATA_VERSION ] =  _ct_version 
9881064
1065+     if  source_dialect  is  not   None :
1066+         mlmodel .user_defined_metadata [_METADATA_SOURCE_DIALECT ] =  source_dialect 
1067+ 
9891068    build_info  =  _get_metadata_from_mlmodel (mlmodel )
9901069
9911070    mlmodel ._set_build_info_mil_attributes (build_info )
0 commit comments