From fda956967b7413213bf6456b460f71b30bac1326 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 2 Feb 2026 19:52:00 -0800 Subject: [PATCH 01/10] Fix bug in Einsum impl when a lone operand had a reduction operation --- .../einsum_typed_compute_processor.cc | 25 +++++++++++++------ .../test/providers/cpu/math/einsum_test.cc | 9 +++++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 096e07eb8e272..ae893f105f382 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -10,6 +10,13 @@ namespace onnxruntime { template void EinsumTypedComputeProcessor::FinalizeOutput(const Tensor& candidate_output, const gsl::span& ordered_subscript_indices_in_candidate) { + + ORT_ENFORCE(candidate_output.Shape().NumDimensions() == ordered_subscript_indices_in_candidate.size(), + "Einsum op: The candidate output's rank has to be the as same number of elements as " + "the ordered subscript indices in the candidate output. Hitting this error points to an " + "internal bug in the Einsum op's implementation. " + "Please open a bug report with appropriate repro steps"); + const std::vector& subscript_indices_to_output_indices = einsum_compute_preprocessor_.GetMappedSubscriptIndicesToOutputindices(); const auto output_dims = einsum_compute_preprocessor_.GetOutputDims(); @@ -75,7 +82,7 @@ void EinsumTypedComputeProcessor::FinalizeOutput(const Tensor& candidate_outp static bool IsTransposeReshapeForEinsum(const gsl::span& perm, gsl::span input_dims, TensorShapeVector& new_shape) { - // As long as the dims with values > 1 stay in the same order, it's a reshape. + // As long as the dims with values > 1 stay in the same relative order, it's a reshape. // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). size_t last_permuted_axis = 0; for (size_t i = 0; i < perm.size(); ++i) { @@ -361,17 +368,19 @@ Status EinsumTypedComputeProcessor::Run() { std::unique_ptr result; { - TensorShapeVector reduced_dims; - TensorShapeVector preserved_dims; // dims which were not reduced - reduced_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving. - preserved_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving. + TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op + reduced_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving + + TensorShapeVector all_dims; // Expanded dimensions of `num_subscript_labels` {0, 1, ..., num_subscript_labels} + all_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the number of elements for (size_t i = 0; i < onnxruntime::narrow(num_subscript_labels); ++i) { if (mapped_indices_to_last_input_index[i] == 0) { reduced_dims.push_back(i); - } else { - preserved_dims.push_back(i); } + + // ReduceSum operation preserves even the reduced dims with reduced dim shape value being 1 + all_dims.push_back(i); } // Reduce the dims that are last seen in the first input alone @@ -391,7 +400,7 @@ Status EinsumTypedComputeProcessor::Run() { if (num_inputs == 1) { // Finalize the output by applying any transpose required to get // it to the required output ordering and move it to the op's output - FinalizeOutput(result ? *result : *raw_inputs[0], preserved_dims); + FinalizeOutput(result ? *result : *raw_inputs[0], all_dims); return Status::OK(); } diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index d3ea8552f60f4..b116ea128fcc6 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -114,6 +114,15 @@ TEST(Einsum, ExplicitEinsumAsBatchedReduceOp_3D_input_1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); } +TEST(Einsum, ExplicitEinsumAsReduceWithTransposeOp_3D_input_0) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "ijk->ki"); + test.AddInput("x", {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + test.AddOutput("y", {4, 2}, {15.f, 15.f, 18.f, 18.f, 21.f, 21.f, 24.f, 24.f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); +} + // Implicit // Cannot do implicit reduction From f0b961e4f2851c6ee74a5e4000dfa12c46a77d33 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 2 Feb 2026 20:08:20 -0800 Subject: [PATCH 02/10] Update onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../cpu/math/einsum_utils/einsum_typed_compute_processor.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index ae893f105f382..b138af592954c 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -10,7 +10,6 @@ namespace onnxruntime { template void EinsumTypedComputeProcessor::FinalizeOutput(const Tensor& candidate_output, const gsl::span& ordered_subscript_indices_in_candidate) { - ORT_ENFORCE(candidate_output.Shape().NumDimensions() == ordered_subscript_indices_in_candidate.size(), "Einsum op: The candidate output's rank has to be the as same number of elements as " "the ordered subscript indices in the candidate output. Hitting this error points to an " From 06ee5cc56380f433d8e0135299d7292cd8968659 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 2 Feb 2026 20:08:27 -0800 Subject: [PATCH 03/10] Update onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../cpu/math/einsum_utils/einsum_typed_compute_processor.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index b138af592954c..cf07073da1c0c 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -370,8 +370,8 @@ Status EinsumTypedComputeProcessor::Run() { TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op reduced_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving - TensorShapeVector all_dims; // Expanded dimensions of `num_subscript_labels` {0, 1, ..., num_subscript_labels} - all_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the number of elements + TensorShapeVector all_dims; // Expanded dimensions of `num_subscript_labels` {0, 1, ..., num_subscript_labels} + all_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the number of elements for (size_t i = 0; i < onnxruntime::narrow(num_subscript_labels); ++i) { if (mapped_indices_to_last_input_index[i] == 0) { From d5396691d6b2eb806f402f7abdaa5854adee056c Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 2 Feb 2026 20:08:34 -0800 Subject: [PATCH 04/10] Update onnxruntime/test/providers/cpu/math/einsum_test.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/test/providers/cpu/math/einsum_test.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index b116ea128fcc6..30b89f6e86ed9 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -117,8 +117,7 @@ TEST(Einsum, ExplicitEinsumAsBatchedReduceOp_3D_input_1) { TEST(Einsum, ExplicitEinsumAsReduceWithTransposeOp_3D_input_0) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "ijk->ki"); - test.AddInput("x", {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, - 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + test.AddInput("x", {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); test.AddOutput("y", {4, 2}, {15.f, 15.f, 18.f, 18.f, 21.f, 21.f, 24.f, 24.f}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); } From 50582acf79a22b2b68a777e5e1b51aafc06a227f Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 2 Feb 2026 20:08:45 -0800 Subject: [PATCH 05/10] Update onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../cpu/math/einsum_utils/einsum_typed_compute_processor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index cf07073da1c0c..9c0bec5c427d2 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -11,7 +11,7 @@ template void EinsumTypedComputeProcessor::FinalizeOutput(const Tensor& candidate_output, const gsl::span& ordered_subscript_indices_in_candidate) { ORT_ENFORCE(candidate_output.Shape().NumDimensions() == ordered_subscript_indices_in_candidate.size(), - "Einsum op: The candidate output's rank has to be the as same number of elements as " + "Einsum op: The candidate output's rank has to be the same number of elements as " "the ordered subscript indices in the candidate output. Hitting this error points to an " "internal bug in the Einsum op's implementation. " "Please open a bug report with appropriate repro steps"); From 3bedf32bc37035c9cf2f2424fd97cd84168978d3 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 2 Feb 2026 20:20:49 -0800 Subject: [PATCH 06/10] Update onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../cpu/math/einsum_utils/einsum_typed_compute_processor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 9c0bec5c427d2..e6dc7f9690d4e 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -370,7 +370,7 @@ Status EinsumTypedComputeProcessor::Run() { TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op reduced_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving - TensorShapeVector all_dims; // Expanded dimensions of `num_subscript_labels` {0, 1, ..., num_subscript_labels} + TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1 all_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the number of elements for (size_t i = 0; i < onnxruntime::narrow(num_subscript_labels); ++i) { From 99a39283736455271d2876db33a0e1f89e8a4498 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 5 Feb 2026 10:10:13 -0800 Subject: [PATCH 07/10] Address nit --- .../math/einsum_utils/einsum_compute_preprocessor.h | 2 +- .../einsum_utils/einsum_typed_compute_processor.cc | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h index 3c02ddf612ec3..4d700c4559931 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h @@ -140,7 +140,7 @@ class EinsumComputePreprocessor final { const std::vector& GetMappedSubscriptIndicesToOutputindices() const; // Get the number of subscript indices (subscript labels) in the einsum equation - int64_t GetNumSubscriptIndices() const; + size_t GetNumSubscriptIndices() const; // Pass-in device specific functions // (Pass-in CPU implementation or CUDA implementation function depending on the kernel using this class) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index e6dc7f9690d4e..5085f185130f0 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -368,12 +368,12 @@ Status EinsumTypedComputeProcessor::Run() { { TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op - reduced_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving + reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1 - all_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the number of elements + all_dims.reserve(num_subscript_labels); // num_subscript_labels is the number of elements - for (size_t i = 0; i < onnxruntime::narrow(num_subscript_labels); ++i) { + for (size_t i = 0; i < num_subscript_labels; ++i) { if (mapped_indices_to_last_input_index[i] == 0) { reduced_dims.push_back(i); } @@ -411,9 +411,9 @@ Status EinsumTypedComputeProcessor::Run() { // Keep processing each input pair-wise for (int input = 1; input < num_inputs; ++input) { TensorShapeVector reduced_dims; - reduced_dims.reserve(onnxruntime::narrow(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin. - for (int64_t dim = 0; dim < num_subscript_labels; ++dim) { - if (mapped_indices_to_last_input_index[onnxruntime::narrow(dim)] == input) { + reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin. + for (size_t dim = 0; dim < num_subscript_labels; ++dim) { + if (mapped_indices_to_last_input_index[dim] == input) { // This is the last input we are seeing this dimension (and it doesn't occur in the output), so reduce along the dimension reduced_dims.push_back(dim); } From dbc0117359b5090f7863c3dfa8266b537710d9e4 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 5 Feb 2026 10:14:54 -0800 Subject: [PATCH 08/10] Format --- .../cpu/math/einsum_utils/einsum_typed_compute_processor.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 5085f185130f0..d3a902024a154 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -367,10 +367,10 @@ Status EinsumTypedComputeProcessor::Run() { std::unique_ptr result; { - TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op + TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving - TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1 + TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1 all_dims.reserve(num_subscript_labels); // num_subscript_labels is the number of elements for (size_t i = 0; i < num_subscript_labels; ++i) { From 4a904cc422a46bebec85343c420b44c885e9c420 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Thu, 5 Feb 2026 10:29:36 -0800 Subject: [PATCH 09/10] Update onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../cpu/math/einsum_utils/einsum_typed_compute_processor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index d3a902024a154..6bd0988ee470c 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -367,7 +367,7 @@ Status EinsumTypedComputeProcessor::Run() { std::unique_ptr result; { - TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op + TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1 From 3325ac22749409f260119a0ed1d0a92661781cb8 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 5 Feb 2026 10:39:13 -0800 Subject: [PATCH 10/10] Avoid as many narrows --- .../einsum_compute_preprocessor.cc | 28 +++++++++---------- .../einsum_compute_preprocessor.h | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc index 3da5ebe5e1e4d..8fa1753d2f86c 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc @@ -56,7 +56,7 @@ const std::vector& EinsumComputePreprocessor::GetMappedSubscriptIndices return subscript_indices_to_output_indices_; } -int64_t EinsumComputePreprocessor::GetNumSubscriptIndices() const { +size_t EinsumComputePreprocessor::GetNumSubscriptIndices() const { return num_subscript_indices_; } @@ -73,7 +73,7 @@ Status EinsumComputePreprocessor::ProcessSubscripts() { "Number of subscripts in the input equation does not match number of input tensors"); } - int64_t input_index = 0; + size_t input_index = 0; // Holds mapping between input indices to its corresponding subscript labels for each input input_subscript_indices_.reserve(inputs_.size()); @@ -84,7 +84,7 @@ Status EinsumComputePreprocessor::ProcessSubscripts() { subscript_indices_to_dim_value_.reserve(10); for (const auto& subscript : left_equation_split) { - const auto& shape = inputs_[onnxruntime::narrow(input_index)]->Shape(); + const auto& shape = inputs_[input_index]->Shape(); const auto& dims = shape.GetDims(); size_t rank = dims.size(); size_t dim_counter = 0; @@ -237,13 +237,13 @@ Status EinsumComputePreprocessor::PostProcessBroadcastedDims() { } } - std::vector temp_index_to_last_input(onnxruntime::narrow(num_subscript_indices_), -1); + std::vector temp_index_to_last_input(num_subscript_indices_, -1); for (size_t i = 0; i < subscript_indices_to_last_input_.size(); ++i) { temp_index_to_last_input[i + num_of_ellipsis_dims_] = subscript_indices_to_last_input_[i]; } subscript_indices_to_last_input_ = std::move(temp_index_to_last_input); - std::vector temp_index_to_dim_value(onnxruntime::narrow(num_subscript_indices_), -1); + std::vector temp_index_to_dim_value(num_subscript_indices_, -1); for (size_t i = 0; i < subscript_indices_to_dim_value_.size(); ++i) { temp_index_to_dim_value[i + num_of_ellipsis_dims_] = subscript_indices_to_dim_value_[i]; } @@ -338,7 +338,7 @@ Status EinsumComputePreprocessor::CalculateOutputShape() { bool is_in_middle_of_ellipsis = false; int64_t ellipsis_char_count = 0; - subscript_indices_to_output_indices_.resize(onnxruntime::narrow(num_subscript_indices_), -1); + subscript_indices_to_output_indices_.resize(num_subscript_indices_, -1); std::array output_letter_to_count; output_letter_to_count.fill(0); @@ -407,13 +407,13 @@ Status EinsumComputePreprocessor::PreprocessInputs() { // As part of input preprocessing we "homogenize" them by // 1) Making them all of the same rank // 2) The axes order in all the inputs are to be made the same - int64_t input_iter = 0; + size_t input_iter = 0; for (const auto* input : inputs_) { // Eventually will hold the "preprocessed" version of the original input std::unique_ptr preprocessed; const auto& input_dims = input->Shape().GetDims(); - const auto& current_subscript_indices = input_subscript_indices_[onnxruntime::narrow(input_iter)]; + const auto& current_subscript_indices = input_subscript_indices_[input_iter]; // If all has gone well, we will have a subscript index (subscript label) for each dim of the input if (input_dims.size() != current_subscript_indices.size()) { @@ -421,10 +421,10 @@ Status EinsumComputePreprocessor::PreprocessInputs() { "Rank of the input must match number of subscript labels corresponding to the input"); } - std::vector subscript_indices_to_input_index(onnxruntime::narrow(num_subscript_indices_), -1); + std::vector subscript_indices_to_input_index(num_subscript_indices_, -1); // This is the input dims after re-ordering so that all inputs have same axes order - TensorShapeVector homogenized_input_dims(onnxruntime::narrow(num_subscript_indices_), 1); + TensorShapeVector homogenized_input_dims(num_subscript_indices_, 1); // Preprocessed dim rank may not be the same as original input rank if we need to parse diagonals along the way // (which reduces rank in the preprocessed input by 1 for each diagonal we parse) @@ -437,7 +437,7 @@ Status EinsumComputePreprocessor::PreprocessInputs() { subscript_indices_to_input_index[onnxruntime::narrow(subscript_index)] = dim_index_in_preprocessed_input++; homogenized_input_dims[onnxruntime::narrow(subscript_index)] = input_dims[onnxruntime::narrow(dim_index_in_original_input)]; } else { // Diagonal needs to be parsed along the repeated axes - preprocessed = device_diagonal_func_(preprocessed ? *preprocessed : *inputs_[onnxruntime::narrow(input_iter)], + preprocessed = device_diagonal_func_(preprocessed ? *preprocessed : *inputs_[input_iter], subscript_indices_to_input_index[onnxruntime::narrow(subscript_index)], dim_index_in_preprocessed_input, allocator_, einsum_ep_assets_); @@ -454,10 +454,10 @@ Status EinsumComputePreprocessor::PreprocessInputs() { } // (Identify no-op transpose and prevent triggering the transpose) - if (EinsumOp::IsTransposeRequired(preprocessed ? preprocessed->Shape().GetDims().size() : inputs_[onnxruntime::narrow(input_iter)]->Shape().GetDims().size(), + if (EinsumOp::IsTransposeRequired(preprocessed ? preprocessed->Shape().GetDims().size() : inputs_[input_iter]->Shape().GetDims().size(), permutation)) { - preprocessed = EinsumOp::Transpose(preprocessed ? *preprocessed : *inputs_[onnxruntime::narrow(input_iter)], - preprocessed ? preprocessed->Shape().GetDims() : inputs_[onnxruntime::narrow(input_iter)]->Shape().GetDims(), + preprocessed = EinsumOp::Transpose(preprocessed ? *preprocessed : *inputs_[input_iter], + preprocessed ? preprocessed->Shape().GetDims() : inputs_[input_iter]->Shape().GetDims(), permutation, allocator_, einsum_ep_assets_, device_transpose_func_); } diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h index 4d700c4559931..e3d6d09cd52ec 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h @@ -185,7 +185,7 @@ class EinsumComputePreprocessor final { // num_subscript_indices_ = 3 (i, j, k) // E.g. 2 : With equation -> '...ij', 'jk' -> '...ik' // num_subscript_indices_ = 3 (i, j, k) + number of dims specified by an ellipsis (across all inputs) - int64_t num_subscript_indices_ = 0; + size_t num_subscript_indices_ = 0; // Hold the count corresponding to the letter seen // `0` means the corresponding letter wasn't seen at all