From 6fb223f2913e5ecf3f73f7a12361b90bb8791ae1 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 2 Feb 2026 20:22:49 -0800 Subject: [PATCH 1/9] Dupe of 23379 --- .../einsum_typed_compute_processor.cc | 33 +++++++++++++ .../test/providers/cpu/math/einsum_test.cc | 46 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 096e07eb8e272..642378e96ef25 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -357,6 +357,39 @@ Status EinsumTypedComputeProcessor::Run() { auto num_inputs = context_->InputCount(); + { + bool has_empty_input = std::any_of(raw_inputs.begin(), raw_inputs.end(), [](const auto& input) { + return input->Shape().Size() == 0; + }); + + // Skip all the work, fill with zeros if needed + if (has_empty_input) { + AllocatorPtr cpu_allocator; + // TODO(hasesh): I think this is the bug - need to use GetTempSpaceCPUAllocator + ORT_RETURN_IF_ERROR(context_->GetTempSpaceAllocator(&cpu_allocator)); + + const auto output_dims = einsum_compute_preprocessor_.GetOutputDims(); + Tensor& output = *context_->Output(0, output_dims); + + // If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU + // buffer to stage the zero buffer results which we will then copy over to the op's output + // allocated on the non-CPU device using the device data copy abstraction + Tensor candidate_output(raw_inputs[0]->DataType(), output_dims, cpu_allocator); + + if constexpr (std::is_integral::value) { + std::fill_n(reinterpret_cast(candidate_output.MutableDataRaw()), candidate_output.Shape().Size(), T(0)); + } else { + std::fill_n(reinterpret_cast(candidate_output.MutableDataRaw()), candidate_output.Shape().Size(), T(0.f)); + } + + auto status = device_data_copy_func_(candidate_output, output, einsum_ep_assets_); + ORT_ENFORCE(status.IsOK(), "Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: ", + status.ErrorMessage()); + + return status; + } + } + // Pre-process the first input so as to reduce any dims that only it has std::unique_ptr result; diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index d3ea8552f60f4..656897663885c 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -751,6 +751,52 @@ TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithOneScalar_Half) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); } +// Theme: Empty inputs +TEST(Einsum, EinsumEmptyInputOuterProduct) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "i,j->ij"); + test.AddInput("x", {0}, {}); + test.AddInput("y", {0}, {}); + test.AddOutput("o", {0, 0}, {}); + test.Run(); +} + +TEST(Einsum, EinsumEmptyInputTranspose) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "ab,ba->ab"); + test.AddInput("x", {0, 1}, {}); + test.AddInput("y", {1, 0}, {}); + test.AddOutput("o", {0, 1}, {}); + test.Run(); +} + +TEST(Einsum, EinsumEmptyInputVanish) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "ab,ba->a"); + test.AddInput("x", {1, 0}, {}); + test.AddInput("y", {0, 1}, {}); + test.AddOutput("o", {1}, {0.f}); + test.Run(); +} + +TEST(Einsum, EinsumEmptyInputVanish3d) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abc,bad->ad"); + test.AddInput("x", {10, 0, 10}, {}); + test.AddInput("y", {0, 10, 1}, {}); + test.AddOutput("o", {10, 1}, std::vector(10, 0.f)); + test.Run(); +} + +TEST(Einsum, EinsumEmptyInputVanish3d2empty) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "abc,bcd->ad"); + test.AddInput("x", {0, 0, 0}, {}); + test.AddInput("y", {0, 0, 1}, {}); + test.AddOutput("o", {0, 1}, {}); + test.Run(); +} + TEST(Einsum, ExplicitEinsumAsTensorContraction_Half) { #if !defined(USE_WEBGPU) if (!HasCudaEnvironment(600)) { From 7e83f1db011fef99dcc202e6af9faa8eb1b4b04c Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 2 Feb 2026 20:35:12 -0800 Subject: [PATCH 2/9] Avoid unnecessary copy on CPU --- .../einsum_typed_compute_processor.cc | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 642378e96ef25..9d29c9bdc8820 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -92,6 +92,15 @@ static bool IsTransposeReshapeForEinsum(const gsl::span& perm, return true; } +template +static void ZeroInputBuffer(Tensor& buffer_to_be_zeroed) { + if constexpr (std::is_integral::value) { + std::fill_n(reinterpret_cast(buffer_to_be_zeroed.MutableDataRaw()), buffer_to_be_zeroed.Shape().Size(), T(0)); + } else { + std::fill_n(reinterpret_cast(buffer_to_be_zeroed.MutableDataRaw()), buffer_to_be_zeroed.Shape().Size(), T(0.f)); + } +} + template std::unique_ptr EinsumTypedComputeProcessor::PairwiseOperandProcess(const Tensor& left, const TensorShape& left_shape_override, @@ -364,29 +373,29 @@ Status EinsumTypedComputeProcessor::Run() { // Skip all the work, fill with zeros if needed if (has_empty_input) { - AllocatorPtr cpu_allocator; - // TODO(hasesh): I think this is the bug - need to use GetTempSpaceCPUAllocator - ORT_RETURN_IF_ERROR(context_->GetTempSpaceAllocator(&cpu_allocator)); - const auto output_dims = einsum_compute_preprocessor_.GetOutputDims(); Tensor& output = *context_->Output(0, output_dims); - // If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU - // buffer to stage the zero buffer results which we will then copy over to the op's output - // allocated on the non-CPU device using the device data copy abstraction - Tensor candidate_output(raw_inputs[0]->DataType(), output_dims, cpu_allocator); + if (output.Location().device.Type() != OrtDevice::CPU) { + AllocatorPtr cpu_allocator; + // TODO(hasesh): I think this is the bug - need to use GetTempSpaceCPUAllocator + ORT_RETURN_IF_ERROR(context_->GetTempSpaceAllocator(&cpu_allocator)); - if constexpr (std::is_integral::value) { - std::fill_n(reinterpret_cast(candidate_output.MutableDataRaw()), candidate_output.Shape().Size(), T(0)); - } else { - std::fill_n(reinterpret_cast(candidate_output.MutableDataRaw()), candidate_output.Shape().Size(), T(0.f)); - } - auto status = device_data_copy_func_(candidate_output, output, einsum_ep_assets_); - ORT_ENFORCE(status.IsOK(), "Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: ", - status.ErrorMessage()); + // If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU + // buffer to stage the zero buffer results which we will then copy over to the op's output + // allocated on the non-CPU device using the device data copy abstraction + Tensor candidate_output(raw_inputs[0]->DataType(), output_dims, cpu_allocator); + ZeroInputBuffer(candidate_output); - return status; + auto status = device_data_copy_func_(candidate_output, output, einsum_ep_assets_); + ORT_ENFORCE(status.IsOK(), "Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: ", + status.ErrorMessage()); + } else { // Zero out the op's output buffer + ZeroInputBuffer(output); + } + + return Status::OK(); } } From c74b516adb377d0039e1c871677e1f0531007dcf Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 2 Feb 2026 20:38:36 -0800 Subject: [PATCH 3/9] Update onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../einsum_typed_compute_processor.cc | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 9d29c9bdc8820..0dd85079be557 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -377,22 +377,21 @@ Status EinsumTypedComputeProcessor::Run() { Tensor& output = *context_->Output(0, output_dims); if (output.Location().device.Type() != OrtDevice::CPU) { - AllocatorPtr cpu_allocator; - // TODO(hasesh): I think this is the bug - need to use GetTempSpaceCPUAllocator - ORT_RETURN_IF_ERROR(context_->GetTempSpaceAllocator(&cpu_allocator)); - - - // If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU - // buffer to stage the zero buffer results which we will then copy over to the op's output - // allocated on the non-CPU device using the device data copy abstraction - Tensor candidate_output(raw_inputs[0]->DataType(), output_dims, cpu_allocator); - ZeroInputBuffer(candidate_output); - - auto status = device_data_copy_func_(candidate_output, output, einsum_ep_assets_); - ORT_ENFORCE(status.IsOK(), "Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: ", - status.ErrorMessage()); - } else { // Zero out the op's output buffer - ZeroInputBuffer(output); + AllocatorPtr cpu_allocator; + // TODO(hasesh): I think this is the bug - need to use GetTempSpaceCPUAllocator + ORT_RETURN_IF_ERROR(context_->GetTempSpaceAllocator(&cpu_allocator)); + + // If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU + // buffer to stage the zero buffer results which we will then copy over to the op's output + // allocated on the non-CPU device using the device data copy abstraction + Tensor candidate_output(raw_inputs[0]->DataType(), output_dims, cpu_allocator); + ZeroInputBuffer(candidate_output); + + auto status = device_data_copy_func_(candidate_output, output, einsum_ep_assets_); + ORT_ENFORCE(status.IsOK(), "Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: ", + status.ErrorMessage()); + } else { // Zero out the op's output buffer + ZeroInputBuffer(output); } return Status::OK(); From e0e6e01b49b9d018d988625ba6f32a45aea462f7 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 3 Feb 2026 00:36:23 -0800 Subject: [PATCH 4/9] More changes --- .../einsum_typed_compute_processor.cc | 4 +--- .../test/providers/cpu/math/einsum_test.cc | 20 ++++++++++++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 9d29c9bdc8820..636045fde64a4 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -378,9 +378,7 @@ Status EinsumTypedComputeProcessor::Run() { if (output.Location().device.Type() != OrtDevice::CPU) { AllocatorPtr cpu_allocator; - // TODO(hasesh): I think this is the bug - need to use GetTempSpaceCPUAllocator - ORT_RETURN_IF_ERROR(context_->GetTempSpaceAllocator(&cpu_allocator)); - + ORT_RETURN_IF_ERROR(context_->GetTempSpaceCPUAllocator(&cpu_allocator)); // If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU // buffer to stage the zero buffer results which we will then copy over to the op's output diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index 656897663885c..c0833effcc786 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -758,7 +758,9 @@ TEST(Einsum, EinsumEmptyInputOuterProduct) { test.AddInput("x", {0}, {}); test.AddInput("y", {0}, {}); test.AddOutput("o", {0, 0}, {}); - test.Run(); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); } TEST(Einsum, EinsumEmptyInputTranspose) { @@ -767,7 +769,9 @@ TEST(Einsum, EinsumEmptyInputTranspose) { test.AddInput("x", {0, 1}, {}); test.AddInput("y", {1, 0}, {}); test.AddOutput("o", {0, 1}, {}); - test.Run(); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); } TEST(Einsum, EinsumEmptyInputVanish) { @@ -776,7 +780,9 @@ TEST(Einsum, EinsumEmptyInputVanish) { test.AddInput("x", {1, 0}, {}); test.AddInput("y", {0, 1}, {}); test.AddOutput("o", {1}, {0.f}); - test.Run(); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); } TEST(Einsum, EinsumEmptyInputVanish3d) { @@ -785,7 +791,9 @@ TEST(Einsum, EinsumEmptyInputVanish3d) { test.AddInput("x", {10, 0, 10}, {}); test.AddInput("y", {0, 10, 1}, {}); test.AddOutput("o", {10, 1}, std::vector(10, 0.f)); - test.Run(); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); } TEST(Einsum, EinsumEmptyInputVanish3d2empty) { @@ -794,7 +802,9 @@ TEST(Einsum, EinsumEmptyInputVanish3d2empty) { test.AddInput("x", {0, 0, 0}, {}); test.AddInput("y", {0, 0, 1}, {}); test.AddOutput("o", {0, 1}, {}); - test.Run(); + // Empty inputs/outputs seem to cause some issue in the WebGpu EP. + // Disable for now. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kWebGpuExecutionProvider}); } TEST(Einsum, ExplicitEinsumAsTensorContraction_Half) { From b845a7cdc5634be86ddb3b6156afa3ceb38fc5e5 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 3 Feb 2026 10:45:57 -0800 Subject: [PATCH 5/9] Adding infra for zeroing output buffer --- .../core/providers/cpu/cpu_provider_shared.cc | 6 ++-- .../core/providers/cpu/cpu_provider_shared.h | 6 ++-- onnxruntime/core/providers/cpu/math/einsum.cc | 12 ++++--- .../math/einsum_utils/einsum_auxiliary_ops.cc | 6 ++++ .../math/einsum_utils/einsum_auxiliary_ops.h | 5 +++ .../einsum_typed_compute_processor.cc | 34 +++---------------- .../einsum_typed_compute_processor.h | 4 ++- .../core/providers/cuda/math/einsum.cc | 9 +++-- .../math/einsum_utils/einsum_auxiliary_ops.cc | 8 +++++ .../math/einsum_utils/einsum_auxiliary_ops.h | 2 ++ 10 files changed, 48 insertions(+), 44 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index b3f62bd13a24d..ca1287a4355e8 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -186,9 +186,9 @@ struct ProviderHostCPUImpl : ProviderHostCPU { std::unique_ptr> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } std::unique_ptr> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } std::unique_ptr> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 15baf7309070d..a3ed329f66e4b 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -117,9 +117,9 @@ struct ProviderHostCPU { virtual std::unique_ptr> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; virtual std::unique_ptr> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; virtual std::unique_ptr> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; diff --git a/onnxruntime/core/providers/cpu/math/einsum.cc b/onnxruntime/core/providers/cpu/math/einsum.cc index 789f6645230d8..f9d108b212384 100644 --- a/onnxruntime/core/providers/cpu/math/einsum.cc +++ b/onnxruntime/core/providers/cpu/math/einsum.cc @@ -65,7 +65,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing); return einsum_compute_processor.Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor(context, @@ -78,7 +79,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing); return einsum_compute_processor.Run(); } else if (inputs[0]->IsDataType()) { @@ -92,7 +94,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing); return einsum_compute_processor.Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor(context, @@ -104,7 +107,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing); return einsum_compute_processor.Run(); } diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc index a602a85fc2737..5e23bf558911a 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc @@ -22,6 +22,12 @@ Status DataCopy(const Tensor& input, Tensor& output, void* /*einsum_cuda_assets* return Status::OK(); } +// CPU specific Zeroing helper +Status Zeroing(Tensor& input, void* /*einsum_cuda_assets*/) { + memset(input.MutableDataRaw(), 0, input.SizeInBytes()); + return Status::OK(); +} + // CPU specific Transpose helper Status Transpose(const gsl::span& permutation, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, void* /*einsum_cuda_assets*/) { diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h index a858a49e3e881..a26b5bd4f046d 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h @@ -27,6 +27,9 @@ namespace DeviceHelpers { // Data copy op - Copies raw data from the source tensor's buffer to the destination tensor's buffer using DataCopy = std::function; +// Zeroing op - Sets all bytes in the tensor's buffer to zero +using Zeroing = std::function; + // Transpose op - Transposes given input based on data in `permutation` using Transpose = std::function& permutation, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, @@ -63,6 +66,8 @@ namespace CpuDeviceHelpers { Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets); +Status Zeroing(Tensor& input, void* einsum_cuda_assets); + Status Transpose(const gsl::span& permutation, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, void* einsum_cuda_assets); diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index e3b2523129ed3..41ab2cb8d0a4d 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -92,15 +92,6 @@ static bool IsTransposeReshapeForEinsum(const gsl::span& perm, return true; } -template -static void ZeroInputBuffer(Tensor& buffer_to_be_zeroed) { - if constexpr (std::is_integral::value) { - std::fill_n(reinterpret_cast(buffer_to_be_zeroed.MutableDataRaw()), buffer_to_be_zeroed.Shape().Size(), T(0)); - } else { - std::fill_n(reinterpret_cast(buffer_to_be_zeroed.MutableDataRaw()), buffer_to_be_zeroed.Shape().Size(), T(0.f)); - } -} - template std::unique_ptr EinsumTypedComputeProcessor::PairwiseOperandProcess(const Tensor& left, const TensorShape& left_shape_override, @@ -345,11 +336,13 @@ template void EinsumTypedComputeProcessor::SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, - const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) { + const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, + const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func) { device_transpose_func_ = device_transpose_func; device_matmul_func_ = device_matmul_func; device_reduce_sum_func_ = device_reduce_sum_func; device_data_copy_func_ = device_data_copy_func; + zero_input_buffer_func_ = zero_input_buffer_func; } template @@ -376,26 +369,7 @@ Status EinsumTypedComputeProcessor::Run() { const auto output_dims = einsum_compute_preprocessor_.GetOutputDims(); Tensor& output = *context_->Output(0, output_dims); - if (output.Location().device.Type() != OrtDevice::CPU) { - // Get CPU allocator to allocate a staging buffer on CPU - AllocatorPtr cpu_allocator; - ORT_RETURN_IF_ERROR(context_->GetTempSpaceCPUAllocator(&cpu_allocator)); - - // If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU - // buffer to stage the zero buffer results which we will then copy over to the op's output - // allocated on the non-CPU device using the device data copy abstraction - Tensor candidate_output(raw_inputs[0]->DataType(), output_dims, cpu_allocator); - ZeroInputBuffer(candidate_output); - - // Copy zeroed buffer to the output buffer - auto status = device_data_copy_func_(candidate_output, output, einsum_ep_assets_); - ORT_ENFORCE(status.IsOK(), "Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: ", - status.ErrorMessage()); - } else { // Zero out the op's output buffer - ZeroInputBuffer(output); - } - - return Status::OK(); + return zero_input_buffer_func_(output, einsum_ep_assets_); } } diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h index f858019c6329e..cf289f16fe7ab 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h @@ -31,7 +31,8 @@ class EinsumTypedComputeProcessor { void SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, - const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func); + const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, + const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func); Status Run(); @@ -64,6 +65,7 @@ class EinsumTypedComputeProcessor { EinsumOp::DeviceHelpers::MatMul device_matmul_func_; EinsumOp::DeviceHelpers::ReduceSum device_reduce_sum_func_; EinsumOp::DeviceHelpers::DataCopy device_data_copy_func_; + EinsumOp::DeviceHelpers::Zeroing zero_input_buffer_func_; // Holds EP-specific assets required for (auxiliary) ops that need to be executed on non-CPU EPs void* einsum_ep_assets_; diff --git a/onnxruntime/core/providers/cuda/math/einsum.cc b/onnxruntime/core/providers/cuda/math/einsum.cc index b7c0d99a5390e..4e84b029ff588 100644 --- a/onnxruntime/core/providers/cuda/math/einsum.cc +++ b/onnxruntime/core/providers/cuda/math/einsum.cc @@ -52,7 +52,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vectorSetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose, EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul, EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing); return einsum_compute_processor->Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, @@ -63,7 +64,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vectorSetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose, EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul, EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing); return einsum_compute_processor->Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, @@ -73,7 +75,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vectorSetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose, EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul, EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy); + EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy, + EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing); return einsum_compute_processor->Run(); } diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc index ee0334e552022..fd60b22fc838c 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc @@ -31,6 +31,14 @@ Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets) { return Status::OK(); } + +// CUDA EP specific Zeroing helper +Status Zeroing(Tensor& input, void* einsum_cuda_assets) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(input.MutableDataRaw(), 0, input.SizeInBytes(), + static_cast(einsum_cuda_assets)->GetCudaStream())); + return Status::OK(); +} + // CUDA EP specific Transpose helper Status Transpose(const gsl::span& permutation, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, void* einsum_cuda_assets) { diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h index b152cb3cc1f9b..5101d2bffaa9c 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h @@ -47,6 +47,8 @@ Status Transpose(const gsl::span& permutation, const Tensor& input Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets); +Status Zeroing(Tensor& input, void* einsum_cuda_assets); + template Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, size_t left_stride, size_t right_stride, size_t output_stride, From 987d64795274eadcf07693fd5bab1f86ae2d782b Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 3 Feb 2026 11:00:31 -0800 Subject: [PATCH 6/9] Update onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc index fd60b22fc838c..73734fdc21d2f 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc @@ -31,7 +31,6 @@ Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets) { return Status::OK(); } - // CUDA EP specific Zeroing helper Status Zeroing(Tensor& input, void* einsum_cuda_assets) { CUDA_RETURN_IF_ERROR(cudaMemsetAsync(input.MutableDataRaw(), 0, input.SizeInBytes(), From 1522f1d8d2782f2b26aaff1c1f5e8047a6229aa7 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 3 Feb 2026 11:36:12 -0800 Subject: [PATCH 7/9] Fixes --- onnxruntime/core/providers/cpu/cpu_provider_shared.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index a3ed329f66e4b..24ea3457f2b0b 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -296,8 +296,9 @@ struct EinsumTypedComputeProcessor { void SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, - const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) { - g_host_cpu.EinsumTypedComputeProcessor__SetDeviceHelpers(this, device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); + const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, + const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func) { + g_host_cpu.EinsumTypedComputeProcessor__SetDeviceHelpers(this, device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, zero_input_buffer_func); } Status Run() { return g_host_cpu.EinsumTypedComputeProcessor__Run(this); } From edb352db5c6ee942397e4cfa6804b9d2f7d97f6e Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 5 Feb 2026 10:28:32 -0800 Subject: [PATCH 8/9] Consistency --- onnxruntime/core/providers/cpu/cpu_provider_shared.cc | 6 +++--- onnxruntime/core/providers/cpu/cpu_provider_shared.h | 10 +++++----- .../einsum_utils/einsum_typed_compute_processor.cc | 6 +++--- .../math/einsum_utils/einsum_typed_compute_processor.h | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index ca1287a4355e8..36d2aefa6c0c2 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -186,9 +186,9 @@ struct ProviderHostCPUImpl : ProviderHostCPU { std::unique_ptr> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } std::unique_ptr> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } std::unique_ptr> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 24ea3457f2b0b..f95a6114bbe9a 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -117,9 +117,9 @@ struct ProviderHostCPU { virtual std::unique_ptr> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; virtual std::unique_ptr> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; virtual std::unique_ptr> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; @@ -297,8 +297,8 @@ struct EinsumTypedComputeProcessor { const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, - const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func) { - g_host_cpu.EinsumTypedComputeProcessor__SetDeviceHelpers(this, device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, zero_input_buffer_func); + const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) { + g_host_cpu.EinsumTypedComputeProcessor__SetDeviceHelpers(this, device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } Status Run() { return g_host_cpu.EinsumTypedComputeProcessor__Run(this); } diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 41ab2cb8d0a4d..0584893e7bda0 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -337,12 +337,12 @@ void EinsumTypedComputeProcessor::SetDeviceHelpers(const EinsumOp::DeviceHelp const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, - const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func) { + const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) { device_transpose_func_ = device_transpose_func; device_matmul_func_ = device_matmul_func; device_reduce_sum_func_ = device_reduce_sum_func; device_data_copy_func_ = device_data_copy_func; - zero_input_buffer_func_ = zero_input_buffer_func; + device_zero_buffer_func_ = device_zero_buffer_func; } template @@ -369,7 +369,7 @@ Status EinsumTypedComputeProcessor::Run() { const auto output_dims = einsum_compute_preprocessor_.GetOutputDims(); Tensor& output = *context_->Output(0, output_dims); - return zero_input_buffer_func_(output, einsum_ep_assets_); + return device_zero_buffer_func_(output, einsum_ep_assets_); } } diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h index cf289f16fe7ab..316cac4d98b5e 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h @@ -32,7 +32,7 @@ class EinsumTypedComputeProcessor { const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, - const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func); + const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func); Status Run(); @@ -65,7 +65,7 @@ class EinsumTypedComputeProcessor { EinsumOp::DeviceHelpers::MatMul device_matmul_func_; EinsumOp::DeviceHelpers::ReduceSum device_reduce_sum_func_; EinsumOp::DeviceHelpers::DataCopy device_data_copy_func_; - EinsumOp::DeviceHelpers::Zeroing zero_input_buffer_func_; + EinsumOp::DeviceHelpers::Zeroing device_zero_buffer_func_; // Holds EP-specific assets required for (auxiliary) ops that need to be executed on non-CPU EPs void* einsum_ep_assets_; From 4ef7a8135df8d079d5a99aadd190f17337967956 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 5 Feb 2026 11:45:02 -0800 Subject: [PATCH 9/9] Rename Zeroing --- onnxruntime/core/providers/cpu/cpu_provider_shared.cc | 6 +++--- onnxruntime/core/providers/cpu/cpu_provider_shared.h | 8 ++++---- onnxruntime/core/providers/cpu/math/einsum.cc | 8 ++++---- .../cpu/math/einsum_utils/einsum_auxiliary_ops.cc | 4 ++-- .../cpu/math/einsum_utils/einsum_auxiliary_ops.h | 6 +++--- .../math/einsum_utils/einsum_typed_compute_processor.cc | 2 +- .../math/einsum_utils/einsum_typed_compute_processor.h | 4 ++-- onnxruntime/core/providers/cuda/math/einsum.cc | 6 +++--- .../cuda/math/einsum_utils/einsum_auxiliary_ops.cc | 4 ++-- .../cuda/math/einsum_utils/einsum_auxiliary_ops.h | 2 +- 10 files changed, 25 insertions(+), 25 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index 36d2aefa6c0c2..4637ee749af2c 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -186,9 +186,9 @@ struct ProviderHostCPUImpl : ProviderHostCPU { std::unique_ptr> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } std::unique_ptr> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } std::unique_ptr> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } - void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } + void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index f95a6114bbe9a..98845b8186f11 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -117,9 +117,9 @@ struct ProviderHostCPU { virtual std::unique_ptr> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; virtual std::unique_ptr> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; virtual std::unique_ptr> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) = 0; - virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) = 0; + virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) = 0; @@ -297,7 +297,7 @@ struct EinsumTypedComputeProcessor { const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, - const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) { + const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) { g_host_cpu.EinsumTypedComputeProcessor__SetDeviceHelpers(this, device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); } diff --git a/onnxruntime/core/providers/cpu/math/einsum.cc b/onnxruntime/core/providers/cpu/math/einsum.cc index f9d108b212384..114ec2d6cfe44 100644 --- a/onnxruntime/core/providers/cpu/math/einsum.cc +++ b/onnxruntime/core/providers/cpu/math/einsum.cc @@ -66,7 +66,7 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::ZeroBuffer); return einsum_compute_processor.Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor(context, @@ -80,7 +80,7 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::ZeroBuffer); return einsum_compute_processor.Run(); } else if (inputs[0]->IsDataType()) { @@ -95,7 +95,7 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::ZeroBuffer); return einsum_compute_processor.Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor(context, @@ -108,7 +108,7 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum, EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy, - EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing); + EinsumOp::DeviceHelpers::CpuDeviceHelpers::ZeroBuffer); return einsum_compute_processor.Run(); } diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc index 5e23bf558911a..f54e345e48a5e 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc @@ -22,8 +22,8 @@ Status DataCopy(const Tensor& input, Tensor& output, void* /*einsum_cuda_assets* return Status::OK(); } -// CPU specific Zeroing helper -Status Zeroing(Tensor& input, void* /*einsum_cuda_assets*/) { +// CPU specific Zero buffer helper +Status ZeroBuffer(Tensor& input, void* /*einsum_cuda_assets*/) { memset(input.MutableDataRaw(), 0, input.SizeInBytes()); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h index a26b5bd4f046d..75aa4849bf81f 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h @@ -27,8 +27,8 @@ namespace DeviceHelpers { // Data copy op - Copies raw data from the source tensor's buffer to the destination tensor's buffer using DataCopy = std::function; -// Zeroing op - Sets all bytes in the tensor's buffer to zero -using Zeroing = std::function; +// Zero buffer op - Sets all bytes in the tensor's buffer to zero +using ZeroBuffer = std::function; // Transpose op - Transposes given input based on data in `permutation` using Transpose = std::function& permutation, const Tensor& input, @@ -66,7 +66,7 @@ namespace CpuDeviceHelpers { Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets); -Status Zeroing(Tensor& input, void* einsum_cuda_assets); +Status ZeroBuffer(Tensor& input, void* einsum_cuda_assets); Status Transpose(const gsl::span& permutation, const Tensor& input, Tensor& output, const TensorShape* input_shape_override, void* einsum_cuda_assets); diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc index 0584893e7bda0..d77d61d8036be 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc @@ -337,7 +337,7 @@ void EinsumTypedComputeProcessor::SetDeviceHelpers(const EinsumOp::DeviceHelp const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, - const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func) { + const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) { device_transpose_func_ = device_transpose_func; device_matmul_func_ = device_matmul_func; device_reduce_sum_func_ = device_reduce_sum_func; diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h index 316cac4d98b5e..666c7c0006663 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h @@ -32,7 +32,7 @@ class EinsumTypedComputeProcessor { const EinsumOp::DeviceHelpers::MatMul& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, - const EinsumOp::DeviceHelpers::Zeroing& device_zero_buffer_func); + const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func); Status Run(); @@ -65,7 +65,7 @@ class EinsumTypedComputeProcessor { EinsumOp::DeviceHelpers::MatMul device_matmul_func_; EinsumOp::DeviceHelpers::ReduceSum device_reduce_sum_func_; EinsumOp::DeviceHelpers::DataCopy device_data_copy_func_; - EinsumOp::DeviceHelpers::Zeroing device_zero_buffer_func_; + EinsumOp::DeviceHelpers::ZeroBuffer device_zero_buffer_func_; // Holds EP-specific assets required for (auxiliary) ops that need to be executed on non-CPU EPs void* einsum_ep_assets_; diff --git a/onnxruntime/core/providers/cuda/math/einsum.cc b/onnxruntime/core/providers/cuda/math/einsum.cc index 4e84b029ff588..405f4f362bcb1 100644 --- a/onnxruntime/core/providers/cuda/math/einsum.cc +++ b/onnxruntime/core/providers/cuda/math/einsum.cc @@ -53,7 +53,7 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum, EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy, - EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing); + EinsumOp::DeviceHelpers::CudaDeviceHelpers::ZeroBuffer); return einsum_compute_processor->Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, @@ -65,7 +65,7 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum, EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy, - EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing); + EinsumOp::DeviceHelpers::CudaDeviceHelpers::ZeroBuffer); return einsum_compute_processor->Run(); } else if (inputs[0]->IsDataType()) { auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, @@ -76,7 +76,7 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector, EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum, EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy, - EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing); + EinsumOp::DeviceHelpers::CudaDeviceHelpers::ZeroBuffer); return einsum_compute_processor->Run(); } diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc index 73734fdc21d2f..82238ab0ea1e0 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc @@ -31,8 +31,8 @@ Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets) { return Status::OK(); } -// CUDA EP specific Zeroing helper -Status Zeroing(Tensor& input, void* einsum_cuda_assets) { +// CUDA EP specific Zero buffer helper +Status ZeroBuffer(Tensor& input, void* einsum_cuda_assets) { CUDA_RETURN_IF_ERROR(cudaMemsetAsync(input.MutableDataRaw(), 0, input.SizeInBytes(), static_cast(einsum_cuda_assets)->GetCudaStream())); return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h index 5101d2bffaa9c..b42cf79a857cf 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.h @@ -47,7 +47,7 @@ Status Transpose(const gsl::span& permutation, const Tensor& input Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets); -Status Zeroing(Tensor& input, void* einsum_cuda_assets); +Status ZeroBuffer(Tensor& input, void* einsum_cuda_assets); template Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data,