Skip to content

Commit 51a31f9

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
PR #40967: [xla:gpu] Cache NCCL Communicator properties
Imported from GitHub PR #40967 Getting them on each call to NCCL API is expensive! Query properties in constructor and cache them! Copybara import of the project: -- 56437d2 by Eugene Zhulenev <ezhulenev@openxla.org>: [xla:gpu] Cache NCCL Communicator properties Merging this change closes #40967 FUTURE_COPYBARA_INTEGRATE_REVIEW=#40967 from ezhulenev:cache-nccl-comm-properties 56437d2 PiperOrigin-RevId: 900983474
1 parent 3bca370 commit 51a31f9

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

xla/backends/gpu/collectives/nccl_communicator.cc

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,23 @@ class NcclCommunicator::NcclRegisteredBufferHandle
240240
};
241241

242242
//==-----------------------------------------------------------------------===//
243-
// NCCL Device Communicator
243+
// NCCL Communicator
244244
//==-----------------------------------------------------------------------===//
245245

246+
NcclCommunicator::NcclCommunicator(se::StreamExecutor* stream_executor,
247+
ncclComm_t comm,
248+
std::unique_ptr<tsl::Executor> executor,
249+
std::shared_ptr<CancellationToken> cancel)
250+
: stream_executor_(stream_executor),
251+
comm_(comm),
252+
executor_(std::move(executor)),
253+
cancel_(std::move(cancel)),
254+
supports_one_sided_comm_(QuerySupportsOneSidedComm()) {
255+
VLOG(1) << absl::StreamFormat("[%d] Created NCCL communicator %s",
256+
stream_executor_->device_ordinal(),
257+
this->ToString());
258+
}
259+
246260
bool NcclCommunicator::SupportsDeviceComm() const {
247261
#if NCCL_VERSION_CODE >= 22800
248262
return true;
@@ -252,6 +266,10 @@ bool NcclCommunicator::SupportsDeviceComm() const {
252266
}
253267

254268
bool NcclCommunicator::SupportsOneSidedComm() const {
269+
return supports_one_sided_comm_;
270+
}
271+
272+
bool NcclCommunicator::QuerySupportsOneSidedComm() const {
255273
#if NCCL_VERSION_CODE >= 22907
256274
ncclCommProperties_t props = NCCL_COMM_PROPERTIES_INITIALIZER;
257275
if (ncclCommQueryProperties(comm_, &props) == ncclSuccess) {
@@ -277,19 +295,11 @@ NcclCommunicator::CreateDeviceComm(
277295
#endif // NCCL_VERSION_CODE >= 22800
278296
}
279297

280-
//==-----------------------------------------------------------------------===//
281-
// NCCL Symmetric Memory
282-
//==-----------------------------------------------------------------------===//
283-
284298
absl::StatusOr<std::unique_ptr<SymmetricMemory>>
285299
NcclCommunicator::CreateSymmetricMemory(se::DeviceAddressBase addr) {
286300
return NcclSymmetricMemory::Create(comm_, addr);
287301
}
288302

289-
//==-----------------------------------------------------------------------===//
290-
// NCCL Communicator
291-
//==-----------------------------------------------------------------------===//
292-
293303
absl::StatusOr<std::unique_ptr<NcclCommunicator>> NcclCommunicator::Create(
294304
se::StreamExecutor* stream_executor,
295305
absl::AnyInvocable<absl::StatusOr<ncclComm_t>()> make_comm,
@@ -757,7 +767,7 @@ absl::Status NcclCommunicator::LaunchAllGather(
757767

758768
// If all buffers are contiguous returns a device address range that covers
759769
// all of them, otherwise returns an empty optional.
760-
static std::optional<se::DeviceAddressBase> IsContinguous(
770+
static std::optional<se::DeviceAddressBase> IsContiguous(
761771
absl::Span<const se::DeviceAddressBase> buffers) {
762772
if (buffers.empty()) {
763773
return std::nullopt;
@@ -794,8 +804,8 @@ absl::Status NcclCommunicator::LaunchAllToAll(
794804
absl::StrAppendFormat(out, "%p", buffer.opaque());
795805
};
796806

797-
auto send_contiguous = IsContinguous(send_buffers);
798-
auto recv_contiguous = IsContinguous(recv_buffers);
807+
auto send_contiguous = IsContiguous(send_buffers);
808+
auto recv_contiguous = IsContiguous(recv_buffers);
799809

800810
VLOG(3) << absl::StreamFormat(
801811
"[%d] Launch NCCL AllToAll operation; send_buffers=[%s]; "

xla/backends/gpu/collectives/nccl_communicator.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,7 @@ class NcclCommunicator : public GpuCommunicator {
172172

173173
NcclCommunicator(se::StreamExecutor* stream_executor, ncclComm_t comm,
174174
std::unique_ptr<tsl::Executor> executor,
175-
std::shared_ptr<CancellationToken> cancel)
176-
: stream_executor_(stream_executor),
177-
comm_(comm),
178-
executor_(std::move(executor)),
179-
cancel_(std::move(cancel)) {
180-
VLOG(1) << absl::StreamFormat("[%d] Created NCCL communicator %s",
181-
stream_executor_->device_ordinal(),
182-
this->ToString());
183-
}
175+
std::shared_ptr<CancellationToken> cancel);
184176

185177
absl::Status GroupStart();
186178
absl::Status GroupEnd();
@@ -239,6 +231,9 @@ class NcclCommunicator : public GpuCommunicator {
239231
const SignalDesc& signal_desc,
240232
const Executor& executor) final;
241233

234+
// Queries NCCL for one-sided comm support. Called once at construction.
235+
bool QuerySupportsOneSidedComm() const;
236+
242237
// Polls the communicator until any pending non-blocking operations are "done"
243238
// or aborted.
244239
absl::Status PollUntilDone() const;
@@ -291,6 +286,9 @@ class NcclCommunicator : public GpuCommunicator {
291286
// Has comm_ been aborted?
292287
bool aborted_ = false;
293288

289+
// Cached result of querying NCCL for one-sided comm support.
290+
bool supports_one_sided_comm_ = false;
291+
294292
// Nesting level of current NCCL group
295293
int group_nesting_level_ = 0;
296294

0 commit comments

Comments
 (0)