diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc index 3da5ebe5e1e4d..8fa1753d2f86c 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc @@ -56,7 +56,7 @@ const std::vector& EinsumComputePreprocessor::GetMappedSubscriptIndices return subscript_indices_to_output_indices_; } -int64_t EinsumComputePreprocessor::GetNumSubscriptIndices() const { +size_t EinsumComputePreprocessor::GetNumSubscriptIndices() const { return num_subscript_indices_; } @@ -73,7 +73,7 @@ Status EinsumComputePreprocessor::ProcessSubscripts() { "Number of subscripts in the input equation does not match number of input tensors"); } - int64_t input_index = 0; + size_t input_index = 0; // Holds mapping between input indices to its corresponding subscript labels for each input input_subscript_indices_.reserve(inputs_.size()); @@ -84,7 +84,7 @@ Status EinsumComputePreprocessor::ProcessSubscripts() { subscript_indices_to_dim_value_.reserve(10); for (const auto& subscript : left_equation_split) { - const auto& shape = inputs_[onnxruntime::narrow(input_index)]->Shape(); + const auto& shape = inputs_[input_index]->Shape(); const auto& dims = shape.GetDims(); size_t rank = dims.size(); size_t dim_counter = 0; @@ -237,13 +237,13 @@ Status EinsumComputePreprocessor::PostProcessBroadcastedDims() { } } - std::vector temp_index_to_last_input(onnxruntime::narrow(num_subscript_indices_), -1); + std::vector temp_index_to_last_input(num_subscript_indices_, -1); for (size_t i = 0; i < subscript_indices_to_last_input_.size(); ++i) { temp_index_to_last_input[i + num_of_ellipsis_dims_] = subscript_indices_to_last_input_[i]; } subscript_indices_to_last_input_ = std::move(temp_index_to_last_input); - std::vector temp_index_to_dim_value(onnxruntime::narrow(num_subscript_indices_), -1); + std::vector temp_index_to_dim_value(num_subscript_indices_, -1); for (size_t i = 0; i < subscript_indices_to_dim_value_.size(); ++i) { temp_index_to_dim_value[i + num_of_ellipsis_dims_] = subscript_indices_to_dim_value_[i]; } @@ -338,7 +338,7 @@ Status EinsumComputePreprocessor::CalculateOutputShape() { bool is_in_middle_of_ellipsis = false; int64_t ellipsis_char_count = 0; - subscript_indices_to_output_indices_.resize(onnxruntime::narrow(num_subscript_indices_), -1); + subscript_indices_to_output_indices_.resize(num_subscript_indices_, -1); std::array output_letter_to_count; output_letter_to_count.fill(0); @@ -407,13 +407,13 @@ Status EinsumComputePreprocessor::PreprocessInputs() { // As part of input preprocessing we "homogenize" them by // 1) Making them all of the same rank // 2) The axes order in all the inputs are to be made the same - int64_t input_iter = 0; + size_t input_iter = 0; for (const auto* input : inputs_) { // Eventually will hold the "preprocessed" version of the original input std::unique_ptr preprocessed; const auto& input_dims = input->Shape().GetDims(); - const auto& current_subscript_indices = input_subscript_indices_[onnxruntime::narrow(input_iter)]; + const auto& current_subscript_indices = input_subscript_indices_[input_iter]; // If all has gone well, we will have a subscript index (subscript label) for each dim of the input if (input_dims.size() != current_subscript_indices.size()) { @@ -421,10 +421,10 @@ Status EinsumComputePreprocessor::PreprocessInputs() { "Rank of the input must match number of subscript labels corresponding to the input"); } - std::vector subscript_indices_to_input_index(onnxruntime::narrow(num_subscript_indices_), -1); + std::vector subscript_indices_to_input_index(num_subscript_indices_, -1); // This is the input dims after re-ordering so that all inputs have same axes order - TensorShapeVector homogenized_input_dims(onnxruntime::narrow(num_subscript_indices_), 1); + TensorShapeVector homogenized_input_dims(num_subscript_indices_, 1); // Preprocessed dim rank may not be the same as original input rank if we need to parse diagonals along the way // (which reduces rank in the preprocessed input by 1 for each diagonal we parse) @@ -437,7 +437,7 @@ Status EinsumComputePreprocessor::PreprocessInputs() { subscript_indices_to_input_index[onnxruntime::narrow(subscript_index)] = dim_index_in_preprocessed_input++; homogenized_input_dims[onnxruntime::narrow(subscript_index)] = input_dims[onnxruntime::narrow(dim_index_in_original_input)]; } else { // Diagonal needs to be parsed along the repeated axes - preprocessed = device_diagonal_func_(preprocessed ? *preprocessed : *inputs_[onnxruntime::narrow(input_iter)], + preprocessed = device_diagonal_func_(preprocessed ? *preprocessed : *inputs_[input_iter], subscript_indices_to_input_index[onnxruntime::narrow(subscript_index)], dim_index_in_preprocessed_input, allocator_, einsum_ep_assets_); @@ -454,10 +454,10 @@ Status EinsumComputePreprocessor::PreprocessInputs() { } // (Identify no-op transpose and prevent triggering the transpose) - if (EinsumOp::IsTransposeRequired(preprocessed ? preprocessed->Shape().GetDims().size() : inputs_[onnxruntime::narrow(input_iter)]->Shape().GetDims().size(), + if (EinsumOp::IsTransposeRequired(preprocessed ? preprocessed->Shape().GetDims().size() : inputs_[input_iter]->Shape().GetDims().size(), permutation)) { - preprocessed = EinsumOp::Transpose(preprocessed ? *preprocessed : *inputs_[onnxruntime::narrow(input_iter)], - preprocessed ? preprocessed->Shape().GetDims() : inputs_[onnxruntime::narrow(input_iter)]->Shape().GetDims(), + preprocessed = EinsumOp::Transpose(preprocessed ? *preprocessed : *inputs_[input_iter], + preprocessed ? preprocessed->Shape().GetDims() : inputs_[input_iter]->Shape().GetDims(), permutation, allocator_, einsum_ep_assets_, device_transpose_func_); } diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h index 3c02ddf612ec3..e3d6d09cd52ec 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h @@ -140,7 +140,7 @@ class EinsumComputePreprocessor final { const std::vector& GetMappedSubscriptIndicesToOutputindices() const; // Get the number of subscript indices (subscript labels) in the einsum equation - int64_t GetNumSubscriptIndices() const; + size_t GetNumSubscriptIndices() const; // Pass-in device specific functions // (Pass-in CPU implementation or CUDA implementation function depending on the kernel using this class) @@ -185,7 +185,7 @@ class EinsumComputePreprocessor final { // num_subscript_indices_ = 3 (i, j, k) // E.g. 2 : With equation -> '...ij', 'jk' -> '...ik' // num_subscript_indices_ = 3 (i, j, k) + number of dims specified by an ellipsis (across all inputs) - int64_t num_subscript_indices_ = 0; + size_t num_subscript_indices_ = 0; // Hold the count corresponding to the letter seen // `0` means the corresponding letter wasn't seen at all diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 096e07eb8e272..6bd0988ee470c 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -10,6 +10,12 @@ namespace onnxruntime { template void EinsumTypedComputeProcessor::FinalizeOutput(const Tensor& candidate_output, const gsl::span& ordered_subscript_indices_in_candidate) { + ORT_ENFORCE(candidate_output.Shape().NumDimensions() == ordered_subscript_indices_in_candidate.size(), + "Einsum op: The candidate output's rank has to be the same number of elements as " + "the ordered subscript indices in the candidate output. Hitting this error points to an " + "internal bug in the Einsum op's implementation. " + "Please open a bug report with appropriate repro steps"); + const std::vector& subscript_indices_to_output_indices = einsum_compute_preprocessor_.GetMappedSubscriptIndicesToOutputindices(); const auto output_dims = einsum_compute_preprocessor_.GetOutputDims(); @@ -75,7 +81,7 @@ void EinsumTypedComputeProcessor::FinalizeOutput(const Tensor& candidate_outp static bool IsTransposeReshapeForEinsum(const gsl::span& perm, gsl::span input_dims, TensorShapeVector& new_shape) { - // As long as the dims with values > 1 stay in the same order, it's a reshape. + // As long as the dims with values > 1 stay in the same relative order, it's a reshape. // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). size_t last_permuted_axis = 0; for (size_t i = 0; i < perm.size(); ++i) { @@ -361,17 +367,19 @@ Status EinsumTypedComputeProcessor::Run() { std::unique_ptr result; { - TensorShapeVector reduced_dims; - TensorShapeVector preserved_dims; // dims which were not reduced - reduced_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving. - preserved_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving. + TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op + reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving + + TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1 + all_dims.reserve(num_subscript_labels); // num_subscript_labels is the number of elements - for (size_t i = 0; i < onnxruntime::narrow(num_subscript_labels); ++i) { + for (size_t i = 0; i < num_subscript_labels; ++i) { if (mapped_indices_to_last_input_index[i] == 0) { reduced_dims.push_back(i); - } else { - preserved_dims.push_back(i); } + + // ReduceSum operation preserves even the reduced dims with reduced dim shape value being 1 + all_dims.push_back(i); } // Reduce the dims that are last seen in the first input alone @@ -391,7 +399,7 @@ Status EinsumTypedComputeProcessor::Run() { if (num_inputs == 1) { // Finalize the output by applying any transpose required to get // it to the required output ordering and move it to the op's output - FinalizeOutput(result ? *result : *raw_inputs[0], preserved_dims); + FinalizeOutput(result ? *result : *raw_inputs[0], all_dims); return Status::OK(); } @@ -403,9 +411,9 @@ Status EinsumTypedComputeProcessor::Run() { // Keep processing each input pair-wise for (int input = 1; input < num_inputs; ++input) { TensorShapeVector reduced_dims; - reduced_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin. - for (int64_t dim = 0; dim < num_subscript_labels; ++dim) { - if (mapped_indices_to_last_input_index[onnxruntime::narrow(dim)] == input) { + reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin. + for (size_t dim = 0; dim < num_subscript_labels; ++dim) { + if (mapped_indices_to_last_input_index[dim] == input) { // This is the last input we are seeing this dimension (and it doesn't occur in the output), so reduce along the dimension reduced_dims.push_back(dim); } diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index d3ea8552f60f4..30b89f6e86ed9 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -114,6 +114,14 @@ TEST(Einsum, ExplicitEinsumAsBatchedReduceOp_3D_input_1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); } +TEST(Einsum, ExplicitEinsumAsReduceWithTransposeOp_3D_input_0) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "ijk->ki"); + test.AddInput("x", {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + test.AddOutput("y", {4, 2}, {15.f, 15.f, 18.f, 18.f, 21.f, 21.f, 24.f, 24.f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); +} + // Implicit // Cannot do implicit reduction