Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class EinsumComputePreprocessor final {
const std::vector<int64_t>& GetMappedSubscriptIndicesToOutputindices() const;

// Get the number of subscript indices (subscript labels) in the einsum equation
int64_t GetNumSubscriptIndices() const;
size_t GetNumSubscriptIndices() const;

// Pass-in device specific functions
// (Pass-in CPU implementation or CUDA implementation function depending on the kernel using this class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ namespace onnxruntime {
template <typename T>
void EinsumTypedComputeProcessor<T>::FinalizeOutput(const Tensor& candidate_output,
const gsl::span<const int64_t>& ordered_subscript_indices_in_candidate) {
ORT_ENFORCE(candidate_output.Shape().NumDimensions() == ordered_subscript_indices_in_candidate.size(),
"Einsum op: The candidate output's rank has to be the same number of elements as "
"the ordered subscript indices in the candidate output. Hitting this error points to an "
"internal bug in the Einsum op's implementation. "
"Please open a bug report with appropriate repro steps");

const std::vector<int64_t>& subscript_indices_to_output_indices =
einsum_compute_preprocessor_.GetMappedSubscriptIndicesToOutputindices();
const auto output_dims = einsum_compute_preprocessor_.GetOutputDims();
Expand Down Expand Up @@ -75,7 +81,7 @@ void EinsumTypedComputeProcessor<T>::FinalizeOutput(const Tensor& candidate_outp
static bool IsTransposeReshapeForEinsum(const gsl::span<const size_t>& perm,
gsl::span<const int64_t> input_dims,
TensorShapeVector& new_shape) {
// As long as the dims with values > 1 stay in the same order, it's a reshape.
// As long as the dims with values > 1 stay in the same relative order, it's a reshape.
// Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
size_t last_permuted_axis = 0;
for (size_t i = 0; i < perm.size(); ++i) {
Expand Down Expand Up @@ -361,17 +367,19 @@ Status EinsumTypedComputeProcessor<T>::Run() {
std::unique_ptr<const Tensor> result;

{
TensorShapeVector reduced_dims;
TensorShapeVector preserved_dims; // dims which were not reduced
reduced_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving.
preserved_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving.
TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op
reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving

TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1
all_dims.reserve(num_subscript_labels); // num_subscript_labels is the number of elements

for (size_t i = 0; i < onnxruntime::narrow<size_t>(num_subscript_labels); ++i) {
for (size_t i = 0; i < num_subscript_labels; ++i) {
if (mapped_indices_to_last_input_index[i] == 0) {
reduced_dims.push_back(i);
} else {
preserved_dims.push_back(i);
}

// ReduceSum operation preserves even the reduced dims with reduced dim shape value being 1
all_dims.push_back(i);
}

// Reduce the dims that are last seen in the first input alone
Expand All @@ -391,7 +399,7 @@ Status EinsumTypedComputeProcessor<T>::Run() {
if (num_inputs == 1) {
// Finalize the output by applying any transpose required to get
// it to the required output ordering and move it to the op's output
FinalizeOutput(result ? *result : *raw_inputs[0], preserved_dims);
FinalizeOutput(result ? *result : *raw_inputs[0], all_dims);

return Status::OK();
}
Expand All @@ -403,9 +411,9 @@ Status EinsumTypedComputeProcessor<T>::Run() {
// Keep processing each input pair-wise
for (int input = 1; input < num_inputs; ++input) {
TensorShapeVector reduced_dims;
reduced_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin.
for (int64_t dim = 0; dim < num_subscript_labels; ++dim) {
if (mapped_indices_to_last_input_index[onnxruntime::narrow<size_t>(dim)] == input) {
reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin.
for (size_t dim = 0; dim < num_subscript_labels; ++dim) {
if (mapped_indices_to_last_input_index[dim] == input) {
// This is the last input we are seeing this dimension (and it doesn't occur in the output), so reduce along the dimension
reduced_dims.push_back(dim);
}
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/test/providers/cpu/math/einsum_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ TEST(Einsum, ExplicitEinsumAsBatchedReduceOp_3D_input_1) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
}

TEST(Einsum, ExplicitEinsumAsReduceWithTransposeOp_3D_input_0) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ijk->ki");
test.AddInput<float>("x", {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f});
test.AddOutput<float>("y", {4, 2}, {15.f, 15.f, 18.f, 18.f, 21.f, 21.f, 24.f, 24.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
}

// Implicit
// Cannot do implicit reduction

Expand Down
Loading