Skip to content

Commit 4409789

Browse files
committed
Fix CoreML reshape to support dynamic shapes
CoreML MIL supports -1 in shape dimensions and can infer at runtime. Change GetStaticShape to GetShape to allow dynamic input shapes, and only apply ReshapeHelper when the shape is fully static. Fixes #26328
1 parent 2b08a0c commit 4409789

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,19 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
4040
const logging::Logger& logger) const {
4141
const auto& input_defs = node.InputDefs();
4242
std::vector<int64_t> input_shape;
43-
ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape of data");
43+
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape of data");
4444

4545
const auto& data_name = input_defs[0]->Name();
4646
const auto& new_shape_name = input_defs[1]->Name();
4747
Initializer unpacked_tensor(model_builder.GetGraphViewer().GetGraph(), *model_builder.GetConstantInitializer(new_shape_name));
4848
TensorShapeVector new_shape = ToShapeVector(unpacked_tensor.DataAsSpan<int64_t>());
4949

50-
// ReshapeHelper applies the ONNX rules to create the concrete output shape
51-
ReshapeHelper helper(TensorShape(input_shape), new_shape);
50+
// ReshapeHelper applies the ONNX rules to create the concrete output shape.
51+
// Only use it if the input shape is static (no dynamic dimensions).
52+
// CoreML MIL supports -1 in the shape and can infer at runtime.
53+
if (IsStaticShape(input_shape)) {
54+
ReshapeHelper helper(TensorShape(input_shape), new_shape);
55+
}
5256

5357
if (model_builder.CreateMLProgram()) {
5458
using namespace CoreML::Specification::MILSpec;
@@ -96,7 +100,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node,
96100
}
97101

98102
std::vector<int64_t> input_shape;
99-
if (!GetStaticShape(*input_defs[0], input_shape, logger))
103+
if (!GetShape(*input_defs[0], input_shape, logger))
100104
return false;
101105

102106
if (input_shape.empty()) {

0 commit comments

Comments
 (0)