| 
8 | 8 | import numpy as np  | 
9 | 9 | 
 
  | 
10 | 10 | from coremltools import _logger as logger  | 
11 |  | -from coremltools.converters.mil._deployment_compatibility import \  | 
12 |  | -    AvailableTarget as _target  | 
 | 11 | +from coremltools.converters.mil._deployment_compatibility import AvailableTarget as _target  | 
 | 12 | +from coremltools.converters.mil.backend.mil import helper  | 
13 | 13 | from coremltools.converters.mil.mil import Block  | 
14 | 14 | from coremltools.converters.mil.mil import Builder as mb  | 
15 |  | -from coremltools.converters.mil.mil import (Function, ListVar, Placeholder,  | 
16 |  | -                                            Program, TupleInputType, Var,  | 
17 |  | -                                            mil_list, types)  | 
 | 15 | +from coremltools.converters.mil.mil import (  | 
 | 16 | +    Function,  | 
 | 17 | +    ListVar,  | 
 | 18 | +    Placeholder,  | 
 | 19 | +    Program,  | 
 | 20 | +    TupleInputType,  | 
 | 21 | +    Var,  | 
 | 22 | +    mil_list,  | 
 | 23 | +    types,  | 
 | 24 | +)  | 
18 | 25 | from coremltools.converters.mil.mil.block import curr_block  | 
19 |  | -from coremltools.converters.mil.mil.ops.registry import \  | 
20 |  | -    SSAOpRegistry as _SSAOpRegistry  | 
 | 26 | +from coremltools.converters.mil.mil.ops.registry import SSAOpRegistry as _SSAOpRegistry  | 
21 | 27 | from coremltools.proto import MIL_pb2 as pm  | 
22 | 28 | from coremltools.proto import Model_pb2 as ml  | 
23 | 29 | 
 
  | 
24 | 30 | from .helper import proto_to_types  | 
25 | 31 | 
 
  | 
26 | 32 | try:  | 
27 | 33 |     from coremltools.libmilstoragepython import _BlobStorageReader as BlobReader  | 
28 |  | -except:  | 
 | 34 | +except Exception as e:  | 
 | 35 | +    logger.warning(f"Fail to import BlobReader from libmilstoragepython. {e}")  | 
29 | 36 |     BlobReader = None  | 
30 | 37 | 
 
  | 
31 | 38 | 
 
  | 
@@ -145,7 +152,7 @@ def _load_value(context, value_spec):  | 
145 | 152 |         else:  | 
146 | 153 |             value = _load_file_value(context, value_spec.blobFileValue, dtype)  | 
147 | 154 | 
 
  | 
148 |  | -        if dtype in (types.fp16, types.int8, types.uint8, types.uint32):  | 
 | 155 | +        if dtype in helper.IMMEDIATE_VALUE_TYPES_IN_BYTES:  | 
149 | 156 |             value = np.frombuffer(value, types.nptype_from_builtin(dtype)).reshape(  | 
150 | 157 |                 shape  | 
151 | 158 |             )  | 
@@ -246,20 +253,23 @@ def _dummy_false_fn(*loop_vars):  | 
246 | 253 |         inputs["_false_fn"] = _dummy_false_fn  | 
247 | 254 | 
 
  | 
248 | 255 | 
 
  | 
 | 256 | +def _load_const_op(context, op_spec):  | 
 | 257 | +    inputs = {k: _load_value(context, v) for k, v in op_spec.attributes.items()}  | 
 | 258 | +    pymil_var = getattr(mb, op_spec.type)(**inputs)  | 
 | 259 | +    context.register_var_with_name(op_spec.outputs[0].name, pymil_var)  | 
 | 260 | + | 
 | 261 | + | 
249 | 262 | def _load_operation(context, op_spec):  | 
250 | 263 |     if not isinstance(op_spec, pm.Operation):  | 
251 | 264 |         raise TypeError("Invalid Operation spec object")  | 
252 | 265 | 
 
  | 
253 | 266 |     op_type = op_spec.type  | 
254 |  | -    if op_type == "const" or op_type.startswith("constexpr_"):  | 
 | 267 | +    if op_type == "const" or "constexpr_" in op_type:  | 
255 | 268 |         if op_spec.blocks:  | 
256 | 269 |             raise ValueError("const / constexpr operation can't have any block")  | 
257 | 270 |         if op_spec.inputs:  | 
258 | 271 |             raise ValueError("const / constexpr operation can't have any input")  | 
259 |  | - | 
260 |  | -        inputs = {k: _load_value(context, v) for k, v in op_spec.attributes.items()}  | 
261 |  | -        pymil_var = getattr(mb, op_type)(**inputs)  | 
262 |  | -        context.register_var_with_name(op_spec.outputs[0].name, pymil_var)  | 
 | 272 | +        _load_const_op(context, op_spec)  | 
263 | 273 | 
 
  | 
264 | 274 |     else:  | 
265 | 275 |         if op_type == "custom_layer":  | 
@@ -402,11 +412,19 @@ def _load_function(context, func_spec, spec_version):  | 
402 | 412 | 
 
  | 
403 | 413 | 
 
  | 
404 | 414 | def load(model_spec, specification_version, file_weights_dir="", **kwargs):  | 
 | 415 | +    """  | 
 | 416 | +    Load MILProto to Pymil.  | 
 | 417 | +
  | 
 | 418 | +    Set force_spec_version to force override the spec version.  | 
 | 419 | +    """  | 
405 | 420 |     if not isinstance(model_spec, ml.Model):  | 
406 | 421 |         raise TypeError("Invalid Model sepc object")  | 
407 | 422 | 
 
  | 
408 | 423 |     if specification_version < model_spec.specificationVersion:  | 
409 |  | -        raise ValueError("specification_version must be greater or equal to the input model spec version")  | 
 | 424 | +        if not kwargs.get("force_spec_version", False):  | 
 | 425 | +            raise ValueError(  | 
 | 426 | +                "specification_version must be greater or equal to the input model spec version"  | 
 | 427 | +            )  | 
410 | 428 | 
 
  | 
411 | 429 |     if model_spec.WhichOneof("Type") != "mlProgram":  | 
412 | 430 |         raise ValueError("Only MIL proto based mlmodels can be loaded")  | 
 | 
0 commit comments