Skip to content

Commit bc870c8

Browse files
committed
Add CrossHostTransferBuffers; modify CrossHostSendBuffers to call into it
1 parent 9b3bd63 commit bc870c8

File tree

2 files changed

+114
-35
lines changed

2 files changed

+114
-35
lines changed

xla/pjrt/gpu/se_gpu_pjrt_client.cc

Lines changed: 110 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -663,10 +663,7 @@ StreamExecutorGpuClient::CrossHostSendBuffers(
663663
std::vector<tsl::RCReference<PjRtRawBuffer>> raw_buffers;
664664
raw_buffers.reserve(buffers.size());
665665

666-
std::vector<std::vector<tsl::RCReference<tsl::AsyncValue>>>
667-
transfer_dependency_avs;
668-
transfer_dependency_avs.reserve(buffers.size());
669-
666+
std::vector<tsl::RCReference<tsl::AsyncValue>> transfer_dependency_avs;
670667
std::vector<tsl::RCReference<PjRtDeviceEventPromise>> usage_event_promises;
671668
usage_event_promises.reserve(buffers.size());
672669

@@ -693,67 +690,145 @@ StreamExecutorGpuClient::CrossHostSendBuffers(
693690
[&](tsl::RCReference<CommonPjRtRawBuffer> buf_raw_buffer,
694691
std::vector<tsl::RCReference<tsl::AsyncValue>>
695692
buf_definition_events) mutable
696-
-> absl::StatusOr<PjRtDeviceEventRef> {
693+
-> absl::StatusOr<PjRtDeviceEventRef> {
697694
// Keep raw_buffer alive until the usage_event completes,
698695
// preventing the allocation from being freed while the
699696
// send is in-flight.
700697
usage_event.AndThen([buf_raw_buffer]() {});
701698
raw_buffers.push_back(std::move(buf_raw_buffer));
702-
transfer_dependency_avs.push_back(
703-
std::move(buf_definition_events));
699+
for (tsl::RCReference<tsl::AsyncValue>& definition_event :
700+
buf_definition_events) {
701+
transfer_dependency_avs.push_back(
702+
std::move(definition_event));
703+
}
704704
return PjRtDeviceEventRef(usage_event);
705705
},
706706
"CrossHostSendBuffers"));
707707
}
708708

709-
// Group the sends by local device.
710-
absl::flat_hash_map<PjRtDevice*, std::vector<int>> sends_by_device;
709+
// Build the CrossHostTransferSpec for each buffer.
710+
std::vector<CrossHostTransferSpec> transfer_specs;
711+
transfer_specs.reserve(buffers.size());
711712
for (int i = 0; i < buffers.size(); ++i) {
712-
sends_by_device[buffers[i]->device()].push_back(i);
713+
transfer_specs.push_back(CrossHostTransferSpec{
714+
/*src_global_device_id=*/buffers[i]->device()->global_device_id(),
715+
dst_global_device_ids[i], std::move(raw_buffers[i])});
713716
}
714717

715718
// Schedule sends.
716-
for (const auto& [device, send_idxs] : sends_by_device) {
717-
const GlobalDeviceId src_global_device_id = device->global_device_id();
719+
TF_ASSIGN_OR_RETURN(
720+
std::vector<PjRtDeviceEventRef> usage_events,
721+
CrossHostTransferBuffers(std::move(transfer_dependency_avs),
722+
std::move(transfer_specs)));
723+
724+
// Populate usage events.
725+
for (int i = 0; i < buffers.size(); ++i) {
726+
usage_event_promises[i]->Set(usage_events[i]);
727+
}
728+
729+
return futures;
730+
}
731+
732+
absl::StatusOr<std::vector<PjRtDeviceEventRef>>
733+
StreamExecutorGpuClient::CrossHostTransferBuffers(
734+
std::vector<tsl::RCReference<tsl::AsyncValue>> transfer_dependency_avs,
735+
std::vector<CrossHostTransferSpec> transfer_specs) {
736+
// Validate arguments.
737+
for (int i = 0; i < transfer_specs.size(); ++i) {
738+
if (transfer_specs[i].raw_buffer->memory_space()->devices().size() != 1) {
739+
return InvalidArgument(
740+
"CrossHostTransferBuffers: Received a raw buffer with a memory space "
741+
"that is not attached to exactly 1 device.");
742+
}
743+
PjRtDevice* buffer_device =
744+
transfer_specs[i].raw_buffer->memory_space()->devices()[0];
745+
if (!buffer_device->IsAddressable()) {
746+
return InvalidArgument(
747+
"CrossHostTransferBuffers: raw buffer %d is on non-addressable "
748+
"device with global device id %d.",
749+
i, buffer_device->global_device_id().value());
750+
}
751+
// Each transfer must be between an addressable and a non-addressable
752+
// device. If both devices are addressable, then both a data transfer and a
753+
// 'normal' XLA SPMD executable may try to acquire the same GPU clique,
754+
// causing issues.
755+
GlobalDeviceId remote_id = (transfer_specs[i].src_global_device_id ==
756+
buffer_device->global_device_id())
757+
? transfer_specs[i].dst_global_device_id
758+
: transfer_specs[i].src_global_device_id;
759+
TF_ASSIGN_OR_RETURN(PjRtDevice * remote_device, LookupDevice(remote_id));
760+
if (remote_device->IsAddressable()) {
761+
return InvalidArgument(
762+
"CrossHostTransferBuffers: remote device for buffer %d is "
763+
"addressable (global device id %d), but cross-host transfers must "
764+
"be between an addressable and a non-addressable device.",
765+
i, remote_id.value());
766+
}
767+
}
768+
769+
// Group the transfers by their buffers' device.
770+
absl::flat_hash_map<PjRtDevice*, std::vector<int>> transfers_by_device;
771+
for (int i = 0; i < transfer_specs.size(); ++i) {
772+
PjRtDevice* buffer_device =
773+
transfer_specs[i].raw_buffer->memory_space()->devices()[0];
774+
transfers_by_device[buffer_device].push_back(i);
775+
}
776+
777+
// We will register a single transfer event for all transfers to/from the same
778+
// device. We will collect the references to those events inside
779+
// output_transfer_events. This will eventually be returned to the user.
780+
std::vector<PjRtDeviceEventRef> output_transfer_events(transfer_specs.size(),
781+
PjRtDeviceEventRef());
782+
783+
// Schedule transfers.
784+
for (const auto& [device, transfer_idxs] : transfers_by_device) {
785+
const GlobalDeviceId device_id = device->global_device_id();
718786

719787
// Create a transfer event for transfers on this device.
720788
tsl::AsyncValueRef<BufferSequencingEvent> transfer_event =
721789
BufferSequencingEvent::Create(this->async_work_runner());
722790

723-
// Extract transfer dependencies, form transfer specs, and fulfill
724-
// usage_event_promises.
725-
std::vector<tsl::RCReference<tsl::AsyncValue>> curr_transfer_dependency_avs;
726-
std::vector<CrossHostTransferSpec> transfer_specs;
727-
transfer_specs.reserve(send_idxs.size());
791+
// Form transfer specs.
792+
std::vector<CrossHostTransferSpec> curr_transfer_specs;
793+
curr_transfer_specs.reserve(transfer_idxs.size());
728794

729-
for (int idx : send_idxs) {
730-
for (tsl::RCReference<tsl::AsyncValue>& event :
731-
transfer_dependency_avs[idx]) {
732-
curr_transfer_dependency_avs.push_back(std::move(event));
733-
}
734-
transfer_specs.push_back(CrossHostTransferSpec{
735-
src_global_device_id, dst_global_device_ids[idx],
736-
std::move(raw_buffers[idx])});
737-
738-
usage_event_promises[idx]->Set(PjRtDeviceEventRef(transfer_event));
795+
for (int idx : transfer_idxs) {
796+
curr_transfer_specs.push_back(std::move(transfer_specs[idx]));
797+
output_transfer_events[idx] = PjRtDeviceEventRef(transfer_event);
739798
}
740799

741800
// Get the local_device_state and use it to schedule transfers. Fail
742801
// transfers early if we cannot get the local_device_state.
743-
absl::StatusOr<LocalDeviceState*> local_device_state =
802+
absl::StatusOr<LocalDeviceState*> maybe_local_device_state =
744803
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
745804
->GetLocalDeviceState();
746-
if (!local_device_state.ok()) {
747-
SetEventAsError(transfer_event, local_device_state.status());
805+
if (!maybe_local_device_state.ok()) {
806+
SetEventAsError(transfer_event, maybe_local_device_state.status());
748807
continue;
749808
}
750-
751-
ScheduleTransfersOnLocalDevice(
752-
*local_device_state, src_global_device_id, std::move(transfer_event),
753-
std::move(curr_transfer_dependency_avs), std::move(transfer_specs));
809+
LocalDeviceState* local_device_state = *maybe_local_device_state;
810+
811+
// Launch ScheduleTransfersOnLocalDevice on either the async dispatch thread
812+
// of the calling thread.
813+
if (local_device_state->async_dispatch_thread()) {
814+
local_device_state->async_dispatch_thread()->Schedule(
815+
tsl::WithCurrentContext(
816+
[this, local_device_state, device_id, transfer_dependency_avs,
817+
curr_transfer_specs = std::move(curr_transfer_specs),
818+
transfer_event = std::move(transfer_event)]() mutable {
819+
ScheduleTransfersOnLocalDevice(
820+
local_device_state, device_id, std::move(transfer_event),
821+
std::move(transfer_dependency_avs),
822+
std::move(curr_transfer_specs));
823+
}));
824+
} else {
825+
ScheduleTransfersOnLocalDevice(
826+
local_device_state, device_id, std::move(transfer_event),
827+
transfer_dependency_avs, std::move(curr_transfer_specs));
828+
}
754829
}
755830

756-
return futures;
831+
return output_transfer_events;
757832
}
758833

759834
void StreamExecutorGpuClient::ScheduleTransfersOnLocalDevice(

xla/pjrt/gpu/se_gpu_pjrt_client.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {
226226
tsl::RCReference<PjRtRawBuffer> raw_buffer;
227227
};
228228

229+
absl::StatusOr<std::vector<PjRtDeviceEventRef>> CrossHostTransferBuffers(
230+
std::vector<tsl::RCReference<tsl::AsyncValue>> transfer_dependency_avs,
231+
std::vector<CrossHostTransferSpec> transfer_specs);
232+
229233
void ScheduleTransfersOnLocalDevice(
230234
LocalDeviceState* local_device_state, GlobalDeviceId device_id,
231235
tsl::AsyncValueRef<BufferSequencingEvent> transfer_event,

0 commit comments

Comments
 (0)