Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ cc_library(
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service:lockable",
"//xla/service:rendezvous",
"//xla/tsl/platform:logging",
"//xla/tsl/util:tied_ref",
"@com_google_absl//absl/base:core_headers",
Expand Down
15 changes: 0 additions & 15 deletions xla/backends/gpu/collectives/gpu_clique.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ limitations under the License.
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/lockable.h"
#include "xla/service/rendezvous.h"
#include "xla/tsl/util/tied_ref.h"

namespace xla::gpu {
Expand Down Expand Up @@ -91,11 +90,6 @@ class GpuClique : public Clique {
// Returns a parent clique iff *this one was created by clique splitting.
const GpuClique* parent() const { return parent_; }

std::pair<RendezvousFlag*, RendezvousFlag*> GetFirstRendezvousFlags() {
return std::make_pair(&pre_call_rendezvous_flag_,
&post_call_rendezvous_flag_);
}

private:
friend LockableGpuClique;

Expand All @@ -117,15 +111,6 @@ class GpuClique : public Clique {
// A parent GPU clique iff *this clique was constructed by split operation.
const GpuClique* parent_;

// Before and after a first call to this particular instance of a collective
// thunk we do a round of rendezvous to make sure that all participants are
// ready to execute the collective operation and that all of them successfully
// allocated on-device state required for it. This is required to avoid
// deadlocks when one device goes too far ahead and causes a deadlock in CUDA
// driver (root cause rumored to be fixed in 590 driver series).
RendezvousFlag pre_call_rendezvous_flag_;
RendezvousFlag post_call_rendezvous_flag_;

// We keep device communicators in a sorted container to guarantee that they
// are destroyed in deterministic order.
mutable absl::Mutex mu_;
Expand Down
26 changes: 0 additions & 26 deletions xla/backends/gpu/runtime/collective_cliques.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,32 +126,6 @@ absl::StatusOr<bool> CollectiveCliques::peer_access_enabled(
return (*clique->second)->peer_access_enabled();
}

absl::StatusOr<std::pair<RendezvousFlag*, RendezvousFlag*>>
CollectiveCliques::GetCliqueFirstRendezvousFlags(
const GpuCliqueKey& clique_key) const {
// Check that we locked access to a clique for `clique_key`.
auto clique = cliques_map_.find(clique_key);
if (clique == cliques_map_.end()) {
return NotFound("No clique found for clique key: %s",
clique_key.ToString());
}
return (*clique->second)->GetFirstRendezvousFlags();
}

absl::StatusOr<bool> AllFirstRendezvousCompleted(
const CollectiveCliques& collective_cliques,
const std::vector<GpuCliqueKey>& requested_clique_keys) {
return collective_cliques.empty() ||
absl::c_all_of(
requested_clique_keys, [&](const GpuCliqueKey& clique_key) {
auto rend_flags =
collective_cliques.GetCliqueFirstRendezvousFlags(clique_key);
CHECK(rend_flags.ok());
return rend_flags.value().first->IsCompleted() &&
rend_flags.value().second->IsCompleted();
});
}

absl::StatusOr<CollectiveCliques> AcquireCollectiveCliques(
const CollectiveParams& params, const CollectiveCliqueRequests& cliques) {
std::vector<CollectiveCliqueRequests::CliqueRequest> ordered_cliques =
Expand Down
7 changes: 0 additions & 7 deletions xla/backends/gpu/runtime/collective_cliques.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ class CollectiveCliques {

bool empty() const { return cliques_map_.empty(); }

absl::StatusOr<std::pair<RendezvousFlag*, RendezvousFlag*>>
GetCliqueFirstRendezvousFlags(const GpuCliqueKey& clique_key) const;

private:
AcquiredCliquesMap cliques_map_;
};
Expand All @@ -94,10 +91,6 @@ absl::StatusOr<tsl::TiedRef<T>> CollectiveCliques::Tie(
absl::StatusOr<CollectiveCliques> AcquireCollectiveCliques(
const CollectiveParams& params, const CollectiveCliqueRequests& cliques);

absl::StatusOr<bool> AllFirstRendezvousCompleted(
const CollectiveCliques& collective_cliques,
const std::vector<GpuCliqueKey>& requested_clique_keys);

} // namespace xla::gpu

#endif // XLA_BACKENDS_GPU_RUNTIME_COLLECTIVE_CLIQUES_H_
9 changes: 3 additions & 6 deletions xla/backends/gpu/runtime/collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,16 +385,13 @@ absl::Status CollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
debug_options
.xla_gpu_first_collective_call_terminate_timeout_seconds()));
};
std::pair<RendezvousFlag*, RendezvousFlag*> rend_flags;
ASSIGN_OR_RETURN(
rend_flags,
params.collective_cliques->GetCliqueFirstRendezvousFlags(clique_key));
RETURN_IF_ERROR(first_call_rendezvous("before", *(rend_flags.first)));

RETURN_IF_ERROR(first_call_rendezvous("before", pre_call_rendezvous_flag_));

// Launch collective operation on the compute stream.
RETURN_IF_ERROR(RunCollective(params, clique_key, *params.stream, *comm));

RETURN_IF_ERROR(first_call_rendezvous("after", *(rend_flags.second)));
RETURN_IF_ERROR(first_call_rendezvous("after", post_call_rendezvous_flag_));

return absl::OkStatus();
}
Expand Down
9 changes: 9 additions & 0 deletions xla/backends/gpu/runtime/collective_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ class CollectiveThunk : public Thunk {
virtual const CollectiveConfig& config() const = 0;

private:
// Before and after a first call to this particular instance of a collective
// thunk we do a round of rendezvous to make sure that all participants are
// ready to execute the collective operation and that all of them successfully
// allocated on-device state required for it. This is required to avoid
// deadlocks when one device goes too far ahead and causes a deadlock in CUDA
// driver (root cause rumored to be fixed in 590 driver series).
RendezvousFlag pre_call_rendezvous_flag_;
RendezvousFlag post_call_rendezvous_flag_;

CommunicationId communication_id_;

// Device assignment is owned by PjRtExecutable and never changes between
Expand Down
6 changes: 1 addition & 5 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -747,10 +747,6 @@ absl::Status ExecuteThunksImpl(const DebugOptions* debug_options,
AcquireCollectiveCliques(collective_params,
collective_clique_requests));
}
ASSIGN_OR_RETURN(
bool skip_rendezvous_after_init,
AllFirstRendezvousCompleted(
collective_cliques, collective_clique_requests.RequestedCliques()));

ASSIGN_OR_RETURN(ScratchMemory scratch_memory,
AcquireScratchMemory(
Expand Down Expand Up @@ -785,7 +781,7 @@ absl::Status ExecuteThunksImpl(const DebugOptions* debug_options,
// collective operations and clique initialization is famous for introducing
// deadlocks if we try to execute it concurrently with other potentially
// memory-allocating operations.
if (!skip_rendezvous_after_init) {
if (!collective_cliques.empty()) {
RETURN_IF_ERROR(RendezvousAfterInitialization(*run_options, debug_options));
}

Expand Down
Loading