@@ -377,20 +377,22 @@ Status EinsumTypedComputeProcessor<T>::Run() {
377377 Tensor& output = *context_->Output (0 , output_dims);
378378
379379 if (output.Location ().device .Type () != OrtDevice::CPU) {
380+ // Get CPU allocator to allocate a staging buffer on CPU
380381 AllocatorPtr cpu_allocator;
381382 ORT_RETURN_IF_ERROR (context_->GetTempSpaceCPUAllocator (&cpu_allocator));
382383
383- // If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU
384- // buffer to stage the zero buffer results which we will then copy over to the op's output
385- // allocated on the non-CPU device using the device data copy abstraction
386- Tensor candidate_output (raw_inputs[0 ]->DataType (), output_dims, cpu_allocator);
387- ZeroInputBuffer<T>(candidate_output);
388-
389- auto status = device_data_copy_func_ (candidate_output, output, einsum_ep_assets_);
390- ORT_ENFORCE (status.IsOK (), " Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: " ,
391- status.ErrorMessage ());
392- } else { // Zero out the op's output buffer
393- ZeroInputBuffer<T>(output);
384+ // If this Einsum node is partitioned to a non-CPU EP, we will use an intermediate CPU
385+ // buffer to stage the zero buffer results which we will then copy over to the op's output
386+ // allocated on the non-CPU device using the device data copy abstraction
387+ Tensor candidate_output (raw_inputs[0 ]->DataType (), output_dims, cpu_allocator);
388+ ZeroInputBuffer<T>(candidate_output);
389+
390+ // Copy zeroed buffer to the output buffer
391+ auto status = device_data_copy_func_ (candidate_output, output, einsum_ep_assets_);
392+ ORT_ENFORCE (status.IsOK (), " Einsum op: Could not copy the intermediate output's buffer into the op's output buffer. Error: " ,
393+ status.ErrorMessage ());
394+ } else { // Zero out the op's output buffer
395+ ZeroInputBuffer<T>(output);
394396 }
395397
396398 return Status::OK ();
0 commit comments