Skip to content

Commit 7b13371

Browse files
authored
8.0 Release (#2342)
1 parent 40f6705 commit 7b13371

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+4797
-1857
lines changed

coremlpython/CoreMLPython.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ namespace CoreML {
5757
Model(const Model&) = delete;
5858
Model& operator=(const Model&) = delete;
5959
~Model();
60-
explicit Model(const std::string& urlStr, const std::string& computeUnits, const std::string& functionName);
60+
explicit Model(const std::string& urlStr, const std::string& computeUnits, const std::string& functionName, const py::dict& optimizationHints);
6161
explicit Model(MLModel* m_model, NSURL* compiledUrl, bool deleteCompiledModelOnExit);
6262

6363
py::list batchPredict(const py::list& batch) const;
@@ -67,6 +67,7 @@ namespace CoreML {
6767
py::dict predict(const py::dict& input, State* state=NULL) const;
6868

6969
#if BUILT_WITH_MACOS15_SDK
70+
static void setOptimizationHints(MLModelConfiguration *configuration, const py::dict& optimizationHints);
7071
State newState() const;
7172
#endif
7273

coremlpython/CoreMLPython.mm

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@ bool usingMacOS13OrHigher() {
4242
}
4343
}
4444

45-
Model::Model(const std::string& urlStr, const std::string& computeUnits, const std::string& functionName) {
45+
Model::Model(
46+
const std::string& urlStr,
47+
const std::string& computeUnits,
48+
const std::string& functionName,
49+
const py::dict& optimizationHints
50+
) {
4651
@autoreleasepool {
4752
NSError *error = nil;
4853

@@ -80,6 +85,10 @@ bool usingMacOS13OrHigher() {
8085
MLModelConfiguration *configuration = [MLModelConfiguration new];
8186
setComputeUnit(configuration, computeUnits);
8287

88+
#if BUILT_WITH_MACOS15_SDK
89+
setOptimizationHints(configuration, optimizationHints);
90+
#endif
91+
8392
if (!functionName.empty()) {
8493
#if BUILT_WITH_MACOS15_SDK
8594
configuration.functionName = [NSString stringWithUTF8String:functionName.c_str()];
@@ -148,6 +157,37 @@ bool usingMacOS13OrHigher() {
148157
}
149158

150159

160+
#if BUILT_WITH_MACOS15_SDK
161+
void Model::setOptimizationHints(MLModelConfiguration *configuration, const py::dict& optimizationHints) {
162+
// This function does minimal validation. It assumes Python layer has already validated.
163+
164+
// Reshape frequency optimization hint
165+
if (optimizationHints.contains("reshapeFrequency")) {
166+
const std::string val = optimizationHints["reshapeFrequency"].cast<std::string>();
167+
if (val == "Frequent") {
168+
configuration.optimizationHints.reshapeFrequency = MLReshapeFrequencyHintFrequent;
169+
} else {
170+
assert(val == "Infrequent");
171+
configuration.optimizationHints.reshapeFrequency = MLReshapeFrequencyHintInfrequent;
172+
}
173+
}
174+
175+
// Specialization strategy optimization hint
176+
if (optimizationHints.contains("specializationStrategy")) {
177+
const std::string val = optimizationHints["specializationStrategy"].cast<std::string>();
178+
if (val == "Default") {
179+
configuration.optimizationHints.specializationStrategy = MLSpecializationStrategyDefault;
180+
} else {
181+
assert(val == "FastPrediction");
182+
configuration.optimizationHints.specializationStrategy = MLSpecializationStrategyFastPrediction;
183+
}
184+
}
185+
186+
187+
}
188+
#endif
189+
190+
151191
py::list Model::batchPredict(const py::list& batch) const {
152192
@autoreleasepool {
153193
NSError* error = nil;
@@ -237,7 +277,7 @@ bool usingMacOS13OrHigher() {
237277
py::module m("libcoremlpython", "CoreML.Framework Python bindings");
238278

239279
py::class_<Model>(m, "_MLModelProxy")
240-
.def(py::init<const std::string&, const std::string&, const std::string&>())
280+
.def(py::init<const std::string&, const std::string&, const std::string&, const py::dict&>())
241281
.def("predict", &Model::predict)
242282
.def("batchPredict", &Model::batchPredict)
243283
.def("get_compiled_model_path", &Model::getCompiledModelPath)

coremltools/__init__.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,35 @@ class ComputeUnit(_Enum):
7272
'''
7373
The set of processing-unit configurations the model can use to make predictions.
7474
'''
75-
ALL = 1 # Allows the model to use all compute units available, including the neural engine
76-
CPU_AND_GPU = 2 # Allows the model to use both the CPU and GPU, but not the neural engine
77-
CPU_ONLY = 3 # Limit the model to only use the CPU
78-
CPU_AND_NE = 4 # Allows the model to use both the CPU and neural engine, but not the GPU.
79-
# Only available on macOS >= 13.0
75+
ALL = 1 # Allows model to use all compute units available, including the neural engine.
76+
CPU_AND_GPU = 2 # Allows model to use both the CPU and GPU, but not the neural engine.
77+
CPU_ONLY = 3 # Limits model to only use the CPU.
78+
CPU_AND_NE = 4 # Allows model to use both the CPU and neural engine, but not the GPU.
79+
# Only available on macOS >= 13.0
80+
81+
82+
class ReshapeFrequency(_Enum):
83+
'''
84+
https://developer.apple.com/documentation/coreml/mlreshapefrequencyhint?language=objc
85+
'''
86+
Frequent = 1
87+
Infrequent = 2
88+
89+
90+
class SpecializationStrategy(_Enum):
91+
'''
92+
The optimization strategy for the model specialization.
93+
94+
https://developer.apple.com/documentation/coreml/mlspecializationstrategy?language=objc
95+
'''
96+
97+
# The strategy that works well for most applications.
98+
Default = 1
99+
100+
# Prefer the prediction latency at the potential cost of specialization time, memory footprint,
101+
# and the disk space usage of specialized artifacts.
102+
FastPrediction = 2
103+
80104

81105
# A dictionary that maps the CoreML model specification version to the MLProgram/MIL opset string
82106
_OPSET = {

coremltools/_deps/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,33 @@ def __get_sklearn_version(version):
153153

154154
# ---------------------------------------------------------------------------------------
155155
_HAS_TORCH = True
156-
_TORCH_MAX_VERSION = "2.3.0"
156+
_TORCH_MAX_VERSION = "2.4.0"
157157
_HAS_TORCH_EXPORT_API = False
158+
_CT_OPTIMIZE_TORCH_MIN_VERSION = "2.1.0"
159+
_IMPORT_CT_OPTIMIZE_TORCH = False
158160
try:
159161
import torch
160162
_warn_if_above_max_supported_version("Torch", torch.__version__, _TORCH_MAX_VERSION)
161163

162-
if _get_version(torch.__version__) >= Version("2.1.0"):
164+
torch_version = _get_version(torch.__version__)
165+
166+
if torch_version >= Version("2.1.0"):
163167
_HAS_TORCH_EXPORT_API = True
164168

169+
if torch_version >= Version(_CT_OPTIMIZE_TORCH_MIN_VERSION):
170+
_IMPORT_CT_OPTIMIZE_TORCH = True
171+
else:
172+
logger.warning(
173+
(
174+
f"Minimum required torch version for importing coremltools.optimize.torch is {_CT_OPTIMIZE_TORCH_MIN_VERSION}. "
175+
f"Got torch version {torch_version}."
176+
)
177+
)
178+
165179
except:
166180
_HAS_TORCH = False
167181
MSG_TORCH_NOT_FOUND = "PyTorch not found."
182+
MSG_TORCH_EXPORT_API_NOT_FOUND = "Torch.Export API not found."
168183

169184

170185
_HAS_TORCH_VISION = True
@@ -189,6 +204,13 @@ def __get_sklearn_version(version):
189204
_HAS_EXECUTORCH = False
190205
MSG_EXECUTORCH_NOT_FOUND = "Executorch not found."
191206

207+
_HAS_TORCHAO = True
208+
try:
209+
import torchao
210+
except:
211+
_HAS_TORCHAO = False
212+
MSG_TORCHAO_NOT_FOUND = "Torchao not found."
213+
192214
# ---------------------------------------------------------------------------------------
193215
try:
194216
import scipy

coremltools/converters/mil/frontend/_utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,13 @@ def _concat_dims(dims, none_if_empty=False):
513513

514514

515515
def _decompose_scaled_dot_product_attention(
516-
q: Var, k: Var, v: Var, mask: Var, name: str, before_op: Optional[Operation] = None
516+
q: Var,
517+
k: Var,
518+
v: Var,
519+
mask: Var,
520+
name: str,
521+
scale: Optional[Var] = None,
522+
before_op: Optional[Operation] = None,
517523
) -> Var:
518524
# scale the query input
519525
embed_size = q.shape[-1]
@@ -524,9 +530,12 @@ def _decompose_scaled_dot_product_attention(
524530
)
525531

526532
q, k, v = promote_input_dtypes([q, k, v])
527-
multiplicative_scale_factor = 1 / math.sqrt(embed_size)
528-
if types.builtin_to_string(q.dtype) == "fp16":
529-
multiplicative_scale_factor = np.float16(multiplicative_scale_factor)
533+
if scale is None:
534+
multiplicative_scale_factor = 1 / math.sqrt(embed_size)
535+
if types.builtin_to_string(q.dtype) == "fp16":
536+
multiplicative_scale_factor = np.float16(multiplicative_scale_factor)
537+
else:
538+
multiplicative_scale_factor = scale
530539
q = mb.mul(x=q, y=multiplicative_scale_factor, before_op=before_op)
531540

532541
# multiply query and key input tensors
@@ -583,6 +592,11 @@ def _construct_constexpr_dequant_op(
583592
scale = np.squeeze(scale)
584593
if isinstance(zero_point, (np.ndarray, np.generic)):
585594
zero_point = np.squeeze(zero_point)
595+
if len(scale.shape) > 1 or len(zero_point.shape) > 1:
596+
raise ValueError(
597+
"The more fine-grained quantization (such as blockwise) is only supported since iOS18."
598+
"Please set minimum_deployment_target to iOS18 for using it."
599+
)
586600

587601
kwargs = {
588602
"quantized_data": quantized_weights,
@@ -631,7 +645,10 @@ def _construct_constexpr_dequant_op(
631645
}
632646
if zero_point is not None and np.any(zero_point):
633647
# Only pass the offset parameter when not all elements in `zero_point` are zeroes.
634-
zero_point = zero_point.reshape(scale.shape).astype(quantized_weights.dtype)
648+
zero_point = zero_point.reshape(scale.shape)
649+
# When zero_point is integer, it's required to have the same dtype as the quantized weight.
650+
if np.issubdtype(zero_point.dtype, np.integer):
651+
zero_point = zero_point.astype(quantized_weights.dtype)
635652
kwargs["offset"] = zero_point
636653
if name is not None:
637654
kwargs["name"] = name

coremltools/converters/mil/frontend/tensorflow/test/test_ops.py

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2622,15 +2622,6 @@ def test_ios17_resize_bilinear_dynamic_shape(
26222622
target_shape,
26232623
align_corners,
26242624
):
2625-
if (
2626-
backend == ("mlprogram", "fp16")
2627-
and input_shape == (2, 5, 2, 3)
2628-
and target_shape == (20, 60)
2629-
):
2630-
pytest.xfail(
2631-
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
2632-
)
2633-
26342625
"""
26352626
Since iOS17, dynamic shape is supported by lowering to `resize` MIL op.
26362627
"""
@@ -2732,15 +2723,6 @@ def test_ios17_resize_nearest_neighbor_dynamic_shape(
27322723
input_shape,
27332724
target_shape,
27342725
):
2735-
if (
2736-
backend == ("mlprogram", "fp16")
2737-
and input_shape == (2, 5, 2, 3)
2738-
and target_shape == (20, 60)
2739-
):
2740-
pytest.xfail(
2741-
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
2742-
)
2743-
27442726
"""
27452727
Since iOS17, dynamic shape is supported by lowering to `resize` MIL op.
27462728
"""
@@ -5706,10 +5688,8 @@ def test_sort(self, compute_unit, backend, rank, dynamic):
57065688
"""
57075689
tf.sort dispatches to tf.math.top_k, and k = size of the axis to be sorted
57085690
"""
5709-
if backend[0] == "mlprogram" and dynamic:
5710-
pytest.xfail(
5711-
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
5712-
)
5691+
if platform.machine() == "x86_64" and dynamic:
5692+
pytest.xfail("rdar://135843153 ([Bug] Models failed on x86_64 platform)")
57135693

57145694
# Here we test the conversion of tf.sort(x, axis=0)
57155695
# If dynamic, we prepend None to x shape as the dynamic shape axis
@@ -6720,7 +6700,6 @@ def build_model(x):
67206700
def test_programmatic(
67216701
self, compute_unit, backend, input_block_rank, dynamic_input, dynamic_paddings
67226702
):
6723-
67246703
input_rank, block_rank = input_block_rank
67256704

67266705
# generate data
@@ -6733,6 +6712,9 @@ def test_programmatic(
67336712
if block_shape[0] == 1:
67346713
pytest.skip("neuralnetwork backend doesn't support unity block shape.")
67356714

6715+
if input_block_rank == (4, 1) and dynamic_input and not dynamic_paddings:
6716+
pytest.xfail("rdar://133558007 shape deduction failure")
6717+
67366718
paddings = []
67376719
for i in range(block_rank):
67386720
while True:
@@ -6832,14 +6814,12 @@ def test_programmatic(
68326814
self, compute_unit, backend, input_block_rank, dynamic_input, dynamic_crops
68336815
):
68346816
if (
6835-
backend == ("mlprogram", "fp16")
6836-
and input_block_rank == (3, 1) or (3,2)
6817+
platform.machine() == "x86_64"
6818+
and input_block_rank == (3, 1)
68376819
and dynamic_input
68386820
and not dynamic_crops
68396821
):
6840-
pytest.xfail(
6841-
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
6842-
)
6822+
pytest.xfail("rdar://135843153 ([Bug] Models failed on x86_64 platform)")
68436823

68446824
input_rank, block_rank = input_block_rank
68456825

@@ -6939,16 +6919,6 @@ def test_smoke_new_op(
69396919
input_shape, block_shape, crops = shape_block_crops
69406920
crops = np.array(crops, dtype=np.int32)
69416921

6942-
if (
6943-
backend == ("mlprogram", "fp16")
6944-
and shape_block_crops == [(4, 4, 6, 1), [1, 2], [[2, 1], [3, 3]]]
6945-
and dynamic_input
6946-
and not dynamic_crops
6947-
):
6948-
pytest.xfail(
6949-
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
6950-
)
6951-
69526922
# The neuralnetwork backend doesn't support these tests
69536923
if backend[0] == "neuralnetwork":
69546924
return

coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops_tf_keras.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,11 +1389,6 @@ def test_lstm_time_distributed_dense(self, compute_unit, backend):
13891389
"compute_unit, backend", itertools.product(compute_units, backends)
13901390
)
13911391
def test_lstm_dynamic_batch(self, compute_unit, backend):
1392-
if backend == ("mlprogram", "fp16"):
1393-
pytest.xfail(
1394-
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
1395-
)
1396-
13971392
input_shape = (1, 1280)
13981393
inp = tf.keras.layers.Input(shape=input_shape)
13991394
out, hn, cn = tf.keras.layers.LSTM(512,

0 commit comments

Comments
 (0)