|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "reshape.h" |
| 5 | + |
| 6 | +#include <gsl/span> |
| 7 | +#include <vector> |
| 8 | +#include "utils.h" |
| 9 | + |
| 10 | +// ONNX Reshape version 21 |
| 11 | +ONNX_OPERATOR_VERSIONED_KERNEL_EX( |
| 12 | + Reshape, |
| 13 | + kOnnxDomain, |
| 14 | + /*start_version*/ 21, /*end_version (inclusive)*/ 22, |
| 15 | + (Ort::KernelDefBuilder() |
| 16 | + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) |
| 17 | + .AddTypeConstraint("shape", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) |
| 18 | + .AddInputOutputAlias(0, 0) |
| 19 | + .SetInputMemType(1, OrtMemTypeCPU)), |
| 20 | + Reshape) |
| 21 | + |
| 22 | +// ONNX Reshape version 23 |
| 23 | +ONNX_OPERATOR_KERNEL_EX( |
| 24 | + Reshape, |
| 25 | + kOnnxDomain, |
| 26 | + /*version*/ 23, // Equivalent to start_version: 23, end_version: 23 |
| 27 | + (Ort::KernelDefBuilder() |
| 28 | + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) |
| 29 | + .AddTypeConstraint("shape", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) |
| 30 | + .AddInputOutputAlias(0, 0) |
| 31 | + .SetInputMemType(1, OrtMemTypeCPU)), |
| 32 | + Reshape) |
| 33 | + |
| 34 | +// ONNX Reshape version 24 |
| 35 | +ONNX_OPERATOR_KERNEL_EX( |
| 36 | + Reshape, |
| 37 | + kOnnxDomain, |
| 38 | + /*version*/ 24, // Equivalent start_version: 24, end_version: 24 |
| 39 | + (Ort::KernelDefBuilder() |
| 40 | + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) |
| 41 | + .AddTypeConstraint("shape", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) |
| 42 | + .AddInputOutputAlias(0, 0) |
| 43 | + .SetInputMemType(1, OrtMemTypeCPU)), |
| 44 | + Reshape) |
| 45 | + |
| 46 | +Reshape::Reshape(const OrtKernelInfo* info, void* state, bool allow_zero, PrivateTag) |
| 47 | + : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL |
| 48 | + info_{info}, |
| 49 | + data_transfer_impl_{reinterpret_cast<OrtDataTransferImpl*>(state)}, |
| 50 | + allow_zero_{allow_zero} { |
| 51 | + ort_version_supported = ORT_API_VERSION; |
| 52 | + Compute = ComputeImpl; |
| 53 | + Release = ReleaseImpl; |
| 54 | +} |
| 55 | + |
| 56 | +/*static*/ |
| 57 | +OrtStatus* Reshape::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr<Reshape>& kernel) noexcept { |
| 58 | + EXCEPTION_TO_RETURNED_STATUS_BEGIN |
| 59 | + Ort::ConstKernelInfo kernel_info(info); |
| 60 | + bool allow_zero = kernel_info.GetAttribute<int64_t>("allowzero") == 1; |
| 61 | + |
| 62 | + kernel = std::make_unique<Reshape>(info, state, allow_zero, PrivateTag{}); |
| 63 | + return nullptr; |
| 64 | + EXCEPTION_TO_RETURNED_STATUS_END |
| 65 | +} |
| 66 | + |
| 67 | +// Computes the requested shape for the reshape operation. |
| 68 | +// Implementation is based on ReshapeHelper in onnxruntime/core/providers/cpu/tensor/reshape_helper.h |
| 69 | +static OrtStatus* GetRequestedShape(gsl::span<const int64_t> input_shape, bool allow_zero, |
| 70 | + /*out*/ std::vector<int64_t>& requested_shape) { |
| 71 | + EXCEPTION_TO_RETURNED_STATUS_BEGIN |
| 72 | + const OrtApi& ort_api = Ort::GetApi(); |
| 73 | + |
| 74 | + int64_t num_input_elems = 1; |
| 75 | + for (auto dim_val : input_shape) { |
| 76 | + num_input_elems *= dim_val; |
| 77 | + } |
| 78 | + RETURN_IF(num_input_elems == -1, ort_api, "Input tensor must not have dynamic (-1) dimensions."); |
| 79 | + |
| 80 | + size_t num_dims = requested_shape.size(); |
| 81 | + int64_t unknown_dim = -1; |
| 82 | + int64_t size = 1; |
| 83 | + |
| 84 | + for (size_t i = 0; i < num_dims; i++) { |
| 85 | + RETURN_IF(requested_shape[i] < -1, ort_api, "A dimension cannot be less than -1"); |
| 86 | + |
| 87 | + if (requested_shape[i] == -1) { |
| 88 | + RETURN_IF(unknown_dim != -1, ort_api, "At most one dimension can be -1"); |
| 89 | + unknown_dim = static_cast<int64_t>(i); |
| 90 | + } else { |
| 91 | + if (!allow_zero && requested_shape[i] == 0) { |
| 92 | + RETURN_IF(i >= input_shape.size(), ort_api, |
| 93 | + "The dimension with value zero exceeds the dimension size of the input"); |
| 94 | + requested_shape[i] = input_shape[i]; |
| 95 | + } |
| 96 | + |
| 97 | + size *= requested_shape[i]; |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + if (unknown_dim != -1) { |
| 102 | + // Calculate unknown dimension. |
| 103 | + RETURN_IF(size == 0 || (num_input_elems % size) != 0, ort_api, |
| 104 | + "The input cannot be reshaped to the requested shape"); |
| 105 | + requested_shape[unknown_dim] = num_input_elems / size; |
| 106 | + } else { |
| 107 | + // Check if the output shape is valid. |
| 108 | + RETURN_IF(num_input_elems != size, ort_api, "The input cannot be reshaped to the requested shape"); |
| 109 | + } |
| 110 | + |
| 111 | + return nullptr; |
| 112 | + EXCEPTION_TO_RETURNED_STATUS_END |
| 113 | +} |
| 114 | + |
| 115 | +/*static*/ |
| 116 | +OrtStatus* ORT_API_CALL Reshape::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { |
| 117 | + EXCEPTION_TO_RETURNED_STATUS_BEGIN |
| 118 | + Reshape* reshape_kernel = static_cast<Reshape*>(this_ptr); |
| 119 | + static_cast<void>(reshape_kernel->info_); // NOTE: Unused in this example. |
| 120 | + |
| 121 | + Ort::KernelContext kernel_context(kernel_ctx); |
| 122 | + |
| 123 | + // Input[0] has the data to reshape. |
| 124 | + Ort::ConstValue input = kernel_context.GetInput(0); |
| 125 | + auto type_shape_info = input.GetTensorTypeAndShapeInfo(); |
| 126 | + std::vector<int64_t> input_shape = type_shape_info.GetShape(); |
| 127 | + |
| 128 | + // Input[1] has the requested shape for the reshape operation. |
| 129 | + Ort::ConstValue shape_input = kernel_context.GetInput(1); |
| 130 | + gsl::span<const int64_t> shape_input_data; |
| 131 | + std::vector<int64_t> final_shape; |
| 132 | + |
| 133 | + RETURN_IF_ERROR(GetValueDataAndShape(shape_input, shape_input_data, final_shape)); |
| 134 | + RETURN_IF(final_shape.size() != 1, Ort::GetApi(), "A shape tensor must have one dimension"); |
| 135 | + RETURN_IF_ERROR(GetRequestedShape(input_shape, reshape_kernel->allow_zero_, final_shape)); |
| 136 | + |
| 137 | + Ort::UnownedValue output = kernel_context.GetOutput(0, final_shape); |
| 138 | + |
| 139 | + // This kernel aliases the input and output, so a copy is not really necessary. |
| 140 | + // CopyTensor() will not do a copy if the source and destination buffers are the same. |
| 141 | + RETURN_IF_ERROR(CopyTensor(*reshape_kernel->data_transfer_impl_, input, output)); |
| 142 | + return nullptr; |
| 143 | + EXCEPTION_TO_RETURNED_STATUS_END |
| 144 | +} |
| 145 | + |
| 146 | +/*static*/ |
| 147 | +void ORT_API_CALL Reshape::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { |
| 148 | + delete static_cast<Reshape*>(this_ptr); |
| 149 | +} |
0 commit comments