Cross Host Data Transfers: Add CrossHostTransferBuffers; modify CrossHostSendBuffers to call into it#41055
Cross Host Data Transfers: Add CrossHostTransferBuffers; modify CrossHostSendBuffers to call into it#41055rao-ashish wants to merge 2 commits intoopenxla:mainfrom
Conversation
| ->GetLocalDeviceState(); | ||
| if (!local_device_state.ok()) { | ||
| SetEventAsError(transfer_event, local_device_state.status()); | ||
| if (!maybe_local_device_state.ok()) { |
There was a problem hiding this comment.
nit: Prefer keeping the statusor and dereferencing it vs creating a new variable: https://abseil.io/tips/181
There was a problem hiding this comment.
Thanks, modified to keep it as a statusor
| if (local_device_state->async_dispatch_thread()) { | ||
| local_device_state->async_dispatch_thread()->Schedule( | ||
| tsl::WithCurrentContext( | ||
| [this, local_device_state, device_id, transfer_dependency_avs, |
There was a problem hiding this comment.
std::move the transfer_dependency_avs? (and below)
Also, to make sure I understand, in this PR each buffer newly waits on all transfer dependency AVs, not just its own, right?
There was a problem hiding this comment.
We can't move transfer_dependency_avs here because they are reused in each iteration of the loop over transfers_by_device.
Yes, we're now waiting on all transfer dependencies instead of per-group dependencies. If per-transfer dependencies become necessary later on, we could add a field into CrossHostTransferSpec to represent those, but for now the user could split up the arrays passed into device_put so that if different arrays are known to become ready at different times, they are batched into separate device_put calls.
There was a problem hiding this comment.
Oops I lost track of the loop. Got it, thanks!
8d492f3 to
d45e24a
Compare
| if (local_device_state->async_dispatch_thread()) { | ||
| local_device_state->async_dispatch_thread()->Schedule( | ||
| tsl::WithCurrentContext( | ||
| [this, local_device_state, device_id, transfer_dependency_avs, |
There was a problem hiding this comment.
Oops I lost track of the loop. Got it, thanks!
📝 Summary of Changes
This PR builds on XLA #40919, and is the next in a sequence of PRs that will refactor cross-host data transfer implementations to eventually rely on a shared helper function CrossHostTransferBuffers. CrossHostTransferBuffers is planned to eventually be integrated into the PJRT APIs to enable receiving data into preallocated receive buffers (this feature is being planned in collaboration with @gspschmid, @emilyfertig, and @pschuh).
This PR implements
CrossHostTransferBuffersas a helper function insideStreamExecutorGpuClient, and modifiesStreamExecutorGpuClient::CrossHostTransferBuffersto call into it. This PR also ensures that if async dispatch is used (PJRT_GPU_ENABLE_ASYNC_DISPATCH=1),CrossHostTransferBuffersschedules transfers through the async_dispatch_thread to prevent deadlocks.🎯 Justification
It is difficult to achieve good comm/compute overlap with cross-host data transfers as the current implementation always allocates receive-buffers 'just-in-time', and because the GPU memory allocator blocks on the compute stream. CrossHostTransferBuffers will enable users to receive into preallocated receive buffers, making it easier to avoid the allocator blocking issue. This PR introduces CrossHostTransferBuffers as a StreamExecutorGpuClient helper function; it will eventually be exposed through the PJRT APIs.
🚀 Kind of Contribution
♻️ Cleanup (eventually ✨ New Feature)
🧪 Unit Tests:
This PR only refactors the implementation of CrossHost{Send/Receive}Buffers, so the pre-existing unit tests for those methods already test this PR.
🧪 Execution Tests:
Verified that these 4 correctness tests continue to pass.