Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ static bool IsTransposeReshapeForEinsum(const gsl::span<const size_t>& perm,
return true;
}

template <typename T>
static void ZeroInputBuffer(Tensor& buffer_to_be_zeroed) {
if constexpr (std::is_integral<T>::value) {
std::fill_n(reinterpret_cast<T*>(buffer_to_be_zeroed.MutableDataRaw()), buffer_to_be_zeroed.Shape().Size(), T(0));
} else {
std::fill_n(reinterpret_cast<T*>(buffer_to_be_zeroed.MutableDataRaw()), buffer_to_be_zeroed.Shape().Size(), T(0.f));
}
}

template <typename T>
std::unique_ptr<Tensor> EinsumTypedComputeProcessor<T>::PairwiseOperandProcess(const Tensor& left,
const TensorShape& left_shape_override,
Expand Down Expand Up @@ -357,6 +366,39 @@ Status EinsumTypedComputeProcessor<T>::Run() {

auto num_inputs = context_->InputCount();

{
bool has_empty_input = std::any_of(raw_inputs.begin(), raw_inputs.end(), [](const auto& input) {
return input->Shape().Size() == 0;
});

// Skip all the work, fill with zeros if needed
if (has_empty_input) {
const auto output_dims = einsum_compute_preprocessor_.GetOutputDims();
Tensor& output = *context_->Output(0, output_dims);

if (output.Location().device.Type() != OrtDevice::CPU) {
// Get CPU allocator to allocate a staging buffer on CPU
AllocatorPtr cpu_allocator;
ORT_RETURN_IF_ERROR(context_->GetTempSpaceCPUAllocator(&cpu_allocator));

// If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU
// buffer to stage the zero buffer results which we will then copy over to the op's output
// allocated on the non-CPU device using the device data copy abstraction
Tensor candidate_output(raw_inputs[0]->DataType(), output_dims, cpu_allocator);
ZeroInputBuffer<T>(candidate_output);

// Copy zeroed buffer to the output buffer
auto status = device_data_copy_func_(candidate_output, output, einsum_ep_assets_);
ORT_ENFORCE(status.IsOK(), "Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: ",
status.ErrorMessage());
} else { // Zero out the op's output buffer
ZeroInputBuffer<T>(output);
}

return Status::OK();
}
}

// Pre-process the first input so as to reduce any dims that only it has
std::unique_ptr<const Tensor> result;

Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/test/providers/cpu/math/einsum_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,62 @@ TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithOneScalar_Half) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
}

// Theme: Empty inputs
TEST(Einsum, EinsumEmptyInputOuterProduct) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "i,j->ij");
test.AddInput<float>("x", {0}, {});
test.AddInput<float>("y", {0}, {});
test.AddOutput<float>("o", {0, 0}, {});
// Empty inputs/outputs seem to cause some issue in the WebGpu EP.
// Disable for now.
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider});
}

TEST(Einsum, EinsumEmptyInputTranspose) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ab,ba->ab");
test.AddInput<float>("x", {0, 1}, {});
test.AddInput<float>("y", {1, 0}, {});
test.AddOutput<float>("o", {0, 1}, {});
// Empty inputs/outputs seem to cause some issue in the WebGpu EP.
// Disable for now.
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider});
}

TEST(Einsum, EinsumEmptyInputVanish) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ab,ba->a");
test.AddInput<float>("x", {1, 0}, {});
test.AddInput<float>("y", {0, 1}, {});
test.AddOutput<float>("o", {1}, {0.f});
// Empty inputs/outputs seem to cause some issue in the WebGpu EP.
// Disable for now.
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider});
}

TEST(Einsum, EinsumEmptyInputVanish3d) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "abc,bad->ad");
test.AddInput<float>("x", {10, 0, 10}, {});
test.AddInput<float>("y", {0, 10, 1}, {});
test.AddOutput<float>("o", {10, 1}, std::vector<float>(10, 0.f));
// Empty inputs/outputs seem to cause some issue in the WebGpu EP.
// Disable for now.
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider});
}

TEST(Einsum, EinsumEmptyInputVanish3d2empty) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "abc,bcd->ad");
test.AddInput<float>("x", {0, 0, 0}, {});
test.AddInput<float>("y", {0, 0, 1}, {});
test.AddOutput<float>("o", {0, 1}, {});
// Empty inputs/outputs seem to cause some issue in the WebGpu EP.
// Disable for now.
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider});
}

TEST(Einsum, ExplicitEinsumAsTensorContraction_Half) {
#if !defined(USE_WEBGPU)
if (!HasCudaEnvironment(600)) {
Expand Down
Loading