Skip to content
Merged
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
std::unique_ptr<EinsumTypedComputeProcessor<float>> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<float>>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
std::unique_ptr<EinsumTypedComputeProcessor<double>> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<double>>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
std::unique_ptr<EinsumTypedComputeProcessor<MLFloat16>> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<MLFloat16>>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); }
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<float>* p) override { return p->Run(); }
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<double>* p) override { return p->Run(); }
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<MLFloat16>* p) override { return p->Run(); }
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ struct ProviderHostCPU {
virtual std::unique_ptr<EinsumTypedComputeProcessor<float>> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0;
virtual std::unique_ptr<EinsumTypedComputeProcessor<double>> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0;
virtual std::unique_ptr<EinsumTypedComputeProcessor<MLFloat16>> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0;
virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<float>* p) = 0;
virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<double>* p) = 0;
virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<MLFloat16>* p) = 0;
Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/core/providers/cpu/math/einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor.SetDeviceHelpers(EinsumOp::DeviceHelpers::CpuDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::MatMul<float>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum<float>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing);
return einsum_compute_processor.Run();
} else if (inputs[0]->IsDataType<int32_t>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<int32_t>(context,
Expand All @@ -78,7 +79,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor.SetDeviceHelpers(EinsumOp::DeviceHelpers::CpuDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::MatMul<int32_t>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum<int32_t>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing);

return einsum_compute_processor.Run();
} else if (inputs[0]->IsDataType<double>()) {
Expand All @@ -92,7 +94,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor.SetDeviceHelpers(EinsumOp::DeviceHelpers::CpuDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::MatMul<double>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum<double>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing);
return einsum_compute_processor.Run();
} else if (inputs[0]->IsDataType<int64_t>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<int64_t>(context,
Expand All @@ -104,7 +107,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor.SetDeviceHelpers(EinsumOp::DeviceHelpers::CpuDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::MatMul<int64_t>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum<int64_t>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing);

return einsum_compute_processor.Run();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ Status DataCopy(const Tensor& input, Tensor& output, void* /*einsum_cuda_assets*
return Status::OK();
}

// CPU specific Zeroing helper
Status Zeroing(Tensor& input, void* /*einsum_cuda_assets*/) {
memset(input.MutableDataRaw(), 0, input.SizeInBytes());
return Status::OK();
}

// CPU specific Transpose helper
Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input,
Tensor& output, const TensorShape* input_shape_override, void* /*einsum_cuda_assets*/) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ namespace DeviceHelpers {
// Data copy op - Copies raw data from the source tensor's buffer to the destination tensor's buffer
using DataCopy = std::function<Status(const Tensor& input, Tensor& output, void* einsum_cuda_assets)>;

// Zeroing op - Sets all bytes in the tensor's buffer to zero
using Zeroing = std::function<Status(Tensor& input, void* einsum_cuda_assets)>;

// Transpose op - Transposes given input based on data in `permutation`
using Transpose = std::function<Status(const gsl::span<const size_t>& permutation, const Tensor& input,
Tensor& output, const TensorShape* input_shape_override,
Expand Down Expand Up @@ -63,6 +66,8 @@ namespace CpuDeviceHelpers {

Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets);

Status Zeroing(Tensor& input, void* einsum_cuda_assets);

Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input,
Tensor& output, const TensorShape* input_shape_override, void* einsum_cuda_assets);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,13 @@ template <typename T>
void EinsumTypedComputeProcessor<T>::SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func,
const EinsumOp::DeviceHelpers::MatMul<T>& device_matmul_func,
const EinsumOp::DeviceHelpers::ReduceSum<T>& device_reduce_sum_func,
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) {
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func,
const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func) {
device_transpose_func_ = device_transpose_func;
device_matmul_func_ = device_matmul_func;
device_reduce_sum_func_ = device_reduce_sum_func;
device_data_copy_func_ = device_data_copy_func;
zero_input_buffer_func_ = zero_input_buffer_func;
}

template <typename T>
Expand All @@ -357,6 +359,20 @@ Status EinsumTypedComputeProcessor<T>::Run() {

auto num_inputs = context_->InputCount();

{
bool has_empty_input = std::any_of(raw_inputs.begin(), raw_inputs.end(), [](const auto& input) {
return input->Shape().Size() == 0;
});

// Skip all the work, fill with zeros if needed
if (has_empty_input) {
const auto output_dims = einsum_compute_preprocessor_.GetOutputDims();
Tensor& output = *context_->Output(0, output_dims);

return zero_input_buffer_func_(output, einsum_ep_assets_);
}
}

// Pre-process the first input so as to reduce any dims that only it has
std::unique_ptr<const Tensor> result;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class EinsumTypedComputeProcessor {
void SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func,
const EinsumOp::DeviceHelpers::MatMul<T>& device_matmul_func,
const EinsumOp::DeviceHelpers::ReduceSum<T>& device_reduce_sum_func,
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func);
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func,
const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func);

Status Run();

Expand Down Expand Up @@ -64,6 +65,7 @@ class EinsumTypedComputeProcessor {
EinsumOp::DeviceHelpers::MatMul<T> device_matmul_func_;
EinsumOp::DeviceHelpers::ReduceSum<T> device_reduce_sum_func_;
EinsumOp::DeviceHelpers::DataCopy device_data_copy_func_;
EinsumOp::DeviceHelpers::Zeroing zero_input_buffer_func_;

// Holds EP-specific assets required for (auxiliary) ops that need to be executed on non-CPU EPs
void* einsum_ep_assets_;
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/core/providers/cuda/math/einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul<float>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum<float>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing);
return einsum_compute_processor->Run();
} else if (inputs[0]->IsDataType<double>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<double>::Create(context, allocator, tp,
Expand All @@ -63,7 +64,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul<double>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum<double>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing);
return einsum_compute_processor->Run();
} else if (inputs[0]->IsDataType<MLFloat16>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<MLFloat16>::Create(context, allocator, tp,
Expand All @@ -73,7 +75,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul<MLFloat16>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum<MLFloat16>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing);
return einsum_compute_processor->Run();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets) {
return Status::OK();
}


// CUDA EP specific Zeroing helper
Status Zeroing(Tensor& input, void* einsum_cuda_assets) {
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(input.MutableDataRaw(), 0, input.SizeInBytes(),
static_cast<EinsumCudaAssets*>(einsum_cuda_assets)->GetCudaStream()));
return Status::OK();
}

// CUDA EP specific Transpose helper
Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input,
Tensor& output, const TensorShape* input_shape_override, void* einsum_cuda_assets) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input

Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets);

Status Zeroing(Tensor& input, void* einsum_cuda_assets);

template <typename T>
Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data,
size_t left_stride, size_t right_stride, size_t output_stride,
Expand Down
Loading
Loading