Skip to content

Commit 9473763

Browse files
authored
release 4.0b3 (#858)
1 parent db55e57 commit 9473763

File tree

91 files changed

+5337
-1851
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+5337
-1851
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
[![Build Status](https://img.shields.io/gitlab/pipeline/zach_nation/coremltools/master)](#)
1+
[![Build Status](https://travis-ci.com/apple/coremltools.svg?branch=master)](#)
22
[![PyPI Release](https://img.shields.io/pypi/v/coremltools.svg)](#)
33
[![Python Versions](https://img.shields.io/pypi/pyversions/coremltools.svg)](#)
44

@@ -29,7 +29,7 @@ With coremltools, you can do the following:
2929
To get the latest version of coremltools:
3030

3131
```shell
32-
pip install coremltools==4.0b2
32+
pip install coremltools==4.0b3
3333
```
3434

3535
For the latest changes please see the [release notes](https://github.com/apple/coremltools/releases/).

coremltools/_deps/__init__.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import platform as _platform
1313
import re as _re
1414
import sys as _sys
15+
from packaging import version
1516

1617

1718
def __get_version(version):
@@ -104,15 +105,15 @@ def __get_sklearn_version(version):
104105

105106
if _HAS_TF_1:
106107
if tf_ver < _StrictVersion(_TF_1_MIN_VERSION):
107-
_logging.warn(
108+
_logging.warning(
108109
(
109110
"TensorFlow version %s is not supported. Minimum required version: %s ."
110111
"TensorFlow conversion will be disabled."
111112
)
112113
% (tensorflow.__version__, _TF_1_MIN_VERSION)
113114
)
114115
elif tf_ver > _StrictVersion(_TF_1_MAX_VERSION):
115-
_logging.warn(
116+
_logging.warning(
116117
"TensorFlow version %s detected. Last version known to be fully compatible is %s ."
117118
% (tensorflow.__version__, _TF_1_MAX_VERSION)
118119
)
@@ -126,7 +127,7 @@ def __get_sklearn_version(version):
126127
% (tensorflow.__version__, _TF_2_MIN_VERSION)
127128
)
128129
elif tf_ver > _StrictVersion(_TF_2_MAX_VERSION):
129-
_logging.warn(
130+
_logging.warning(
130131
"TensorFlow version %s detected. Last version known to be fully compatible is %s ."
131132
% (tensorflow.__version__, _TF_2_MAX_VERSION)
132133
)
@@ -232,6 +233,7 @@ def __get_sklearn_version(version):
232233
_HAS_TORCH = False
233234
MSG_TORCH_NOT_FOUND = "PyTorch not found."
234235

236+
235237
# ---------------------------------------------------------------------------------------
236238
_HAS_ONNX = True
237239
try:
@@ -246,3 +248,17 @@ def __get_sklearn_version(version):
246248
except:
247249
_HAS_GRAPHVIZ = False
248250
MSG_ONNX_NOT_FOUND = "ONNX not found."
251+
252+
# General utils
253+
def version_ge(module, target_version):
254+
"""
255+
Example usage:
256+
257+
>>> import torch # v1.5.0
258+
>>> version_ge(torch, '1.6.0') # False
259+
"""
260+
return version.parse(module.__version__) >= version.parse(target_version)
261+
262+
def version_lt(module, target_version):
263+
"""See version_ge"""
264+
return version.parse(module.__version__) < version.parse(target_version)

coremltools/converters/_converters_entry.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,21 @@ def convert(
7373
to determine the source framework.
7474
7575
inputs: list of `TensorType` or `ImageType`
76-
- Inputs are required for PyTorch model, but optional for TensorFlow.
77-
- For PyTorch models, the inputs may be nested list or tuple, but for
78-
TensorFlow models it must be a flat list.
79-
- For TensorFlow, if inputs is `None`, the inputs are `Placeholder`
80-
nodes in the model (if model is frozen graph) or function inputs (if
81-
model is tf function).
82-
- For TensorFlow, if inputs is not `None`, inputs may contain only a
83-
subset of all Placeholder in the TF model.
76+
TensorFlow 1 and 2:
77+
- `inputs` are optional. If not provided, the inputs are
78+
`Placeholder` nodes in the model (if model is frozen graph) or
79+
function inputs (if model is tf function)
80+
- `inputs` must corresponds to all or some of the Placeholder
81+
nodes in the TF model
82+
- `TensorType` and `ImageType` in `inputs` must have `name`
83+
specified. `shape` is optional.
84+
- If `inputs` is provided, it must be a flat list.
85+
86+
PyTorch:
87+
- `inputs` are required.
88+
- `inputs` may be nested list or tuple.
89+
- `TensorType` and `ImageType` in `inputs` must have `name`
90+
and `shape` specified.
8491
8592
outputs: list[str] (optional)
8693
@@ -263,6 +270,10 @@ def raise_if_duplicated(input_list):
263270
msg = 'Unexpected argument "example_inputs" found'
264271
raise ValueError(msg)
265272

273+
if inputs is None:
274+
msg = 'Expected argument for pytorch "inputs" not provided'
275+
raise ValueError(msg)
276+
266277
def _flatten_list(_inputs):
267278
ret = []
268279
for _input in _inputs:
@@ -313,6 +324,9 @@ def _flatten_list(_inputs):
313324
**kwargs
314325
)
315326

327+
if convert_to == 'mil':
328+
return proto_spec # Returns the MIL program
329+
316330
model = coremltools.models.MLModel(proto_spec, useCPUOnly=True)
317331

318332
if minimum_deployment_target is not None:

coremltools/converters/mil/_deployment_compatibility.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def iOS14Features(spec):
9393
if layer_type == "sliceStatic" and layer.sliceDynamic.squeezeMasks:
9494
msg = "Squeeze mask for static slice operation"
9595

96+
if layer_type == "concatND" and layer.concatND.interleave:
97+
msg = "Concat layer with interleave operation"
98+
9699
if msg != "" and (msg not in features_list):
97100
features_list.append(msg)
98101

coremltools/converters/mil/backend/nn/load.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212
from collections import defaultdict
1313

14+
import coremltools as ct
1415
from coremltools.converters.mil.input_types import (
1516
ClassifierConfig,
1617
ImageType,
@@ -171,25 +172,25 @@ def _set_optional_inputs(proto, input_types):
171172
for input_type in input_types:
172173
if isinstance(input_type, ImageType):
173174
continue
174-
is_optional = input_type.is_optional
175-
optional_value = input_type.optional_value
176-
shape = input_type.shape
177-
if not is_optional:
178-
continue
179-
msg = "Not support optional inputs for flexible input shape."
180-
if not isinstance(shape, Shape):
181-
raise NotImplementedError(msg)
182-
if any([isinstance(s, RangeDim) for s in shape.shape]):
183-
raise NotImplementedError(msg)
184-
185-
default_map[input_type.name] = optional_value
175+
if input_type.default_value is not None:
176+
default_map[input_type.name] = input_type.default_value
186177

187178
for idx, input in enumerate(proto.description.input):
188179
name = proto.description.input[idx].name
189180
if name in default_map:
181+
default_value = default_map[name]
190182
proto.description.input[idx].type.isOptional = True
191-
default_value = default_map[name] if default_map[name] is not None else 0.
192-
proto.description.input[idx].type.multiArrayType.floatDefaultValue = default_value
183+
array_t = proto.description.input[idx].type.multiArrayType
184+
default_fill_val = default_value.flatten()[0]
185+
array_t.floatDefaultValue = default_fill_val
186+
if default_fill_val != 0 or list(default_value.shape) != \
187+
array_t.shape:
188+
# promote spec version to 5 and set the default value
189+
proto.specificationVersion = max(proto.specificationVersion,
190+
ct._SPECIFICATION_VERSION_IOS_14)
191+
# array_t.shape is not empty.
192+
array_t.ClearField('shape')
193+
array_t.shape.extend(list(default_value.shape))
193194

194195

195196
@_profile

0 commit comments

Comments
 (0)