@@ -240,9 +240,23 @@ class NcclCommunicator::NcclRegisteredBufferHandle
240240};
241241
242242// ==-----------------------------------------------------------------------===//
243- // NCCL Device Communicator
243+ // NCCL Communicator
244244// ==-----------------------------------------------------------------------===//
245245
246+ NcclCommunicator::NcclCommunicator (se::StreamExecutor* stream_executor,
247+ ncclComm_t comm,
248+ std::unique_ptr<tsl::Executor> executor,
249+ std::shared_ptr<CancellationToken> cancel)
250+ : stream_executor_(stream_executor),
251+ comm_ (comm),
252+ executor_(std::move(executor)),
253+ cancel_(std::move(cancel)),
254+ supports_one_sided_comm_(QuerySupportsOneSidedComm()) {
255+ VLOG (1 ) << absl::StreamFormat (" [%d] Created NCCL communicator %s" ,
256+ stream_executor_->device_ordinal (),
257+ this ->ToString ());
258+ }
259+
246260bool NcclCommunicator::SupportsDeviceComm () const {
247261#if NCCL_VERSION_CODE >= 22800
248262 return true ;
@@ -252,6 +266,10 @@ bool NcclCommunicator::SupportsDeviceComm() const {
252266}
253267
254268bool NcclCommunicator::SupportsOneSidedComm () const {
269+ return supports_one_sided_comm_;
270+ }
271+
272+ bool NcclCommunicator::QuerySupportsOneSidedComm () const {
255273#if NCCL_VERSION_CODE >= 22907
256274 ncclCommProperties_t props = NCCL_COMM_PROPERTIES_INITIALIZER;
257275 if (ncclCommQueryProperties (comm_, &props) == ncclSuccess) {
@@ -277,19 +295,11 @@ NcclCommunicator::CreateDeviceComm(
277295#endif // NCCL_VERSION_CODE >= 22800
278296}
279297
280- // ==-----------------------------------------------------------------------===//
281- // NCCL Symmetric Memory
282- // ==-----------------------------------------------------------------------===//
283-
284298absl::StatusOr<std::unique_ptr<SymmetricMemory>>
285299NcclCommunicator::CreateSymmetricMemory (se::DeviceAddressBase addr) {
286300 return NcclSymmetricMemory::Create (comm_, addr);
287301}
288302
289- // ==-----------------------------------------------------------------------===//
290- // NCCL Communicator
291- // ==-----------------------------------------------------------------------===//
292-
293303absl::StatusOr<std::unique_ptr<NcclCommunicator>> NcclCommunicator::Create (
294304 se::StreamExecutor* stream_executor,
295305 absl::AnyInvocable<absl::StatusOr<ncclComm_t>()> make_comm,
@@ -757,7 +767,7 @@ absl::Status NcclCommunicator::LaunchAllGather(
757767
758768// If all buffers are contiguous returns a device address range that covers
759769// all of them, otherwise returns an empty optional.
760- static std::optional<se::DeviceAddressBase> IsContinguous (
770+ static std::optional<se::DeviceAddressBase> IsContiguous (
761771 absl::Span<const se::DeviceAddressBase> buffers) {
762772 if (buffers.empty ()) {
763773 return std::nullopt ;
@@ -794,8 +804,8 @@ absl::Status NcclCommunicator::LaunchAllToAll(
794804 absl::StrAppendFormat (out, " %p" , buffer.opaque ());
795805 };
796806
797- auto send_contiguous = IsContinguous (send_buffers);
798- auto recv_contiguous = IsContinguous (recv_buffers);
807+ auto send_contiguous = IsContiguous (send_buffers);
808+ auto recv_contiguous = IsContiguous (recv_buffers);
799809
800810 VLOG (3 ) << absl::StreamFormat (
801811 " [%d] Launch NCCL AllToAll operation; send_buffers=[%s]; "
0 commit comments