Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,6 @@ cc_library(
"//xla/runtime:device_id",
"//xla/runtime:process_id",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/cuda:nccl_allocator", # buildcleaner: keep (static registration)
"//xla/tsl/cuda:nccl",
"//xla/tsl/platform:env",
"//xla/tsl/platform:logging",
Expand All @@ -564,7 +563,6 @@ cc_library(
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:numbers",
"@tsl//tsl/profiler/lib:traceme",
],
alwayslink = True, # registers collectives implementation
Expand Down Expand Up @@ -619,7 +617,6 @@ cc_library(
"@local_config_rocm//rocm:rccl", # buildcleaner: keep
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:numbers",
],
alwayslink = True, # registers collectives implementation
)
Expand Down Expand Up @@ -784,7 +781,6 @@ cc_library(
"//xla/stream_executor:device_address",
"//xla/stream_executor:stream",
"//xla/stream_executor/cuda:nvshmem",
"//xla/stream_executor/cuda:nvshmem_allocator", # buildcleaner: keep (static registration)
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
Expand Down
12 changes: 6 additions & 6 deletions xla/backends/gpu/collectives/gpu_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,19 @@ class GpuCollectives : public Collectives {
// communicator topology supports host RMA at runtime.
virtual bool SupportsOneSidedComm() const { return false; }

// Returns minimum alignment requirement for symmetric memory.
virtual size_t SymmetricMemoryAlignment() const { return 1; }

// Returns a slice of device memory `buff` containing `count` values of data
// type `dtype` starting from `offset`.
static stream_executor::DeviceAddressBase Slice(
stream_executor::DeviceAddressBase buff, PrimitiveType dtype,
size_t offset, size_t count);

// TODO(b/410686553): Use smart wrapper instead of void*.
virtual absl::StatusOr<void*> Allocate(uint64_t bytes) = 0;
virtual absl::StatusOr<void*> Allocate(uint64_t bytes) {
return absl::UnimplementedError("Collectives allocator not available");
}

virtual absl::Status Deallocate(void* buffer) = 0;
virtual absl::Status Deallocate(void* buffer) {
return absl::UnimplementedError("Collectives deallocator not available");
}

// Creates a single communicator.
virtual absl::StatusOr<std::unique_ptr<Communicator>>
Expand Down
47 changes: 1 addition & 46 deletions xla/backends/gpu/collectives/nccl_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ limitations under the License.
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/tsl/platform/status_macros.h"
#include "third_party/nccl/nccl.h"
#include "xla/backends/gpu/collectives/cancellation_token.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
Expand All @@ -60,10 +59,10 @@ limitations under the License.
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/status_macros.h"
#include "xla/tsl/platform/threadpool.h"
#include "xla/util.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/numbers.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::gpu {
Expand Down Expand Up @@ -239,12 +238,6 @@ bool NcclCollectives::SupportsOneSidedComm() const {
return NCCL_VERSION_CODE >= 22900;
}

size_t NcclCollectives::SymmetricMemoryAlignment() const {
// TODO(ezhulenev): Query memory alignment from CUDA executor for multicast
// memory (CU_MULTICAST_GRANULARITY_MINIMUM). Find how to query it for NCCL.
return 4096;
}

static absl::StatusOr<ncclConfig_t> AsNcclConfig(
const GpuCollectives::Config& config,
const se::StreamExecutor* stream_executor) {
Expand Down Expand Up @@ -504,44 +497,6 @@ static absl::StatusOr<xla::gpu::GpuCollectives*> GetNvshmemCollectives() {
return nvshmem_collectives;
}

absl::StatusOr<void*> NcclCollectives::Allocate(uint64_t bytes) {
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
return nvshmem_collectives->Allocate(bytes);
}

void* ptr = nullptr;
ncclResult_t res = ncclMemAlloc(&ptr, bytes);
if (res != ncclSuccess) {
return Internal(
"Failed to allocate %s (%llu bytes) from device collective memory: %s, "
"Last NCCL warning(error) log entry (may be unrelated): %s",
tsl::strings::HumanReadableNumBytes(bytes), bytes,
ncclGetErrorString(res), ncclGetLastError(nullptr));
}
VLOG(2) << "Allocated collective memory " << ptr << " of " << bytes
<< " bytes";
return ptr;
}

absl::Status NcclCollectives::Deallocate(void* location) {
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
return nvshmem_collectives->Deallocate(location);
}

ncclResult_t res = ncclMemFree(location);
if (res != ncclSuccess) {
return Internal(
"Failed to free device collective memory at %p; result: %s, Last NCCL "
"warning(error) log entry (may be unrelated): %s",
location, ncclGetErrorString(res), ncclGetLastError(nullptr));
}

VLOG(2) << "Deallocated collective memory " << location;
return absl::OkStatus();
}

absl::StatusOr<GpuCollectives::CliqueIdCallback>
NcclCollectives::InitializeTopology(const Topology& topology) {
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
Expand Down
6 changes: 0 additions & 6 deletions xla/backends/gpu/collectives/nccl_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ class NcclCollectives : public GpuCollectives {
bool SupportsDeviceComm() const final;
bool SupportsOneSidedComm() const final;

size_t SymmetricMemoryAlignment() const final;

absl::StatusOr<CliqueId> CreateUniqueCliqueId() const final;

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
Expand Down Expand Up @@ -82,10 +80,6 @@ class NcclCollectives : public GpuCollectives {
return absl::UnimplementedError("Not implemented.");
}

absl::StatusOr<void*> Allocate(uint64_t bytes) final;

absl::Status Deallocate(void* location) final;

absl::StatusOr<CliqueIdCallback> InitializeTopology(
const Topology& topology) final;
};
Expand Down
39 changes: 0 additions & 39 deletions xla/backends/gpu/collectives/rccl_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ limitations under the License.
#include "xla/tsl/platform/threadpool.h"
#include "xla/util.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/numbers.h"

#if (TF_ROCM_VERSION >= 50200)
#include "rocm/include/rccl/rccl.h"
Expand Down Expand Up @@ -337,44 +336,6 @@ static absl::StatusOr<xla::gpu::GpuCollectives*> GetNvshmemCollectives() {
return nvshmem_collectives;
}

absl::StatusOr<void*> RcclCollectives::Allocate(uint64_t bytes) {
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
return nvshmem_collectives->Allocate(bytes);
}

void* ptr = nullptr;
ncclResult_t res = ncclMemAlloc(&ptr, bytes);
if (res != ncclSuccess) {
return absl::InternalError(absl::StrFormat(
"failed to allocate %s (%llu bytes) from device collective memory: %s, "
"Last NCCL warning(error) log entry (may be unrelated): %s",
tsl::strings::HumanReadableNumBytes(bytes), bytes,
ncclGetErrorString(res), ncclGetLastError(nullptr)));
}
VLOG(2) << "Allocated collective memory " << ptr << " of " << bytes
<< " bytes";
return ptr;
}

absl::Status RcclCollectives::Deallocate(void* location) {
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
return nvshmem_collectives->Deallocate(location);
}

ncclResult_t res = ncclMemFree(location);
if (res != ncclSuccess) {
return absl::InternalError(absl::StrFormat(
"failed to free device collective memory at %p; result: %s, Last NCCL "
"warning(error) log entry (may be unrelated): %s",
location, ncclGetErrorString(res), ncclGetLastError(nullptr)));
}

VLOG(2) << "Deallocated collective memory " << location;
return absl::OkStatus();
}

absl::StatusOr<CliqueIdCallback> RcclCollectives::InitializeTopology(
const Topology& topology) {
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
Expand Down
4 changes: 0 additions & 4 deletions xla/backends/gpu/collectives/rccl_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ class RcclCollectives : public GpuCollectives {
return absl::UnimplementedError("Not implemented.");
}

absl::StatusOr<void*> Allocate(uint64_t bytes) final;

absl::Status Deallocate(void* location) final;

absl::StatusOr<CliqueIdCallback> InitializeTopology(
const Topology& topology) final;
};
Expand Down
3 changes: 1 addition & 2 deletions xla/backends/gpu/tests/all_reduce_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ class AllReduceTestNoParams : public CollectiveOpsWithFlagsBase {
explicit AllReduceTestNoParams(bool is_async = false)
: CollectiveOpsWithFlagsBase(/*enable_async=*/is_async,
/*enable_p2p_memcpy=*/false,
/*memory_size=*/32 * kMB,
/*collectives_memory_size=*/0) {}
/*memory_size=*/32 * kMB) {}

void SetUp() override {
CollectiveOpsE2ETestBase::SetUp();
Expand Down
17 changes: 6 additions & 11 deletions xla/backends/gpu/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ bool IsAsync(const HloInstruction* inst) {

class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
public:
explicit CollectiveOpsTestE2E(size_t memory_size = 128 * kMB,
size_t collectives_memory_size = 0)
: CollectiveOpsE2ETestBase(memory_size, collectives_memory_size) {}
explicit CollectiveOpsTestE2E(size_t memory_size = 128 * kMB)
: CollectiveOpsE2ETestBase(memory_size) {}

bool HasFp8Support() {
if (Capability().IsCuda()) {
Expand Down Expand Up @@ -133,8 +132,7 @@ class AsyncCollectiveOps : public CollectiveOpsWithFlagsBase,
AsyncCollectiveOps()
: CollectiveOpsWithFlagsBase(/*enable_async=*/GetParam(),
/*enable_p2p_memcpy=*/false,
/*memory_size=*/8 * kGB,
/*collectives_memory_size=*/0) {}
/*memory_size=*/8 * kGB) {}
};

class MemcpyCollectiveOps : public CollectiveOpsWithFlagsBase,
Expand All @@ -143,8 +141,7 @@ class MemcpyCollectiveOps : public CollectiveOpsWithFlagsBase,
MemcpyCollectiveOps()
: CollectiveOpsWithFlagsBase(/*enable_async=*/true,
/*enable_p2p_memcpy=*/GetParam(),
/*memory_size=*/32 * kMB,
/*collectives_memory_size=*/0) {}
/*memory_size=*/32 * kMB) {}
};

class AsyncMemcpyCollectiveOps
Expand All @@ -155,8 +152,7 @@ class AsyncMemcpyCollectiveOps
: CollectiveOpsWithFlagsBase(
/*enable_async=*/std::get<0>(GetParam()),
/*enable_p2p_memcpy=*/std::get<1>(GetParam()),
/*memory_size=*/32 * kMB,
/*collectives_memory_size=*/0) {}
/*memory_size=*/32 * kMB) {}
};

std::string GetAsyncTestName(bool is_async) {
Expand Down Expand Up @@ -1269,8 +1265,7 @@ TEST_F(CollectiveOpsTestE2E, HostMemoryOffloadingWithDonation) {
class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
public:
CollectiveOpsTestE2EWindowedNonWindowed()
: CollectiveOpsTestE2E(/*memory_size=*/4 * kGB,
/*collectives_memory_size=*/0) {}
: CollectiveOpsTestE2E(/*memory_size=*/4 * kGB) {}

void CollectiveOpsCompareWindowedNonWindowed(
absl::string_view hlo_text, bool disable_dot_merger = false,
Expand Down
12 changes: 4 additions & 8 deletions xla/backends/gpu/tests/collective_ops_e2e_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/platform/status_macros.h"
#include "xla/backends/gpu/tests/hlo_pjrt_gpu_test_base.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
Expand All @@ -41,19 +40,18 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/tsl/platform/status_macros.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"

namespace xla {
namespace {

std::unique_ptr<PjRtClient> CreatePjRtClient(size_t memory_size,
size_t collectives_memory_size) {
std::unique_ptr<PjRtClient> CreatePjRtClient(size_t memory_size) {
xla::GpuClientOptions options;
options.allocator_config.kind = xla::GpuAllocatorConfig::Kind::kBFC;
options.allocator_config.gpu_system_memory_size = memory_size;
options.allocator_config.collective_memory_size = collectives_memory_size;
options.use_tfrt_gpu_client = true;

absl::StatusOr<std::unique_ptr<xla::PjRtClient>> pjrt_client =
Expand All @@ -64,10 +62,8 @@ std::unique_ptr<PjRtClient> CreatePjRtClient(size_t memory_size,

} // namespace

CollectiveOpsE2ETestBase::CollectiveOpsE2ETestBase(
size_t memory_size, size_t collectives_memory_size)
: HloPjRtGpuTestBase(
CreatePjRtClient(memory_size, collectives_memory_size)) {}
CollectiveOpsE2ETestBase::CollectiveOpsE2ETestBase(size_t memory_size)
: HloPjRtGpuTestBase(CreatePjRtClient(memory_size)) {}

absl::StatusOr<CollectiveOpsE2ETestBase::ExecutionResult>
CollectiveOpsE2ETestBase::ExecuteReplicated(std::unique_ptr<HloModule> module) {
Expand Down
6 changes: 3 additions & 3 deletions xla/backends/gpu/tests/collective_ops_e2e_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ inline constexpr size_t kGB = 1024LL * kMB;

class CollectiveOpsE2ETestBase : public gpu::HloPjRtGpuTestBase {
public:
CollectiveOpsE2ETestBase(size_t memory_size, size_t collectives_memory_size);
explicit CollectiveOpsE2ETestBase(size_t memory_size);

DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options =
Expand Down Expand Up @@ -104,8 +104,8 @@ class CollectiveOpsE2ETestBase : public gpu::HloPjRtGpuTestBase {
class CollectiveOpsWithFlagsBase : public CollectiveOpsE2ETestBase {
public:
CollectiveOpsWithFlagsBase(bool enable_async, bool enable_p2p_memcpy,
size_t memory_size, size_t collectives_memory_size)
: CollectiveOpsE2ETestBase(memory_size, collectives_memory_size),
size_t memory_size)
: CollectiveOpsE2ETestBase(memory_size),
enable_async_(enable_async),
enable_p2p_memcpy_(enable_p2p_memcpy) {}

Expand Down
4 changes: 1 addition & 3 deletions xla/backends/gpu/tests/collective_ops_ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ absl::NoDestructor<std::unique_ptr<SynchronizationSignals>> global_signals;

class CollectiveOpsTestFFI : public CollectiveOpsE2ETestBase {
public:
CollectiveOpsTestFFI()
: CollectiveOpsE2ETestBase(/*memory_size=*/1 * kMB,
/*collectives_memory_size=*/1 * kMB) {}
CollectiveOpsTestFFI() : CollectiveOpsE2ETestBase(/*memory_size=*/50 * kMB) {}
void SetUp() override {
CollectiveOpsE2ETestBase::SetUp();
*global_signals =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ namespace {
class CollectiveOpsTestE2EShardedUnsharded : public CollectiveOpsE2ETestBase {
public:
CollectiveOpsTestE2EShardedUnsharded()
: CollectiveOpsE2ETestBase(/*memory_size=*/64 * kMB,
/*collectives_memory_size=*/0) {}
: CollectiveOpsE2ETestBase(/*memory_size=*/64 * kMB) {}

void CollectiveOpsCompareShardedUnsharded(
const std::string& hlo_text, const int64_t num_partitions = 2,
Expand Down
7 changes: 3 additions & 4 deletions xla/backends/gpu/tests/ragged_all_to_all_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ class RaggedAllToAllTestBase : public CollectiveOpsWithFlagsBase {
public:
RaggedAllToAllTestBase(bool enable_async, RaggedAllToAllImplType impl_type)
: CollectiveOpsWithFlagsBase(enable_async, /*enable_p2p_memcpy=*/false,
/*memory_size=*/64 * kMB,
/*collectives_memory_size=*/0),
/*memory_size=*/64 * kMB),
impl_type_(impl_type) {}

// Creates random test data for a ragged-all-to-all.
Expand Down Expand Up @@ -482,8 +481,8 @@ TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs_S4) {
send_sizes = s32[2] parameter(3)
output_offsets = s32[2] parameter(4)
recv_sizes = s32[2] parameter(5)
ROOT ra2a = s4[4,2]{1,0:E(4)} ragged-all-to-all(input, output,
input_offsets, send_sizes, output_offsets, recv_sizes),
ROOT ra2a = s4[4,2]{1,0:E(4)} ragged-all-to-all(input, output,
input_offsets, send_sizes, output_offsets, recv_sizes),
replica_groups={{0,1}}
})";

Expand Down
Loading
Loading