diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index b3f62bd13a24d..4637ee749af2c 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -186,9 +186,9 @@ struct ProviderHostCPUImpl : ProviderHostCPU { std::unique_ptr> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } std::unique_ptr> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } std::unique_ptr> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 15baf7309070d..98845b8186f11 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -117,9 +117,9 @@ struct ProviderHostCPU { virtual std::unique_ptr> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; virtual std::unique_ptr> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; virtual std::unique_ptr> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; @@ -296,8 +296,9 @@ struct EinsumTypedComputeProcessor { void SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, - const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) { - g_host_cpu.EinsumTypedComputeProcessor__SetDeviceHelpers(this, device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); + const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, + const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) { + g_host_cpu.EinsumTypedComputeProcessor__SetDeviceHelpers(this, device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } Status Run() { return g_host_cpu.EinsumTypedComputeProcessor__Run(this); } diff --git a/onnxruntime/core/providers/cpu/math/einsum.cc b/onnxruntime/core/providers/cpu/math/einsum.cc index 789f6645230d8..114ec2d6cfe44 100644 --- a/onnxruntime/core/providers/cpu/math/einsum.cc +++ b/onnxruntime/core/providers/cpu/math/einsum.cc @@ -65,7 +65,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CpuDeviceHelpers::ZeroBuffer); return einsum_compute_processor.Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor(context, @@ -78,7 +79,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CpuDeviceHelpers::ZeroBuffer); return einsum_compute_processor.Run(); } else if (inputs[0]->IsDataType()) { @@ -92,7 +94,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CpuDeviceHelpers::ZeroBuffer); return einsum_compute_processor.Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor(context, @@ -104,7 +107,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CpuDeviceHelpers::ZeroBuffer); return einsum_compute_processor.Run(); } diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc index a602a85fc2737..f54e345e48a5e 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc @@ -22,6 +22,12 @@ Status DataCopy(const Tensor& input, Tensor& output, void* /*einsum_cuda_assets* return Status::OK(); } +// CPU specific Zero buffer helper +Status ZeroBuffer(Tensor& input, void* /*einsum_cuda_assets*/) { + memset(input.MutableDataRaw(), 0, input.SizeInBytes()); + return Status::OK(); +} + // CPU specific Transpose helper Status Transpose(const gsl::span& permutation, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, void* /*einsum_cuda_assets*/) { diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h index a858a49e3e881..75aa4849bf81f 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h @@ -27,6 +27,9 @@ namespace DeviceHelpers { // Data copy op - Copies raw data from the source tensor's buffer to the destination tensor's buffer using DataCopy = std::function; +// Zero buffer op - Sets all bytes in the tensor's buffer to zero +using ZeroBuffer = std::function; + // Transpose op - Transposes given input based on data in `permutation` using Transpose = std::function& permutation, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, @@ -63,6 +66,8 @@ namespace CpuDeviceHelpers { Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets); +Status ZeroBuffer(Tensor& input, void* einsum_cuda_assets); + Status Transpose(const gsl::span& permutation, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, void* einsum_cuda_assets); diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 096e07eb8e272..d77d61d8036be 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -336,11 +336,13 @@ template void EinsumTypedComputeProcessor::SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, - const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) { + const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, + const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) { device_transpose_func_ = device_transpose_func; device_matmul_func_ = device_matmul_func; device_reduce_sum_func_ = device_reduce_sum_func; device_data_copy_func_ = device_data_copy_func; + device_zero_buffer_func_ = device_zero_buffer_func; } template @@ -357,6 +359,20 @@ Status EinsumTypedComputeProcessor::Run() { auto num_inputs = context_->InputCount(); + { + bool has_empty_input = std::any_of(raw_inputs.begin(), raw_inputs.end(), [](const auto& input) { + return input->Shape().Size() == 0; + }); + + // Skip all the work, fill with zeros if needed + if (has_empty_input) { + const auto output_dims = einsum_compute_preprocessor_.GetOutputDims(); + Tensor& output = *context_->Output(0, output_dims); + + return device_zero_buffer_func_(output, einsum_ep_assets_); + } + } + // Pre-process the first input so as to reduce any dims that only it has std::unique_ptr result; diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h index f858019c6329e..666c7c0006663 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h @@ -31,7 +31,8 @@ class EinsumTypedComputeProcessor { void SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, - const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func); + const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, + const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func); Status Run(); @@ -64,6 +65,7 @@ class EinsumTypedComputeProcessor { EinsumOp::DeviceHelpers::MatMul device_matmul_func_; EinsumOp::DeviceHelpers::ReduceSum device_reduce_sum_func_; EinsumOp::DeviceHelpers::DataCopy device_data_copy_func_; + EinsumOp::DeviceHelpers::ZeroBuffer device_zero_buffer_func_; // Holds EP-specific assets required for (auxiliary) ops that need to be executed on non-CPU EPs void* einsum_ep_assets_; diff --git a/onnxruntime/core/providers/cuda/math/einsum.cc b/onnxruntime/core/providers/cuda/math/einsum.cc index b7c0d99a5390e..405f4f362bcb1 100644 --- a/onnxruntime/core/providers/cuda/math/einsum.cc +++ b/onnxruntime/core/providers/cuda/math/einsum.cc @@ -52,7 +52,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vectorSetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose, EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul, EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CudaDeviceHelpers::ZeroBuffer); return einsum_compute_processor->Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, @@ -63,7 +64,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vectorSetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose, EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul, EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CudaDeviceHelpers::ZeroBuffer); return einsum_compute_processor->Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, @@ -73,7 +75,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vectorSetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose, EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul, EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CudaDeviceHelpers::ZeroBuffer); return einsum_compute_processor->Run(); } diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc index ee0334e552022..82238ab0ea1e0 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc @@ -31,6 +31,13 @@ Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets) { return Status::OK(); } +// CUDA EP specific Zero buffer helper +Status ZeroBuffer(Tensor& input, void* einsum_cuda_assets) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(input.MutableDataRaw(), 0, input.SizeInBytes(), + static_cast(einsum_cuda_assets)->GetCudaStream())); + return Status::OK(); +} + // CUDA EP specific Transpose helper Status Transpose(const gsl::span& permutation, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, void* einsum_cuda_assets) { diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h index b152cb3cc1f9b..b42cf79a857cf 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h @@ -47,6 +47,8 @@ Status Transpose(const gsl::span& permutation, const Tensor& input Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets); +Status ZeroBuffer(Tensor& input, void* einsum_cuda_assets); + template Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, size_t left_stride, size_t right_stride, size_t output_stride, diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index d3ea8552f60f4..c0833effcc786 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -751,6 +751,62 @@ TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithOneScalar_Half) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); } +// Theme: Empty inputs +TEST(Einsum, EinsumEmptyInputOuterProduct) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "i,j->ij"); + test.AddInput("x", {0}, {}); + test.AddInput("y", {0}, {}); + test.AddOutput("o", {0, 0}, {}); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); +} + +TEST(Einsum, EinsumEmptyInputTranspose) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "ab,ba->ab"); + test.AddInput("x", {0, 1}, {}); + test.AddInput("y", {1, 0}, {}); + test.AddOutput("o", {0, 1}, {}); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); +} + +TEST(Einsum, EinsumEmptyInputVanish) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "ab,ba->a"); + test.AddInput("x", {1, 0}, {}); + test.AddInput("y", {0, 1}, {}); + test.AddOutput("o", {1}, {0.f}); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); +} + +TEST(Einsum, EinsumEmptyInputVanish3d) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abc,bad->ad"); + test.AddInput("x", {10, 0, 10}, {}); + test.AddInput("y", {0, 10, 1}, {}); + test.AddOutput("o", {10, 1}, std::vector(10, 0.f)); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); +} + +TEST(Einsum, EinsumEmptyInputVanish3d2empty) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abc,bcd->ad"); + test.AddInput("x", {0, 0, 0}, {}); + test.AddInput("y", {0, 0, 1}, {}); + test.AddOutput("o", {0, 1}, {}); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); +} + TEST(Einsum, ExplicitEinsumAsTensorContraction_Half) { #if !defined(USE_WEBGPU) if (!HasCudaEnvironment(600)) {