Skip to content

Commit f4f5597

Browse files
committed
[xla:gpu] Unify CUDA allocators under cuMemCreate allocator
1 parent 6eaf528 commit f4f5597

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+812
-1604
lines changed

xla/backends/gpu/collectives/BUILD

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,6 @@ cc_library(
546546
"//xla/runtime:device_id",
547547
"//xla/runtime:process_id",
548548
"//xla/stream_executor:stream_executor_h",
549-
"//xla/stream_executor/cuda:nccl_allocator", # buildcleaner: keep (static registration)
550549
"//xla/tsl/cuda:nccl",
551550
"//xla/tsl/platform:env",
552551
"//xla/tsl/platform:logging",
@@ -564,7 +563,6 @@ cc_library(
564563
"@com_google_absl//absl/time",
565564
"@com_google_absl//absl/types:span",
566565
"@tsl//tsl/platform:casts",
567-
"@tsl//tsl/platform:numbers",
568566
"@tsl//tsl/profiler/lib:traceme",
569567
],
570568
alwayslink = True, # registers collectives implementation
@@ -619,7 +617,6 @@ cc_library(
619617
"@local_config_rocm//rocm:rccl", # buildcleaner: keep
620618
"@local_config_rocm//rocm:rocm_headers",
621619
"@tsl//tsl/platform:casts",
622-
"@tsl//tsl/platform:numbers",
623620
],
624621
alwayslink = True, # registers collectives implementation
625622
)
@@ -784,7 +781,6 @@ cc_library(
784781
"//xla/stream_executor:device_address",
785782
"//xla/stream_executor:stream",
786783
"//xla/stream_executor/cuda:nvshmem",
787-
"//xla/stream_executor/cuda:nvshmem_allocator", # buildcleaner: keep (static registration)
788784
"//xla/tsl/concurrency:async_value",
789785
"//xla/tsl/platform:errors",
790786
"//xla/tsl/platform:statusor",

xla/backends/gpu/collectives/gpu_collectives.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,19 @@ class GpuCollectives : public Collectives {
157157
// communicator topology supports host RMA at runtime.
158158
virtual bool SupportsOneSidedComm() const { return false; }
159159

160-
// Returns minimum alignment requirement for symmetric memory.
161-
virtual size_t SymmetricMemoryAlignment() const { return 1; }
162-
163160
// Returns a slice of device memory `buff` containing `count` values of data
164161
// type `dtype` starting from `offset`.
165162
static stream_executor::DeviceAddressBase Slice(
166163
stream_executor::DeviceAddressBase buff, PrimitiveType dtype,
167164
size_t offset, size_t count);
168165

169-
// TODO(b/410686553): Use smart wrapper instead of void*.
170-
virtual absl::StatusOr<void*> Allocate(uint64_t bytes) = 0;
166+
virtual absl::StatusOr<void*> Allocate(uint64_t bytes) {
167+
return absl::UnimplementedError("Collectives allocator not available");
168+
}
171169

172-
virtual absl::Status Deallocate(void* buffer) = 0;
170+
virtual absl::Status Deallocate(void* buffer) {
171+
return absl::UnimplementedError("Collectives deallocator not available");
172+
}
173173

174174
// Creates a single communicator.
175175
virtual absl::StatusOr<std::unique_ptr<Communicator>>

xla/backends/gpu/collectives/nccl_collectives.cc

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ limitations under the License.
3838
#include "absl/time/clock.h"
3939
#include "absl/time/time.h"
4040
#include "absl/types/span.h"
41-
#include "xla/tsl/platform/status_macros.h"
4241
#include "third_party/nccl/nccl.h"
4342
#include "xla/backends/gpu/collectives/cancellation_token.h"
4443
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
@@ -60,10 +59,10 @@ limitations under the License.
6059
#include "xla/stream_executor/stream_executor.h"
6160
#include "xla/tsl/platform/env.h"
6261
#include "xla/tsl/platform/logging.h"
62+
#include "xla/tsl/platform/status_macros.h"
6363
#include "xla/tsl/platform/threadpool.h"
6464
#include "xla/util.h"
6565
#include "tsl/platform/casts.h"
66-
#include "tsl/platform/numbers.h"
6766
#include "tsl/profiler/lib/traceme.h"
6867

6968
namespace xla::gpu {
@@ -239,12 +238,6 @@ bool NcclCollectives::SupportsOneSidedComm() const {
239238
return NCCL_VERSION_CODE >= 22900;
240239
}
241240

242-
size_t NcclCollectives::SymmetricMemoryAlignment() const {
243-
// TODO(ezhulenev): Query memory alignment from CUDA executor for multicast
244-
// memory (CU_MULTICAST_GRANULARITY_MINIMUM). Find how to query it for NCCL.
245-
return 4096;
246-
}
247-
248241
static absl::StatusOr<ncclConfig_t> AsNcclConfig(
249242
const GpuCollectives::Config& config,
250243
const se::StreamExecutor* stream_executor) {
@@ -504,44 +497,6 @@ static absl::StatusOr<xla::gpu::GpuCollectives*> GetNvshmemCollectives() {
504497
return nvshmem_collectives;
505498
}
506499

507-
absl::StatusOr<void*> NcclCollectives::Allocate(uint64_t bytes) {
508-
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
509-
ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
510-
return nvshmem_collectives->Allocate(bytes);
511-
}
512-
513-
void* ptr = nullptr;
514-
ncclResult_t res = ncclMemAlloc(&ptr, bytes);
515-
if (res != ncclSuccess) {
516-
return Internal(
517-
"Failed to allocate %s (%llu bytes) from device collective memory: %s, "
518-
"Last NCCL warning(error) log entry (may be unrelated): %s",
519-
tsl::strings::HumanReadableNumBytes(bytes), bytes,
520-
ncclGetErrorString(res), ncclGetLastError(nullptr));
521-
}
522-
VLOG(2) << "Allocated collective memory " << ptr << " of " << bytes
523-
<< " bytes";
524-
return ptr;
525-
}
526-
527-
absl::Status NcclCollectives::Deallocate(void* location) {
528-
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
529-
ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
530-
return nvshmem_collectives->Deallocate(location);
531-
}
532-
533-
ncclResult_t res = ncclMemFree(location);
534-
if (res != ncclSuccess) {
535-
return Internal(
536-
"Failed to free device collective memory at %p; result: %s, Last NCCL "
537-
"warning(error) log entry (may be unrelated): %s",
538-
location, ncclGetErrorString(res), ncclGetLastError(nullptr));
539-
}
540-
541-
VLOG(2) << "Deallocated collective memory " << location;
542-
return absl::OkStatus();
543-
}
544-
545500
absl::StatusOr<GpuCollectives::CliqueIdCallback>
546501
NcclCollectives::InitializeTopology(const Topology& topology) {
547502
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {

xla/backends/gpu/collectives/nccl_collectives.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ class NcclCollectives : public GpuCollectives {
4444
bool SupportsDeviceComm() const final;
4545
bool SupportsOneSidedComm() const final;
4646

47-
size_t SymmetricMemoryAlignment() const final;
48-
4947
absl::StatusOr<CliqueId> CreateUniqueCliqueId() const final;
5048

5149
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
@@ -82,10 +80,6 @@ class NcclCollectives : public GpuCollectives {
8280
return absl::UnimplementedError("Not implemented.");
8381
}
8482

85-
absl::StatusOr<void*> Allocate(uint64_t bytes) final;
86-
87-
absl::Status Deallocate(void* location) final;
88-
8983
absl::StatusOr<CliqueIdCallback> InitializeTopology(
9084
const Topology& topology) final;
9185
};

xla/backends/gpu/collectives/rccl_collectives.cc

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ limitations under the License.
6363
#include "xla/tsl/platform/threadpool.h"
6464
#include "xla/util.h"
6565
#include "tsl/platform/casts.h"
66-
#include "tsl/platform/numbers.h"
6766

6867
#if (TF_ROCM_VERSION >= 50200)
6968
#include "rocm/include/rccl/rccl.h"
@@ -337,44 +336,6 @@ static absl::StatusOr<xla::gpu::GpuCollectives*> GetNvshmemCollectives() {
337336
return nvshmem_collectives;
338337
}
339338

340-
absl::StatusOr<void*> RcclCollectives::Allocate(uint64_t bytes) {
341-
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
342-
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
343-
return nvshmem_collectives->Allocate(bytes);
344-
}
345-
346-
void* ptr = nullptr;
347-
ncclResult_t res = ncclMemAlloc(&ptr, bytes);
348-
if (res != ncclSuccess) {
349-
return absl::InternalError(absl::StrFormat(
350-
"failed to allocate %s (%llu bytes) from device collective memory: %s, "
351-
"Last NCCL warning(error) log entry (may be unrelated): %s",
352-
tsl::strings::HumanReadableNumBytes(bytes), bytes,
353-
ncclGetErrorString(res), ncclGetLastError(nullptr)));
354-
}
355-
VLOG(2) << "Allocated collective memory " << ptr << " of " << bytes
356-
<< " bytes";
357-
return ptr;
358-
}
359-
360-
absl::Status RcclCollectives::Deallocate(void* location) {
361-
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
362-
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
363-
return nvshmem_collectives->Deallocate(location);
364-
}
365-
366-
ncclResult_t res = ncclMemFree(location);
367-
if (res != ncclSuccess) {
368-
return absl::InternalError(absl::StrFormat(
369-
"failed to free device collective memory at %p; result: %s, Last NCCL "
370-
"warning(error) log entry (may be unrelated): %s",
371-
location, ncclGetErrorString(res), ncclGetLastError(nullptr)));
372-
}
373-
374-
VLOG(2) << "Deallocated collective memory " << location;
375-
return absl::OkStatus();
376-
}
377-
378339
absl::StatusOr<CliqueIdCallback> RcclCollectives::InitializeTopology(
379340
const Topology& topology) {
380341
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {

xla/backends/gpu/collectives/rccl_collectives.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,6 @@ class RcclCollectives : public GpuCollectives {
7575
return absl::UnimplementedError("Not implemented.");
7676
}
7777

78-
absl::StatusOr<void*> Allocate(uint64_t bytes) final;
79-
80-
absl::Status Deallocate(void* location) final;
81-
8278
absl::StatusOr<CliqueIdCallback> InitializeTopology(
8379
const Topology& topology) final;
8480
};

xla/backends/gpu/tests/all_reduce_e2e_test.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ class AllReduceTestNoParams : public CollectiveOpsWithFlagsBase {
9292
explicit AllReduceTestNoParams(bool is_async = false)
9393
: CollectiveOpsWithFlagsBase(/*enable_async=*/is_async,
9494
/*enable_p2p_memcpy=*/false,
95-
/*memory_size=*/32 * kMB,
96-
/*collectives_memory_size=*/0) {}
95+
/*memory_size=*/32 * kMB) {}
9796

9897
void SetUp() override {
9998
CollectiveOpsE2ETestBase::SetUp();

xla/backends/gpu/tests/collective_ops_e2e_test.cc

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,8 @@ bool IsAsync(const HloInstruction* inst) {
7575

7676
class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
7777
public:
78-
explicit CollectiveOpsTestE2E(size_t memory_size = 128 * kMB,
79-
size_t collectives_memory_size = 0)
80-
: CollectiveOpsE2ETestBase(memory_size, collectives_memory_size) {}
78+
explicit CollectiveOpsTestE2E(size_t memory_size = 128 * kMB)
79+
: CollectiveOpsE2ETestBase(memory_size) {}
8180

8281
bool HasFp8Support() {
8382
if (Capability().IsCuda()) {
@@ -133,8 +132,7 @@ class AsyncCollectiveOps : public CollectiveOpsWithFlagsBase,
133132
AsyncCollectiveOps()
134133
: CollectiveOpsWithFlagsBase(/*enable_async=*/GetParam(),
135134
/*enable_p2p_memcpy=*/false,
136-
/*memory_size=*/8 * kGB,
137-
/*collectives_memory_size=*/0) {}
135+
/*memory_size=*/8 * kGB) {}
138136
};
139137

140138
class MemcpyCollectiveOps : public CollectiveOpsWithFlagsBase,
@@ -143,8 +141,7 @@ class MemcpyCollectiveOps : public CollectiveOpsWithFlagsBase,
143141
MemcpyCollectiveOps()
144142
: CollectiveOpsWithFlagsBase(/*enable_async=*/true,
145143
/*enable_p2p_memcpy=*/GetParam(),
146-
/*memory_size=*/32 * kMB,
147-
/*collectives_memory_size=*/0) {}
144+
/*memory_size=*/32 * kMB) {}
148145
};
149146

150147
class AsyncMemcpyCollectiveOps
@@ -155,8 +152,7 @@ class AsyncMemcpyCollectiveOps
155152
: CollectiveOpsWithFlagsBase(
156153
/*enable_async=*/std::get<0>(GetParam()),
157154
/*enable_p2p_memcpy=*/std::get<1>(GetParam()),
158-
/*memory_size=*/32 * kMB,
159-
/*collectives_memory_size=*/0) {}
155+
/*memory_size=*/32 * kMB) {}
160156
};
161157

162158
std::string GetAsyncTestName(bool is_async) {
@@ -1269,8 +1265,7 @@ TEST_F(CollectiveOpsTestE2E, HostMemoryOffloadingWithDonation) {
12691265
class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
12701266
public:
12711267
CollectiveOpsTestE2EWindowedNonWindowed()
1272-
: CollectiveOpsTestE2E(/*memory_size=*/4 * kGB,
1273-
/*collectives_memory_size=*/0) {}
1268+
: CollectiveOpsTestE2E(/*memory_size=*/4 * kGB) {}
12741269

12751270
void CollectiveOpsCompareWindowedNonWindowed(
12761271
absl::string_view hlo_text, bool disable_dot_merger = false,

xla/backends/gpu/tests/collective_ops_e2e_test_base.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ limitations under the License.
2727
#include "absl/log/log.h"
2828
#include "absl/status/statusor.h"
2929
#include "absl/strings/string_view.h"
30-
#include "xla/tsl/platform/status_macros.h"
3130
#include "xla/backends/gpu/tests/hlo_pjrt_gpu_test_base.h"
3231
#include "xla/hlo/ir/hlo_instruction.h"
3332
#include "xla/hlo/ir/hlo_module.h"
@@ -41,19 +40,18 @@ limitations under the License.
4140
#include "xla/service/gpu/backend_configs.pb.h"
4241
#include "xla/service/hlo_module_config.h"
4342
#include "xla/service/hlo_runner_interface.h"
43+
#include "xla/tsl/platform/status_macros.h"
4444
#include "xla/tsl/platform/statusor.h"
4545
#include "xla/xla.pb.h"
4646
#include "xla/xla_data.pb.h"
4747

4848
namespace xla {
4949
namespace {
5050

51-
std::unique_ptr<PjRtClient> CreatePjRtClient(size_t memory_size,
52-
size_t collectives_memory_size) {
51+
std::unique_ptr<PjRtClient> CreatePjRtClient(size_t memory_size) {
5352
xla::GpuClientOptions options;
5453
options.allocator_config.kind = xla::GpuAllocatorConfig::Kind::kBFC;
5554
options.allocator_config.gpu_system_memory_size = memory_size;
56-
options.allocator_config.collective_memory_size = collectives_memory_size;
5755
options.use_tfrt_gpu_client = true;
5856

5957
absl::StatusOr<std::unique_ptr<xla::PjRtClient>> pjrt_client =
@@ -64,10 +62,8 @@ std::unique_ptr<PjRtClient> CreatePjRtClient(size_t memory_size,
6462

6563
} // namespace
6664

67-
CollectiveOpsE2ETestBase::CollectiveOpsE2ETestBase(
68-
size_t memory_size, size_t collectives_memory_size)
69-
: HloPjRtGpuTestBase(
70-
CreatePjRtClient(memory_size, collectives_memory_size)) {}
65+
CollectiveOpsE2ETestBase::CollectiveOpsE2ETestBase(size_t memory_size)
66+
: HloPjRtGpuTestBase(CreatePjRtClient(memory_size)) {}
7167

7268
absl::StatusOr<CollectiveOpsE2ETestBase::ExecutionResult>
7369
CollectiveOpsE2ETestBase::ExecuteReplicated(std::unique_ptr<HloModule> module) {

xla/backends/gpu/tests/collective_ops_e2e_test_base.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ inline constexpr size_t kGB = 1024LL * kMB;
4444

4545
class CollectiveOpsE2ETestBase : public gpu::HloPjRtGpuTestBase {
4646
public:
47-
CollectiveOpsE2ETestBase(size_t memory_size, size_t collectives_memory_size);
47+
explicit CollectiveOpsE2ETestBase(size_t memory_size);
4848

4949
DebugOptions GetDebugOptionsForTest() const override {
5050
DebugOptions debug_options =
@@ -104,8 +104,8 @@ class CollectiveOpsE2ETestBase : public gpu::HloPjRtGpuTestBase {
104104
class CollectiveOpsWithFlagsBase : public CollectiveOpsE2ETestBase {
105105
public:
106106
CollectiveOpsWithFlagsBase(bool enable_async, bool enable_p2p_memcpy,
107-
size_t memory_size, size_t collectives_memory_size)
108-
: CollectiveOpsE2ETestBase(memory_size, collectives_memory_size),
107+
size_t memory_size)
108+
: CollectiveOpsE2ETestBase(memory_size),
109109
enable_async_(enable_async),
110110
enable_p2p_memcpy_(enable_p2p_memcpy) {}
111111

0 commit comments

Comments
 (0)