| 
 | 1 | +# -*- coding: utf-8 -*-  | 
 | 2 | + | 
 | 3 | +#  Copyright (c) 2021, Apple Inc. All rights reserved.  | 
 | 4 | +#  | 
 | 5 | +#  Use of this source code is governed by a BSD-3-clause license that can be  | 
 | 6 | +#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause  | 
 | 7 | + | 
 | 8 | +from coremltools.converters.mil.mil import Builder as _mb  | 
 | 9 | +from coremltools.converters.mil.mil import types as _types  | 
 | 10 | +from coremltools.converters.mil.mil.ops import defs as _ops  | 
 | 11 | +from coremltools.converters.mil.mil.passes.pass_registry import register_pass as _register_pass  | 
 | 12 | + | 
 | 13 | +import warnings as _warnings  | 
 | 14 | + | 
 | 15 | +@_register_pass(namespace="mil_backend")  | 
 | 16 | +def adjust_io_to_supported_types(prog):  | 
 | 17 | +    """  | 
 | 18 | +    Converts all dTypes to types that are supported by the CoreML runtime.  | 
 | 19 | +    The runtime supports only fp16, fp32, int32, str, and bool variables.  | 
 | 20 | +
  | 
 | 21 | +    General rules:  | 
 | 22 | +        * Integer vars that are not 32 bit are replaced with int32 types.  | 
 | 23 | +        * All other types not in the list of runtime supported types are replaced with the fp32 dtype.  | 
 | 24 | +          No casts are inserted; the previous type is replaced. The assumption is that all remaining  | 
 | 25 | +          types are numerical and can be reasonably replaced with 32 bit float types.  | 
 | 26 | +
  | 
 | 27 | +    The "main" function has additional rules since its I/O is mapped to CoreML model I/O:  | 
 | 28 | +        * Fp16 I/O is replaced with fp32 I/O.  | 
 | 29 | +          Casts (fp32 input -> fp16) are inserted at the beginning of the program to preserve 16 bit inputs.  | 
 | 30 | +          Casts (fp16 -> fp32 output) are inserted at the end of the program to preserve 16 bit computations.  | 
 | 31 | +
  | 
 | 32 | +        * All non-integer I/O that is not fp32 is replaced with fp32 I/O.  | 
 | 33 | +          A cast (prev input type -> fp32) is inserted at the beginning of the program to preserve non-fp32 inputs.  | 
 | 34 | +          A cast (prev type -> fp32 out) is inserted at the end of the program to preserve non-fp32 computations.  | 
 | 35 | +          The assumption is that all remaining types are numerical and it is valid to cast them to/from fp32.  | 
 | 36 | +
  | 
 | 37 | +        * The only exception: Int64 outputs are allowed for the classifier op. This is to keep consistency with  | 
 | 38 | +          the CoreML API, which uses 64 bit integers to represent classifier labels.  | 
 | 39 | +
  | 
 | 40 | +    ------  | 
 | 41 | +
  | 
 | 42 | +    func main(bool x, int32 y, fp32 z) {  | 
 | 43 | +        bool  out = logical_not(x)  | 
 | 44 | +    } -> (out, y, z)  | 
 | 45 | +
  | 
 | 46 | +    becomes  | 
 | 47 | +
  | 
 | 48 | +    func main(fp32 x, int32 y, fp32 z) {  | 
 | 49 | +       bool  x_casted = cast(x)  | 
 | 50 | +       bool  out__pre__output__fp32__cast = logical_not(x_casted)  | 
 | 51 | +       fp32  out = cast(out__pre__output__fp32__cast)  | 
 | 52 | +    } -> (out, y, z)  | 
 | 53 | +
  | 
 | 54 | +    ------  | 
 | 55 | +
  | 
 | 56 | +    func not_main(bool x, int32 y, fp32 z) {  | 
 | 57 | +        bool  out = logical_not(x)  | 
 | 58 | +    } -> (out, y, z)  | 
 | 59 | +
  | 
 | 60 | +    is unchanged.  | 
 | 61 | +    """  | 
 | 62 | +    for name, func in prog.functions.items():  | 
 | 63 | +        _adjust_io_to_supported_types(func, name == "main")  | 
 | 64 | + | 
 | 65 | + | 
 | 66 | +__RUNTIME_SUPPORTED_TYPES = [_types.fp16, _types.fp32, _types.int32, _types.str, _types.bool]  | 
 | 67 | + | 
 | 68 | +#####  | 
 | 69 | +# Main Function  | 
 | 70 | +#####  | 
 | 71 | +def _adjust_main_inputs(func):  | 
 | 72 | +    first_op = func.operations[0] if len(func.operations) > 0 else None  | 
 | 73 | +    for input_name, input_var in func.inputs.items():  | 
 | 74 | +       if (_types.is_tensor(input_var.sym_type) or _types.is_scalar(input_var.sym_type)) \  | 
 | 75 | +            and input_var.dtype != _types.fp32 \  | 
 | 76 | +            and input_var.dtype != _types.int32:  | 
 | 77 | +            input_dtype_str = _types.builtin_to_string(input_var.dtype)  | 
 | 78 | +            if _types.is_int(input_var.dtype):  | 
 | 79 | +                # Replace non-int32 input type with int32.  | 
 | 80 | +                _warnings.warn("Input" + input_var.name + " is of dType " + input_dtype_str +\  | 
 | 81 | +                               ". Only integer variables of bit width 32 are supported by the CoreML runtime. " +\  | 
 | 82 | +                               "This input will be assigned a dType of int32. " +\  | 
 | 83 | +                               "No cast will be inserted; the previous dtype will be replaced.")  | 
 | 84 | +                input_var._sym_type = _types.tensor(_types.int32, input_var.sym_type.get_shape())  | 
 | 85 | +            elif input_var.dtype == _types.fp64:  | 
 | 86 | +                # Replace float64 input type with fp32.  | 
 | 87 | +                _warnings.warn("Input" + input_var.name + " is of dtype fp64. 64 bit float inputs are " +\  | 
 | 88 | +                               "not supported by ML program models. This input will be assigned a dType " +\  | 
 | 89 | +                               "of fp32. No cast will be inserted; the previous dtype will be replaced.")  | 
 | 90 | +                input_var._sym_type = _types.tensor(_types.fp32, input_var.sym_type.get_shape())  | 
 | 91 | +            else:  | 
 | 92 | +                # This is some other dType. Change the type to fp32 and add a cast.  | 
 | 93 | +                # This is only a limitation of main--other functions do not represent CoreML model inputs  | 
 | 94 | +                # and do not have the same limitation on input types.  | 
 | 95 | +                _warnings.warn("Input" + input_var.name + " is of dType " + input_dtype_str + ". The " +\  | 
 | 96 | +                               "CoreML runtime does not support inputs with this dType (only fp32 and " +\  | 
 | 97 | +                               "int32 inputs are supported). This input will be assigned a dType of " +\  | 
 | 98 | +                               "fp32. A cast will be inserted at the beginning of the program to " +\  | 
 | 99 | +                               "convert the input to the originally defined dType.")  | 
 | 100 | +                with func:  | 
 | 101 | +                    casted_input_var = _mb.cast(x=input_var, dtype=input_dtype_str, before_op=first_op)  | 
 | 102 | +                    func.replace_uses_of_var_after_op(anchor_op=casted_input_var.op, old_var=input_var, new_var=casted_input_var)  | 
 | 103 | +                    input_var._sym_type = _types.tensor(_types.fp32, input_var.sym_type.get_shape())  | 
 | 104 | + | 
 | 105 | + | 
 | 106 | +def _adjust_main_outputs(func):  | 
 | 107 | +    new_outputs = []  | 
 | 108 | +    for output_var in func.outputs:  | 
 | 109 | +        output_type = output_var.sym_type  | 
 | 110 | +        if (_types.is_tensor(output_type) or _types.is_scalar(output_type)) \  | 
 | 111 | +            and output_var.dtype != _types.fp32 \  | 
 | 112 | +            and output_var.dtype != _types.int32:  | 
 | 113 | +            output_dtype_str = _types.builtin_to_string(output_var.dtype)  | 
 | 114 | +            _warnings.warn("Output" + output_var.name + " is of dType " + output_dtype_str + ". The " +\  | 
 | 115 | +                           "CoreML runtime does not support outputs with this dType (only int32 and " +\  | 
 | 116 | +                           "fp32 are supported for outputs). This output will be assigned a dType " +\  | 
 | 117 | +                           "of fp32. A cast will be inserted at the end of the program to convert" +\  | 
 | 118 | +                           "the original output dType to the dType supported by the CoreML runtime.")  | 
 | 119 | + | 
 | 120 | +            output_var_name = output_var.name  | 
 | 121 | +            output_var.set_name(output_var_name + "__pre__output__fp32__cast")  | 
 | 122 | +            # Convert the output to fp32, and add a cast.  | 
 | 123 | +            with func:  | 
 | 124 | +                output_var = _mb.cast(x=output_var, dtype="fp32")  | 
 | 125 | +                output_var.set_name(output_var_name)  | 
 | 126 | +        new_outputs.append(output_var)  | 
 | 127 | +    func.set_outputs(new_outputs)  | 
 | 128 | + | 
 | 129 | + | 
 | 130 | +#####  | 
 | 131 | +# General Functions and Blocks  | 
 | 132 | +#####  | 
 | 133 | +def _adjust_var(var):  | 
 | 134 | +    """  | 
 | 135 | +    Changes the dtype of the provided variable according  | 
 | 136 | +    to the rules outlined in the top level pass comment  | 
 | 137 | +    (see adjust_io_to_supported_types).  | 
 | 138 | +    """  | 
 | 139 | +    if (_types.is_tensor(var.sym_type) or _types.is_scalar(var.sym_type)) \  | 
 | 140 | +        and var.dtype not in __RUNTIME_SUPPORTED_TYPES:  | 
 | 141 | +        dtype_str = _types.builtin_to_string(var.dtype)  | 
 | 142 | +        if _types.is_int(var.dtype):  | 
 | 143 | +            # Replace non-int32 input type with int32.  | 
 | 144 | +            _warnings.warn("Input" + var.name + " is of dType " + dtype_str +\  | 
 | 145 | +                           ". Only integer variables of bit width 32 are supported by the CoreML runtime. " +\  | 
 | 146 | +                           "This input will be assigned a dType of int32. " +\  | 
 | 147 | +                           "No cast will be inserted; the previous dtype will be replaced.")  | 
 | 148 | +            var._sym_type = _types.tensor(_types.int32, var.sym_type.get_shape())  | 
 | 149 | +        else:  | 
 | 150 | +            # This is some other unsupported dType. Change the input type to fp32.  | 
 | 151 | +            _warnings.warn("Var " + var.name + " is of dType " + dtype_str + ". The CoreML runtime " +\  | 
 | 152 | +                           "does not support this dType (only fp16, fp32, bool, and int32 are supported). " +\  | 
 | 153 | +                           "This input will be assigned a dType of fp32. No cast will be inserted; " +\  | 
 | 154 | +                           "the previous dtype will be replaced.")  | 
 | 155 | +            var._sym_type = _types.tensor(_types.fp32, var.sym_type.get_shape())  | 
 | 156 | + | 
 | 157 | + | 
 | 158 | +def _adjust_func_inputs(func):  | 
 | 159 | +    for input_name, input_var in func.inputs.items():  | 
 | 160 | +       _adjust_var(input_var)  | 
 | 161 | + | 
 | 162 | + | 
 | 163 | +def _adjust_block_inputs(block):  | 
 | 164 | +    for input_var in block.inputs:  | 
 | 165 | +       _adjust_var(input_var)  | 
 | 166 | + | 
 | 167 | + | 
 | 168 | +def _adjust_ops(block):  | 
 | 169 | +    len_block = len(block.operations)  | 
 | 170 | +    i = 0  | 
 | 171 | +    while i < len_block:  | 
 | 172 | +        op = block.operations[i]  | 
 | 173 | + | 
 | 174 | +        # Classifier is a special exception to this rule. It can output 64 bit integer labels.  | 
 | 175 | +        # Classifier should be inserted after running this pass.  | 
 | 176 | +        if op.op_type == "classify":  | 
 | 177 | +            raise ValueError("ML Program backend pass adjust_to_supported_types does not support programs" +\  | 
 | 178 | +                             " that have already added a classify op.")  | 
 | 179 | + | 
 | 180 | +        for subblock in op.blocks:  | 
 | 181 | +            _adjust_block_inputs(subblock)  | 
 | 182 | +            _adjust_ops(subblock)  | 
 | 183 | + | 
 | 184 | +        for var in op.outputs:  | 
 | 185 | +            _adjust_var(var)  | 
 | 186 | + | 
 | 187 | +        # Cast ops have a param (dtype) that should match the output dtype.  | 
 | 188 | +        # If the output dtype or input dtype was previously adjusted,  | 
 | 189 | +        # the cast op must change or be removed in kind.  | 
 | 190 | +        if op.op_type == "cast":  | 
 | 191 | +            output_type_str = _types.builtin_to_string(op.outputs[0].dtype)  | 
 | 192 | +            if op.outputs[0].dtype == op.x.dtype:  | 
 | 193 | +                # The type of the input or output of this cast op was changed per the rules  | 
 | 194 | +                # defined in the top level comment for adjust_io_to_supported_types.  | 
 | 195 | +                #  | 
 | 196 | +                # That changed output type is the same type as the input to the cast  | 
 | 197 | +                # op. Therefore, regardless of whether the user created this cast or  | 
 | 198 | +                # not, it is now redundant (noop), and should be removed.  | 
 | 199 | +                #  | 
 | 200 | +                # The removal isn't covered by the main cast  | 
 | 201 | +                # optimization pass since that pass runs before this pass.  | 
 | 202 | +                block.replace_uses_of_var_after_op(  | 
 | 203 | +                    anchor_op=op, old_var=op.outputs[0], new_var=op.x  | 
 | 204 | +                )  | 
 | 205 | +                block.remove_ops([op])  | 
 | 206 | +                len_block = len(block.operations)  | 
 | 207 | +                i -= 1  | 
 | 208 | +            elif output_type_str != op.dtype.val:  | 
 | 209 | +                # The type of the output of this cast op was changed per the rules  | 
 | 210 | +                # defined in the top level comment for adjust_io_to_supported_types.  | 
 | 211 | +                #  | 
 | 212 | +                # This cast is meaningful, and the "dtype" param now differs from the output  | 
 | 213 | +                # type. Replace the dtype cast with a new cast op with a matching dtype param.  | 
 | 214 | +                with block:  | 
 | 215 | +                    new_cast_out = _mb.cast(x=op.x, dtype=output_type_str, before_op=op)  | 
 | 216 | +                    block.replace_uses_of_var_after_op(  | 
 | 217 | +                        anchor_op=op, old_var=op.outputs[0], new_var=new_cast_out  | 
 | 218 | +                    )  | 
 | 219 | +                block.remove_ops([op])  | 
 | 220 | +                len_block = len(block.operations)  | 
 | 221 | +        i = i + 1  | 
 | 222 | +    return block  | 
 | 223 | + | 
 | 224 | +#####  | 
 | 225 | +# The Pass  | 
 | 226 | +#####  | 
 | 227 | +def _adjust_io_to_supported_types(func, is_main):  | 
 | 228 | +    if is_main:  | 
 | 229 | +        _adjust_main_inputs(func)  | 
 | 230 | +        _adjust_ops(func)  | 
 | 231 | +        _adjust_main_outputs(func)  | 
 | 232 | +    else:  | 
 | 233 | +        _adjust_func_inputs(func)  | 
 | 234 | +        _adjust_ops(func)  | 
0 commit comments