Skip to content

Commit 9b9634f

Browse files
committed
Fixes
2 parents e0e6e01 + c74b516 commit 9b9634f

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -377,20 +377,22 @@ Status EinsumTypedComputeProcessor<T>::Run() {
377377
Tensor& output = *context_->Output(0, output_dims);
378378

379379
if (output.Location().device.Type() != OrtDevice::CPU) {
380+
// Get CPU allocator to allocate a staging buffer on CPU
380381
AllocatorPtr cpu_allocator;
381382
ORT_RETURN_IF_ERROR(context_->GetTempSpaceCPUAllocator(&cpu_allocator));
382383

383-
// If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU
384-
// buffer to stage the zero buffer results which we will then copy over to the op's output
385-
// allocated on the non-CPU device using the device data copy abstraction
386-
Tensor candidate_output(raw_inputs[0]->DataType(), output_dims, cpu_allocator);
387-
ZeroInputBuffer<T>(candidate_output);
388-
389-
auto status = device_data_copy_func_(candidate_output, output, einsum_ep_assets_);
390-
ORT_ENFORCE(status.IsOK(), "Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: ",
391-
status.ErrorMessage());
392-
} else { // Zero out the op's output buffer
393-
ZeroInputBuffer<T>(output);
384+
// If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU
385+
// buffer to stage the zero buffer results which we will then copy over to the op's output
386+
// allocated on the non-CPU device using the device data copy abstraction
387+
Tensor candidate_output(raw_inputs[0]->DataType(), output_dims, cpu_allocator);
388+
ZeroInputBuffer<T>(candidate_output);
389+
390+
// Copy zeroed buffer to the output buffer
391+
auto status = device_data_copy_func_(candidate_output, output, einsum_ep_assets_);
392+
ORT_ENFORCE(status.IsOK(), "Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: ",
393+
status.ErrorMessage());
394+
} else { // Zero out the op's output buffer
395+
ZeroInputBuffer<T>(output);
394396
}
395397

396398
return Status::OK();

0 commit comments

Comments
 (0)