Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions xla/backends/gpu/collectives/nccl_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,23 @@ class NcclCommunicator::NcclRegisteredBufferHandle
};

//==-----------------------------------------------------------------------===//
// NCCL Device Communicator
// NCCL Communicator
//==-----------------------------------------------------------------------===//

NcclCommunicator::NcclCommunicator(se::StreamExecutor* stream_executor,
ncclComm_t comm,
std::unique_ptr<tsl::Executor> executor,
std::shared_ptr<CancellationToken> cancel)
: stream_executor_(stream_executor),
comm_(comm),
executor_(std::move(executor)),
cancel_(std::move(cancel)),
supports_one_sided_comm_(QuerySupportsOneSidedComm()) {
VLOG(1) << absl::StreamFormat("[%d] Created NCCL communicator %s",
stream_executor_->device_ordinal(),
this->ToString());
}

bool NcclCommunicator::SupportsDeviceComm() const {
#if NCCL_VERSION_CODE >= 22800
return true;
Expand All @@ -252,6 +266,10 @@ bool NcclCommunicator::SupportsDeviceComm() const {
}

bool NcclCommunicator::SupportsOneSidedComm() const {
return supports_one_sided_comm_;
}

bool NcclCommunicator::QuerySupportsOneSidedComm() const {
#if NCCL_VERSION_CODE >= 22907
ncclCommProperties_t props = NCCL_COMM_PROPERTIES_INITIALIZER;
if (ncclCommQueryProperties(comm_, &props) == ncclSuccess) {
Expand All @@ -277,19 +295,11 @@ NcclCommunicator::CreateDeviceComm(
#endif // NCCL_VERSION_CODE >= 22800
}

//==-----------------------------------------------------------------------===//
// NCCL Symmetric Memory
//==-----------------------------------------------------------------------===//

absl::StatusOr<std::unique_ptr<SymmetricMemory>>
NcclCommunicator::CreateSymmetricMemory(se::DeviceAddressBase addr) {
return NcclSymmetricMemory::Create(comm_, addr);
}

//==-----------------------------------------------------------------------===//
// NCCL Communicator
//==-----------------------------------------------------------------------===//

absl::StatusOr<std::unique_ptr<NcclCommunicator>> NcclCommunicator::Create(
se::StreamExecutor* stream_executor,
absl::AnyInvocable<absl::StatusOr<ncclComm_t>()> make_comm,
Expand Down Expand Up @@ -757,7 +767,7 @@ absl::Status NcclCommunicator::LaunchAllGather(

// If all buffers are contiguous returns a device address range that covers
// all of them, otherwise returns an empty optional.
static std::optional<se::DeviceAddressBase> IsContinguous(
static std::optional<se::DeviceAddressBase> IsContiguous(
absl::Span<const se::DeviceAddressBase> buffers) {
if (buffers.empty()) {
return std::nullopt;
Expand Down Expand Up @@ -794,8 +804,8 @@ absl::Status NcclCommunicator::LaunchAllToAll(
absl::StrAppendFormat(out, "%p", buffer.opaque());
};

auto send_contiguous = IsContinguous(send_buffers);
auto recv_contiguous = IsContinguous(recv_buffers);
auto send_contiguous = IsContiguous(send_buffers);
auto recv_contiguous = IsContiguous(recv_buffers);

VLOG(3) << absl::StreamFormat(
"[%d] Launch NCCL AllToAll operation; send_buffers=[%s]; "
Expand Down
16 changes: 7 additions & 9 deletions xla/backends/gpu/collectives/nccl_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,7 @@ class NcclCommunicator : public GpuCommunicator {

NcclCommunicator(se::StreamExecutor* stream_executor, ncclComm_t comm,
std::unique_ptr<tsl::Executor> executor,
std::shared_ptr<CancellationToken> cancel)
: stream_executor_(stream_executor),
comm_(comm),
executor_(std::move(executor)),
cancel_(std::move(cancel)) {
VLOG(1) << absl::StreamFormat("[%d] Created NCCL communicator %s",
stream_executor_->device_ordinal(),
this->ToString());
}
std::shared_ptr<CancellationToken> cancel);

absl::Status GroupStart();
absl::Status GroupEnd();
Expand Down Expand Up @@ -239,6 +231,9 @@ class NcclCommunicator : public GpuCommunicator {
const SignalDesc& signal_desc,
const Executor& executor) final;

// Queries NCCL for one-sided comm support. Called once at construction.
bool QuerySupportsOneSidedComm() const;

// Polls the communicator until any pending non-blocking operations are "done"
// or aborted.
absl::Status PollUntilDone() const;
Expand Down Expand Up @@ -291,6 +286,9 @@ class NcclCommunicator : public GpuCommunicator {
// Has comm_ been aborted?
bool aborted_ = false;

// Cached result of querying NCCL for one-sided comm support.
bool supports_one_sided_comm_ = false;

// Nesting level of current NCCL group
int group_nesting_level_ = 0;

Expand Down
Loading