Skip to content

Commit b811965

Browse files
sizhit2tensorflower-gardener
authored andcommitted
Support memory space in underlying buffer allocation.
Enable PinnedMemory in D2D transfer. PiperOrigin-RevId: 755894750
1 parent 7eced04 commit b811965

File tree

5 files changed

+47
-41
lines changed

5 files changed

+47
-41
lines changed

third_party/xla/xla/pjrt/gpu/tfrt/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ cc_library(
233233
"//xla/service:shaped_buffer",
234234
"//xla/stream_executor:device_memory",
235235
"//xla/stream_executor:device_memory_allocator",
236+
"//xla/stream_executor:stream_executor_h",
236237
"//xla/tsl/concurrency:async_value",
237238
"//xla/tsl/framework:allocator",
238239
"//xla/tsl/platform:statusor",

third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ absl::StatusOr<std::unique_ptr<TfrtGpuBuffer>> AllocateTfrtGpuDestinationBuffer(
185185
TF_ASSIGN_OR_RETURN(
186186
auto device_buffer,
187187
MaybeOwningGpuMemory::AllocateShared(
188-
client->allocator(), device->local_device_id().value(), byte_size));
188+
client->allocator(), device->local_device_id().value(), byte_size,
189+
LayoutUtil::MemorySpace(on_device_shape)));
189190
auto buffer_async_value_ref =
190191
tsl::MakeAvailableAsyncValueRef<MaybeOwningGpuMemory>(
191192
std::move(device_buffer));
@@ -2794,15 +2795,11 @@ bool TfrtGpuBuffer::IsDeleted() {
27942795

27952796
absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuBuffer::CopyToMemorySpace(
27962797
PjRtMemorySpace* dst_memory_space) {
2797-
// TODO: b/382117736 - Support non-default memory spaces.
27982798
tsl::profiler::TraceMe traceme("TfrtGpuBuffer::CopyToMemorySpace");
27992799
PjRtDevice* dst_device = dst_memory_space->devices()[0];
28002800

2801-
// TODO(sizhi): Support copy data to the pinned host memory space.
2802-
if (dst_memory_space->kind() == PinnedHostMemorySpace::kKind) {
2803-
return Unimplemented(
2804-
"Copy data to pinned host memory space is not implemented.");
2805-
}
2801+
VLOG(2) << " TfrtGpuBuffer::CopyToMemorySpace: dst_device: " << dst_device
2802+
<< "dst_memory_space: " << dst_memory_space->kind();
28062803

28072804
// Copying across PjRtClients involves a copy through the host.
28082805
if (dst_device->client() != client_) {
@@ -2824,8 +2821,8 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuBuffer::CopyToMemorySpace(
28242821
}
28252822

28262823
// Copy each leaf buffer to a destination buffer.
2827-
auto usage_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
2828-
TrackedTfrtGpuDeviceBuffer* src_device_buffer = AcquireUsage(usage_event);
2824+
auto src_usage_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
2825+
TrackedTfrtGpuDeviceBuffer* src_device_buffer = AcquireUsage(src_usage_event);
28292826
if (src_device_buffer == nullptr) {
28302827
return InvalidArgument(
28312828
"CopyToMemorySpace called on deleted or donated buffer");
@@ -2834,50 +2831,51 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuBuffer::CopyToMemorySpace(
28342831
TfrtGpuDevice* gpu_dst_device = tsl::down_cast<TfrtGpuDevice*>(dst_device);
28352832
tsl::AsyncValueRef<MaybeOwningGpuMemory> src_buffer =
28362833
src_device_buffer->buffer();
2837-
auto dst_buffer = tsl::MakeUnconstructedAsyncValueRef<MaybeOwningGpuMemory>();
2834+
28382835
auto dst_definition_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
2836+
TF_ASSIGN_OR_RETURN(auto output_buffer,
2837+
AllocateTfrtGpuDestinationBuffer(
2838+
on_device_shape_, {dst_definition_event.CopyRef()},
2839+
gpu_dst_device, client_, dst_memory_space));
2840+
auto dst_usage_event = tsl::MakeConstructedAsyncValueRef<GpuEvent>();
2841+
TrackedTfrtGpuDeviceBuffer* allocated_dst_device_buffer =
2842+
output_buffer->AcquireUsage(dst_usage_event);
2843+
CHECK(allocated_dst_device_buffer != nullptr);
2844+
auto allocated_dst_buffer = allocated_dst_device_buffer->buffer();
28392845

28402846
absl::AnyInvocable<void()> transfer_d2d =
2841-
[src_buffer(src_buffer.CopyRef()), dst_buffer(dst_buffer.CopyRef()),
2847+
[src_buffer(src_buffer.CopyRef()),
2848+
allocated_dst_buffer(allocated_dst_buffer.CopyRef()),
28422849
dst_definition_event(dst_definition_event.CopyRef()),
28432850
src_definition_event(src_device_buffer->definition_event().CopyRef()),
2844-
dst_device(gpu_dst_device), usage_event(usage_event.CopyRef())]() {
2851+
dst_device(gpu_dst_device), src_usage_event(src_usage_event.CopyRef()),
2852+
dst_usage_event(dst_usage_event.CopyRef())]() {
28452853
tsl::profiler::TraceMe traceme("D2D copy");
2854+
2855+
MarkGpuEventReadyOnExit ready_on_exit_src(std::move(src_usage_event));
2856+
MarkGpuEventReadyOnExit ready_on_exit_dst(std::move(dst_usage_event));
2857+
28462858
if (const absl::Status* error =
28472859
dst_definition_event.GetErrorIfPresent()) {
2848-
dst_buffer.SetError(*error);
2860+
allocated_dst_buffer.SetError(*error);
28492861
dst_definition_event.SetError(*error);
2850-
usage_event.SetStateConcrete();
28512862
return;
28522863
}
28532864

28542865
if (const absl::Status* error =
28552866
src_definition_event.GetErrorIfPresent()) {
2856-
dst_buffer.SetError(*error);
2867+
allocated_dst_buffer.SetError(*error);
28572868
dst_definition_event.SetError(*error);
2858-
usage_event.SetStateConcrete();
2859-
return;
2860-
}
2861-
MarkGpuEventReadyOnExit ready_on_exit(std::move(usage_event));
2862-
absl::StatusOr<MaybeOwningGpuMemory> allocated_dst_buffer =
2863-
MaybeOwningGpuMemory::AllocateShared(
2864-
dst_device->allocator(),
2865-
dst_device->local_hardware_id().value(),
2866-
src_buffer->buffer().size());
2867-
if (!allocated_dst_buffer.ok()) {
2868-
dst_buffer.SetError(allocated_dst_buffer.status());
2869-
dst_definition_event.SetError(allocated_dst_buffer.status());
28702869
return;
28712870
}
2872-
dst_buffer.emplace(std::move(allocated_dst_buffer.value()));
28732871

28742872
absl::StatusOr<BoundedStreamPool::Handle> stream =
28752873
dst_device->stream_pool().Borrow();
28762874
if (!stream.ok()) {
28772875
dst_definition_event.SetError(stream.status());
28782876
return;
28792877
}
2880-
se::DeviceMemoryBase dst(dst_buffer->buffer());
2878+
se::DeviceMemoryBase dst(allocated_dst_buffer->buffer());
28812879
absl::Status status = stream->get()->Memcpy(
28822880
&dst, src_buffer->buffer(), src_buffer->buffer().size());
28832881
if (!status.ok()) {
@@ -2895,11 +2893,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuBuffer::CopyToMemorySpace(
28952893
EnqueueWorkWhenReady(client_->blocking_thread_pool(),
28962894
{src_device_buffer->definition_event().CopyRCRef()},
28972895
std::move(transfer_d2d));
2898-
return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtGpuBuffer>(
2899-
on_device_shape_,
2900-
std::make_unique<TrackedTfrtGpuDeviceBuffer>(
2901-
std::move(dst_buffer), std::move(dst_definition_event)),
2902-
client(), tsl::down_cast<TfrtGpuDevice*>(dst_device), dst_memory_space));
2896+
return output_buffer;
29032897
}
29042898

29052899
void TfrtGpuBuffer::DropExternalReference() {

third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -610,9 +610,6 @@ TEST(TfrtGpuClientTest, BufferFromHostBufferPinnedMemory) {
610610
}
611611

612612
TEST(TfrtGpuClientTest, CopyToPinnedHostMemorySpace) {
613-
// TODO(sizhi): Re-enable this test after the feature is implemented.
614-
GTEST_SKIP() << "Skipping this test.";
615-
616613
TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions()));
617614
std::vector<int32_t> data{1, 2, 3, 4};
618615
Shape shape = ShapeUtil::MakeShape(S32, {4});
@@ -642,9 +639,6 @@ TEST(TfrtGpuClientTest, CopyToPinnedHostMemorySpace) {
642639
}
643640

644641
TEST(TfrtGpuClientTest, CopyToPinnedHostMemorySpaceInt4) {
645-
// TODO(sizhi): Re-enable this test after the feature is implemented.
646-
GTEST_SKIP() << "Skipping this test.";
647-
648642
TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtGpuClient(GpuClientOptions()));
649643
std::vector<int8_t> data{1, 2, 3, 4};
650644
Shape shape = ShapeUtil::MakeShape(S4, {4});

third_party/xla/xla/pjrt/gpu/tfrt/tracked_tfrt_gpu_device_buffer.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
#include "xla/pjrt/gpu/tfrt/tracked_tfrt_gpu_device_buffer.h"
1616

1717
#include <cstddef>
18+
#include <cstdint>
1819
#include <functional>
1920
#include <utility>
2021

@@ -29,6 +30,7 @@ limitations under the License.
2930
#include "xla/shape_tree.h"
3031
#include "xla/stream_executor/device_memory.h"
3132
#include "xla/stream_executor/device_memory_allocator.h"
33+
#include "xla/stream_executor/stream_executor.h"
3234
#include "xla/tsl/concurrency/async_value_ref.h"
3335
#include "xla/tsl/framework/allocator.h"
3436
#include "xla/tsl/platform/statusor.h"
@@ -58,10 +60,20 @@ void MaybeOwningGpuMemory::SetUnOwned() {
5860

5961
absl::StatusOr<MaybeOwningGpuMemory> MaybeOwningGpuMemory::AllocateShared(
6062
se::DeviceMemoryAllocator* allocator, int device_ordinal, size_t size) {
63+
return AllocateShared(allocator, device_ordinal, size,
64+
static_cast<int>(se::MemoryType::kDevice));
65+
}
66+
67+
absl::StatusOr<MaybeOwningGpuMemory> MaybeOwningGpuMemory::AllocateShared(
68+
se::DeviceMemoryAllocator* allocator, int device_ordinal, size_t size,
69+
int64_t memory_space) {
6170
if (size == 0) {
6271
return MaybeOwningGpuMemory(se::DeviceMemoryBase());
6372
}
64-
TF_ASSIGN_OR_RETURN(auto memory, allocator->Allocate(device_ordinal, size));
73+
TF_ASSIGN_OR_RETURN(
74+
auto memory,
75+
allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/true,
76+
memory_space));
6577
return MaybeOwningGpuMemory(std::move(memory));
6678
}
6779

third_party/xla/xla/pjrt/gpu/tfrt/tracked_tfrt_gpu_device_buffer.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define XLA_PJRT_GPU_TFRT_TRACKED_TFRT_GPU_DEVICE_BUFFER_H_
1818

1919
#include <cstddef>
20+
#include <cstdint>
2021
#include <functional>
2122
#include <utility>
2223

@@ -79,6 +80,10 @@ class MaybeOwningGpuMemory {
7980
static absl::StatusOr<MaybeOwningGpuMemory> AllocateShared(
8081
se::DeviceMemoryAllocator* allocator, int device_ordinal, size_t size);
8182

83+
static absl::StatusOr<MaybeOwningGpuMemory> AllocateShared(
84+
se::DeviceMemoryAllocator* allocator, int device_ordinal, size_t size,
85+
int64_t memory_space);
86+
8287
stream_executor::DeviceMemoryBase buffer() const { return buffer_; }
8388
size_t size() const { return buffer_.size(); }
8489
bool owns_data() const { return !owning_buffer_.is_null(); }

0 commit comments

Comments
 (0)