@@ -663,10 +663,7 @@ StreamExecutorGpuClient::CrossHostSendBuffers(
663663 std::vector<tsl::RCReference<PjRtRawBuffer>> raw_buffers;
664664 raw_buffers.reserve (buffers.size ());
665665
666- std::vector<std::vector<tsl::RCReference<tsl::AsyncValue>>>
667- transfer_dependency_avs;
668- transfer_dependency_avs.reserve (buffers.size ());
669-
666+ std::vector<tsl::RCReference<tsl::AsyncValue>> transfer_dependency_avs;
670667 std::vector<tsl::RCReference<PjRtDeviceEventPromise>> usage_event_promises;
671668 usage_event_promises.reserve (buffers.size ());
672669
@@ -693,67 +690,145 @@ StreamExecutorGpuClient::CrossHostSendBuffers(
693690 [&](tsl::RCReference<CommonPjRtRawBuffer> buf_raw_buffer,
694691 std::vector<tsl::RCReference<tsl::AsyncValue>>
695692 buf_definition_events) mutable
696- -> absl::StatusOr<PjRtDeviceEventRef> {
693+ -> absl::StatusOr<PjRtDeviceEventRef> {
697694 // Keep raw_buffer alive until the usage_event completes,
698695 // preventing the allocation from being freed while the
699696 // send is in-flight.
700697 usage_event.AndThen ([buf_raw_buffer]() {});
701698 raw_buffers.push_back (std::move (buf_raw_buffer));
702- transfer_dependency_avs.push_back (
703- std::move (buf_definition_events));
699+ for (tsl::RCReference<tsl::AsyncValue>& definition_event :
700+ buf_definition_events) {
701+ transfer_dependency_avs.push_back (
702+ std::move (definition_event));
703+ }
704704 return PjRtDeviceEventRef (usage_event);
705705 },
706706 " CrossHostSendBuffers" ));
707707 }
708708
709- // Group the sends by local device.
710- absl::flat_hash_map<PjRtDevice*, std::vector<int >> sends_by_device;
709+ // Build the CrossHostTransferSpec for each buffer.
710+ std::vector<CrossHostTransferSpec> transfer_specs;
711+ transfer_specs.reserve (buffers.size ());
711712 for (int i = 0 ; i < buffers.size (); ++i) {
712- sends_by_device[buffers[i]->device ()].push_back (i);
713+ transfer_specs.push_back (CrossHostTransferSpec{
714+ /* src_global_device_id=*/ buffers[i]->device ()->global_device_id (),
715+ dst_global_device_ids[i], std::move (raw_buffers[i])});
713716 }
714717
715718 // Schedule sends.
716- for (const auto & [device, send_idxs] : sends_by_device) {
717- const GlobalDeviceId src_global_device_id = device->global_device_id ();
719+ TF_ASSIGN_OR_RETURN (
720+ std::vector<PjRtDeviceEventRef> usage_events,
721+ CrossHostTransferBuffers (std::move (transfer_dependency_avs),
722+ std::move (transfer_specs)));
723+
724+ // Populate usage events.
725+ for (int i = 0 ; i < buffers.size (); ++i) {
726+ usage_event_promises[i]->Set (usage_events[i]);
727+ }
728+
729+ return futures;
730+ }
731+
732+ absl::StatusOr<std::vector<PjRtDeviceEventRef>>
733+ StreamExecutorGpuClient::CrossHostTransferBuffers (
734+ std::vector<tsl::RCReference<tsl::AsyncValue>> transfer_dependency_avs,
735+ std::vector<CrossHostTransferSpec> transfer_specs) {
736+ // Validate arguments.
737+ for (int i = 0 ; i < transfer_specs.size (); ++i) {
738+ if (transfer_specs[i].raw_buffer ->memory_space ()->devices ().size () != 1 ) {
739+ return InvalidArgument (
740+ " CrossHostTransferBuffers: Received a raw buffer with a memory space "
741+ " that is not attached to exactly 1 device." );
742+ }
743+ PjRtDevice* buffer_device =
744+ transfer_specs[i].raw_buffer ->memory_space ()->devices ()[0 ];
745+ if (!buffer_device->IsAddressable ()) {
746+ return InvalidArgument (
747+ " CrossHostTransferBuffers: raw buffer %d is on non-addressable "
748+ " device with global device id %d." ,
749+ i, buffer_device->global_device_id ().value ());
750+ }
751+ // Each transfer must be between an addressable and a non-addressable
752+ // device. If both devices are addressable, then both a data transfer and a
753+ // 'normal' XLA SPMD executable may try to acquire the same GPU clique,
754+ // causing issues.
755+ GlobalDeviceId remote_id = (transfer_specs[i].src_global_device_id ==
756+ buffer_device->global_device_id ())
757+ ? transfer_specs[i].dst_global_device_id
758+ : transfer_specs[i].src_global_device_id ;
759+ TF_ASSIGN_OR_RETURN (PjRtDevice * remote_device, LookupDevice (remote_id));
760+ if (remote_device->IsAddressable ()) {
761+ return InvalidArgument (
762+ " CrossHostTransferBuffers: remote device for buffer %d is "
763+ " addressable (global device id %d), but cross-host transfers must "
764+ " be between an addressable and a non-addressable device." ,
765+ i, remote_id.value ());
766+ }
767+ }
768+
769+ // Group the transfers by their buffers' device.
770+ absl::flat_hash_map<PjRtDevice*, std::vector<int >> transfers_by_device;
771+ for (int i = 0 ; i < transfer_specs.size (); ++i) {
772+ PjRtDevice* buffer_device =
773+ transfer_specs[i].raw_buffer ->memory_space ()->devices ()[0 ];
774+ transfers_by_device[buffer_device].push_back (i);
775+ }
776+
777+ // We will register a single transfer event for all transfers to/from the same
778+ // device. We will collect the references to those events inside
779+ // output_transfer_events. This will eventually be returned to the user.
780+ std::vector<PjRtDeviceEventRef> output_transfer_events (transfer_specs.size (),
781+ PjRtDeviceEventRef ());
782+
783+ // Schedule transfers.
784+ for (const auto & [device, transfer_idxs] : transfers_by_device) {
785+ const GlobalDeviceId device_id = device->global_device_id ();
718786
719787 // Create a transfer event for transfers on this device.
720788 tsl::AsyncValueRef<BufferSequencingEvent> transfer_event =
721789 BufferSequencingEvent::Create (this ->async_work_runner ());
722790
723- // Extract transfer dependencies, form transfer specs, and fulfill
724- // usage_event_promises.
725- std::vector<tsl::RCReference<tsl::AsyncValue>> curr_transfer_dependency_avs;
726- std::vector<CrossHostTransferSpec> transfer_specs;
727- transfer_specs.reserve (send_idxs.size ());
791+ // Form transfer specs.
792+ std::vector<CrossHostTransferSpec> curr_transfer_specs;
793+ curr_transfer_specs.reserve (transfer_idxs.size ());
728794
729- for (int idx : send_idxs) {
730- for (tsl::RCReference<tsl::AsyncValue>& event :
731- transfer_dependency_avs[idx]) {
732- curr_transfer_dependency_avs.push_back (std::move (event));
733- }
734- transfer_specs.push_back (CrossHostTransferSpec{
735- src_global_device_id, dst_global_device_ids[idx],
736- std::move (raw_buffers[idx])});
737-
738- usage_event_promises[idx]->Set (PjRtDeviceEventRef (transfer_event));
795+ for (int idx : transfer_idxs) {
796+ curr_transfer_specs.push_back (std::move (transfer_specs[idx]));
797+ output_transfer_events[idx] = PjRtDeviceEventRef (transfer_event);
739798 }
740799
741800 // Get the local_device_state and use it to schedule transfers. Fail
742801 // transfers early if we cannot get the local_device_state.
743- absl::StatusOr<LocalDeviceState*> local_device_state =
802+ absl::StatusOr<LocalDeviceState*> maybe_local_device_state =
744803 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
745804 ->GetLocalDeviceState ();
746- if (!local_device_state .ok ()) {
747- SetEventAsError (transfer_event, local_device_state .status ());
805+ if (!maybe_local_device_state .ok ()) {
806+ SetEventAsError (transfer_event, maybe_local_device_state .status ());
748807 continue ;
749808 }
750-
751- ScheduleTransfersOnLocalDevice (
752- *local_device_state, src_global_device_id, std::move (transfer_event),
753- std::move (curr_transfer_dependency_avs), std::move (transfer_specs));
809+ LocalDeviceState* local_device_state = *maybe_local_device_state;
810+
811+ // Launch ScheduleTransfersOnLocalDevice on either the async dispatch thread
812+ // of the calling thread.
813+ if (local_device_state->async_dispatch_thread ()) {
814+ local_device_state->async_dispatch_thread ()->Schedule (
815+ tsl::WithCurrentContext (
816+ [this , local_device_state, device_id, transfer_dependency_avs,
817+ curr_transfer_specs = std::move (curr_transfer_specs),
818+ transfer_event = std::move (transfer_event)]() mutable {
819+ ScheduleTransfersOnLocalDevice (
820+ local_device_state, device_id, std::move (transfer_event),
821+ std::move (transfer_dependency_avs),
822+ std::move (curr_transfer_specs));
823+ }));
824+ } else {
825+ ScheduleTransfersOnLocalDevice (
826+ local_device_state, device_id, std::move (transfer_event),
827+ transfer_dependency_avs, std::move (curr_transfer_specs));
828+ }
754829 }
755830
756- return futures ;
831+ return output_transfer_events ;
757832}
758833
759834void StreamExecutorGpuClient::ScheduleTransfersOnLocalDevice (
0 commit comments