@@ -185,7 +185,8 @@ absl::StatusOr<std::unique_ptr<TfrtGpuBuffer>> AllocateTfrtGpuDestinationBuffer(
185185 TF_ASSIGN_OR_RETURN (
186186 auto device_buffer,
187187 MaybeOwningGpuMemory::AllocateShared (
188- client->allocator (), device->local_device_id ().value (), byte_size));
188+ client->allocator (), device->local_device_id ().value (), byte_size,
189+ LayoutUtil::MemorySpace (on_device_shape)));
189190 auto buffer_async_value_ref =
190191 tsl::MakeAvailableAsyncValueRef<MaybeOwningGpuMemory>(
191192 std::move (device_buffer));
@@ -2794,15 +2795,11 @@ bool TfrtGpuBuffer::IsDeleted() {
27942795
27952796absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuBuffer::CopyToMemorySpace (
27962797 PjRtMemorySpace* dst_memory_space) {
2797- // TODO: b/382117736 - Support non-default memory spaces.
27982798 tsl::profiler::TraceMe traceme (" TfrtGpuBuffer::CopyToMemorySpace" );
27992799 PjRtDevice* dst_device = dst_memory_space->devices ()[0 ];
28002800
2801- // TODO(sizhi): Support copy data to the pinned host memory space.
2802- if (dst_memory_space->kind () == PinnedHostMemorySpace::kKind ) {
2803- return Unimplemented (
2804- " Copy data to pinned host memory space is not implemented." );
2805- }
2801+ VLOG (2 ) << " TfrtGpuBuffer::CopyToMemorySpace: dst_device: " << dst_device
2802+ << " dst_memory_space: " << dst_memory_space->kind ();
28062803
28072804 // Copying across PjRtClients involves a copy through the host.
28082805 if (dst_device->client () != client_) {
@@ -2824,8 +2821,8 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuBuffer::CopyToMemorySpace(
28242821 }
28252822
28262823 // Copy each leaf buffer to a destination buffer.
2827- auto usage_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
2828- TrackedTfrtGpuDeviceBuffer* src_device_buffer = AcquireUsage (usage_event );
2824+ auto src_usage_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
2825+ TrackedTfrtGpuDeviceBuffer* src_device_buffer = AcquireUsage (src_usage_event );
28292826 if (src_device_buffer == nullptr ) {
28302827 return InvalidArgument (
28312828 " CopyToMemorySpace called on deleted or donated buffer" );
@@ -2834,50 +2831,51 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuBuffer::CopyToMemorySpace(
28342831 TfrtGpuDevice* gpu_dst_device = tsl::down_cast<TfrtGpuDevice*>(dst_device);
28352832 tsl::AsyncValueRef<MaybeOwningGpuMemory> src_buffer =
28362833 src_device_buffer->buffer ();
2837- auto dst_buffer = tsl::MakeUnconstructedAsyncValueRef<MaybeOwningGpuMemory>();
2834+
28382835 auto dst_definition_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
2836+ TF_ASSIGN_OR_RETURN (auto output_buffer,
2837+ AllocateTfrtGpuDestinationBuffer (
2838+ on_device_shape_, {dst_definition_event.CopyRef ()},
2839+ gpu_dst_device, client_, dst_memory_space));
2840+ auto dst_usage_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
2841+ TrackedTfrtGpuDeviceBuffer* allocated_dst_device_buffer =
2842+ output_buffer->AcquireUsage (dst_usage_event);
2843+ CHECK (allocated_dst_device_buffer != nullptr );
2844+ auto allocated_dst_buffer = allocated_dst_device_buffer->buffer ();
28392845
28402846 absl::AnyInvocable<void ()> transfer_d2d =
2841- [src_buffer (src_buffer.CopyRef ()), dst_buffer (dst_buffer.CopyRef ()),
2847+ [src_buffer (src_buffer.CopyRef ()),
2848+ allocated_dst_buffer (allocated_dst_buffer.CopyRef ()),
28422849 dst_definition_event (dst_definition_event.CopyRef ()),
28432850 src_definition_event (src_device_buffer->definition_event ().CopyRef ()),
2844- dst_device (gpu_dst_device), usage_event (usage_event.CopyRef ())]() {
2851+ dst_device (gpu_dst_device), src_usage_event (src_usage_event.CopyRef ()),
2852+ dst_usage_event (dst_usage_event.CopyRef ())]() {
28452853 tsl::profiler::TraceMe traceme (" D2D copy" );
2854+
2855+ MarkGpuEventReadyOnExit ready_on_exit_src (std::move (src_usage_event));
2856+ MarkGpuEventReadyOnExit ready_on_exit_dst (std::move (dst_usage_event));
2857+
28462858 if (const absl::Status* error =
28472859 dst_definition_event.GetErrorIfPresent ()) {
2848- dst_buffer .SetError (*error);
2860+ allocated_dst_buffer .SetError (*error);
28492861 dst_definition_event.SetError (*error);
2850- usage_event.SetStateConcrete ();
28512862 return ;
28522863 }
28532864
28542865 if (const absl::Status* error =
28552866 src_definition_event.GetErrorIfPresent ()) {
2856- dst_buffer .SetError (*error);
2867+ allocated_dst_buffer .SetError (*error);
28572868 dst_definition_event.SetError (*error);
2858- usage_event.SetStateConcrete ();
2859- return ;
2860- }
2861- MarkGpuEventReadyOnExit ready_on_exit (std::move (usage_event));
2862- absl::StatusOr<MaybeOwningGpuMemory> allocated_dst_buffer =
2863- MaybeOwningGpuMemory::AllocateShared (
2864- dst_device->allocator (),
2865- dst_device->local_hardware_id ().value (),
2866- src_buffer->buffer ().size ());
2867- if (!allocated_dst_buffer.ok ()) {
2868- dst_buffer.SetError (allocated_dst_buffer.status ());
2869- dst_definition_event.SetError (allocated_dst_buffer.status ());
28702869 return ;
28712870 }
2872- dst_buffer.emplace (std::move (allocated_dst_buffer.value ()));
28732871
28742872 absl::StatusOr<BoundedStreamPool::Handle> stream =
28752873 dst_device->stream_pool ().Borrow ();
28762874 if (!stream.ok ()) {
28772875 dst_definition_event.SetError (stream.status ());
28782876 return ;
28792877 }
2880- se::DeviceMemoryBase dst (dst_buffer ->buffer ());
2878+ se::DeviceMemoryBase dst (allocated_dst_buffer ->buffer ());
28812879 absl::Status status = stream->get ()->Memcpy (
28822880 &dst, src_buffer->buffer (), src_buffer->buffer ().size ());
28832881 if (!status.ok ()) {
@@ -2895,11 +2893,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuBuffer::CopyToMemorySpace(
28952893 EnqueueWorkWhenReady (client_->blocking_thread_pool (),
28962894 {src_device_buffer->definition_event ().CopyRCRef ()},
28972895 std::move (transfer_d2d));
2898- return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtGpuBuffer>(
2899- on_device_shape_,
2900- std::make_unique<TrackedTfrtGpuDeviceBuffer>(
2901- std::move (dst_buffer), std::move (dst_definition_event)),
2902- client (), tsl::down_cast<TfrtGpuDevice*>(dst_device), dst_memory_space));
2896+ return output_buffer;
29032897}
29042898
29052899void TfrtGpuBuffer::DropExternalReference () {
0 commit comments