Skip to content

Cross Host Data Transfers: Add CrossHostTransferBuffers; modify CrossHostSendBuffers to call into it#41055

Open
rao-ashish wants to merge 2 commits intoopenxla:mainfrom
rao-ashish:asrao/cross_host_refactor_v2_3
Open

Cross Host Data Transfers: Add CrossHostTransferBuffers; modify CrossHostSendBuffers to call into it#41055
rao-ashish wants to merge 2 commits intoopenxla:mainfrom
rao-ashish:asrao/cross_host_refactor_v2_3

Conversation

@rao-ashish
Copy link
Copy Markdown
Contributor

📝 Summary of Changes
This PR builds on XLA #40919, and is the next in a sequence of PRs that will refactor cross-host data transfer implementations to eventually rely on a shared helper function CrossHostTransferBuffers. CrossHostTransferBuffers is planned to eventually be integrated into the PJRT APIs to enable receiving data into preallocated receive buffers (this feature is being planned in collaboration with @gspschmid, @emilyfertig, and @pschuh).

This PR implements CrossHostTransferBuffers as a helper function inside StreamExecutorGpuClient, and modifies StreamExecutorGpuClient::CrossHostTransferBuffers to call into it. This PR also ensures that if async dispatch is used (PJRT_GPU_ENABLE_ASYNC_DISPATCH=1), CrossHostTransferBuffers schedules transfers through the async_dispatch_thread to prevent deadlocks.

🎯 Justification
It is difficult to achieve good comm/compute overlap with cross-host data transfers as the current implementation always allocates receive-buffers 'just-in-time', and because the GPU memory allocator blocks on the compute stream. CrossHostTransferBuffers will enable users to receive into preallocated receive buffers, making it easier to avoid the allocator blocking issue. This PR introduces CrossHostTransferBuffers as a StreamExecutorGpuClient helper function; it will eventually be exposed through the PJRT APIs.

🚀 Kind of Contribution
♻️ Cleanup (eventually ✨ New Feature)

🧪 Unit Tests:
This PR only refactors the implementation of CrossHost{Send/Receive}Buffers, so the pre-existing unit tests for those methods already test this PR.

🧪 Execution Tests:
Verified that these 4 correctness tests continue to pass.

Comment thread xla/pjrt/gpu/se_gpu_pjrt_client.cc Outdated
->GetLocalDeviceState();
if (!local_device_state.ok()) {
SetEventAsError(transfer_event, local_device_state.status());
if (!maybe_local_device_state.ok()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Prefer keeping the statusor and dereferencing it vs creating a new variable: https://abseil.io/tips/181

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, modified to keep it as a statusor

if (local_device_state->async_dispatch_thread()) {
local_device_state->async_dispatch_thread()->Schedule(
tsl::WithCurrentContext(
[this, local_device_state, device_id, transfer_dependency_avs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::move the transfer_dependency_avs? (and below)

Also, to make sure I understand, in this PR each buffer newly waits on all transfer dependency AVs, not just its own, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't move transfer_dependency_avs here because they are reused in each iteration of the loop over transfers_by_device.

Yes, we're now waiting on all transfer dependencies instead of per-group dependencies. If per-transfer dependencies become necessary later on, we could add a field into CrossHostTransferSpec to represent those, but for now the user could split up the arrays passed into device_put so that if different arrays are known to become ready at different times, they are batched into separate device_put calls.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops I lost track of the loop. Got it, thanks!

@rao-ashish rao-ashish force-pushed the asrao/cross_host_refactor_v2_3 branch from 8d492f3 to d45e24a Compare April 17, 2026 01:38
if (local_device_state->async_dispatch_thread()) {
local_device_state->async_dispatch_thread()->Schedule(
tsl::WithCurrentContext(
[this, local_device_state, device_id, transfer_dependency_avs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops I lost track of the loop. Got it, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants