Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,19 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape of data");
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape of data");

const auto& data_name = input_defs[0]->Name();
const auto& new_shape_name = input_defs[1]->Name();
Initializer unpacked_tensor(model_builder.GetGraphViewer().GetGraph(), *model_builder.GetConstantInitializer(new_shape_name));
TensorShapeVector new_shape = ToShapeVector(unpacked_tensor.DataAsSpan<int64_t>());

// ReshapeHelper applies the ONNX rules to create the concrete output shape
ReshapeHelper helper(TensorShape(input_shape), new_shape);
// ReshapeHelper applies the ONNX rules to create the concrete output shape.
// Only use it if the input shape is static (no dynamic dimensions).
// CoreML MIL supports -1 in the shape and can infer at runtime.
if (IsStaticShape(input_shape)) {
ReshapeHelper helper(TensorShape(input_shape), new_shape);
}

if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;
Expand Down Expand Up @@ -96,7 +100,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node,
}

std::vector<int64_t> input_shape;
if (!GetStaticShape(*input_defs[0], input_shape, logger))
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

if (input_shape.empty()) {
Expand Down
35 changes: 35 additions & 0 deletions onnxruntime/test/providers/coreml/dynamic_input_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,41 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MobileNetExcerpt) {
}
}

TEST(CoreMLExecutionProviderDynamicInputShapeTest, Reshape) {
constexpr auto model_path = ORT_TSTR("testdata/reshape_with_dynamic_input_shape.onnx");

auto test = [&](const size_t M) {
SCOPED_TRACE(MakeString("M=", M));
std::unordered_map<std::string, std::string> options;
auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider();

const auto ep_verification_params = EPVerificationParams{
ExpectedEPNodeAssignment::All,
1e-4f,
};

#if defined(__APPLE__)
RandomValueGenerator gen{1234};
// Input is [M, 512], reshape to [-1, 2048] so M must be a multiple of 4.
const auto X_shape = std::vector<int64_t>{static_cast<int64_t>(M * 4), 512};
const auto X_data = gen.Uniform<float>(X_shape, 0.0f, 1.0f);

OrtValue X = CreateInputOrtValueOnCPU<float>(X_shape, X_data);

RunAndVerifyOutputsWithEP(model_path, CurrentTestName(),
std::move(coreml_ep),
{{"X", X}},
ep_verification_params);
#else
TestModelLoad(model_path, std::move(coreml_ep), ep_verification_params.ep_node_assignment);
#endif
};

for (size_t i = 1; i <= 3; ++i) {
test(i);
}
}

TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) {
constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx");

Expand Down
Binary file not shown.
33 changes: 33 additions & 0 deletions onnxruntime/test/testdata/reshape_with_dynamic_input_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pathlib import Path

import onnx
from onnx import TensorProto, helper

# This model contains a Reshape where:
# - X has shape [M, 512] and `M` is a dynamic dimension.
# - shape is a constant initializer with value [-1, 2048].
# CoreML MIL supports -1 in the shape and can infer the dimension at runtime.

M = "M"
K = 512
N = 2048

graph = helper.make_graph(
[ # nodes
helper.make_node("Reshape", ["X", "shape"], ["Y"], "Reshape"),
],
"ReshapeWithDynamicInputShape", # name
[ # inputs
helper.make_tensor_value_info("X", TensorProto.FLOAT, [M, K]),
],
[ # outputs
helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, N]),
],
[ # initializers
helper.make_tensor("shape", TensorProto.INT64, [2], [-1, N]),
],
)

opset_imports = [helper.make_operatorsetid("", 19)]
model = helper.make_model(graph, opset_imports=opset_imports)
onnx.save(model, str(Path(__file__).parent / "reshape_with_dynamic_input_shape.onnx"))