Skip to content

Commit c6f43a8

Browse files
Rollback PR #37936: since it breaks ragged_collective_test_gpu in JAX when machine have >2 GPUs.
Reverts 13760ad PiperOrigin-RevId: 900696964
1 parent 85c65a2 commit c6f43a8

File tree

7 files changed

+13
-60
lines changed

7 files changed

+13
-60
lines changed

xla/backends/gpu/collectives/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ cc_library(
7575
"//xla/core/collectives:communicator",
7676
"//xla/core/collectives:rank_id",
7777
"//xla/service:lockable",
78-
"//xla/service:rendezvous",
7978
"//xla/tsl/platform:logging",
8079
"//xla/tsl/util:tied_ref",
8180
"@com_google_absl//absl/base:core_headers",

xla/backends/gpu/collectives/gpu_clique.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ limitations under the License.
3434
#include "xla/core/collectives/communicator.h"
3535
#include "xla/core/collectives/rank_id.h"
3636
#include "xla/service/lockable.h"
37-
#include "xla/service/rendezvous.h"
3837
#include "xla/tsl/util/tied_ref.h"
3938

4039
namespace xla::gpu {
@@ -91,11 +90,6 @@ class GpuClique : public Clique {
9190
// Returns a parent clique iff *this one was created by clique splitting.
9291
const GpuClique* parent() const { return parent_; }
9392

94-
std::pair<RendezvousFlag*, RendezvousFlag*> GetFirstRendezvousFlags() {
95-
return std::make_pair(&pre_call_rendezvous_flag_,
96-
&post_call_rendezvous_flag_);
97-
}
98-
9993
private:
10094
friend LockableGpuClique;
10195

@@ -117,15 +111,6 @@ class GpuClique : public Clique {
117111
// A parent GPU clique iff *this clique was constructed by split operation.
118112
const GpuClique* parent_;
119113

120-
// Before and after a first call to this particular instance of a collective
121-
// thunk we do a round of rendezvous to make sure that all participants are
122-
// ready to execute the collective operation and that all of them successfully
123-
// allocated on-device state required for it. This is required to avoid
124-
// deadlocks when one device goes too far ahead and causes a deadlock in CUDA
125-
// driver (root cause rumored to be fixed in 590 driver series).
126-
RendezvousFlag pre_call_rendezvous_flag_;
127-
RendezvousFlag post_call_rendezvous_flag_;
128-
129114
// We keep device communicators in a sorted container to guarantee that they
130115
// are destroyed in deterministic order.
131116
mutable absl::Mutex mu_;

xla/backends/gpu/runtime/collective_cliques.cc

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -126,32 +126,6 @@ absl::StatusOr<bool> CollectiveCliques::peer_access_enabled(
126126
return (*clique->second)->peer_access_enabled();
127127
}
128128

129-
absl::StatusOr<std::pair<RendezvousFlag*, RendezvousFlag*>>
130-
CollectiveCliques::GetCliqueFirstRendezvousFlags(
131-
const GpuCliqueKey& clique_key) const {
132-
// Check that we locked access to a clique for `clique_key`.
133-
auto clique = cliques_map_.find(clique_key);
134-
if (clique == cliques_map_.end()) {
135-
return NotFound("No clique found for clique key: %s",
136-
clique_key.ToString());
137-
}
138-
return (*clique->second)->GetFirstRendezvousFlags();
139-
}
140-
141-
absl::StatusOr<bool> AllFirstRendezvousCompleted(
142-
const CollectiveCliques& collective_cliques,
143-
const std::vector<GpuCliqueKey>& requested_clique_keys) {
144-
return collective_cliques.empty() ||
145-
absl::c_all_of(
146-
requested_clique_keys, [&](const GpuCliqueKey& clique_key) {
147-
auto rend_flags =
148-
collective_cliques.GetCliqueFirstRendezvousFlags(clique_key);
149-
CHECK(rend_flags.ok());
150-
return rend_flags.value().first->IsCompleted() &&
151-
rend_flags.value().second->IsCompleted();
152-
});
153-
}
154-
155129
absl::StatusOr<CollectiveCliques> AcquireCollectiveCliques(
156130
const CollectiveParams& params, const CollectiveCliqueRequests& cliques) {
157131
std::vector<CollectiveCliqueRequests::CliqueRequest> ordered_cliques =

xla/backends/gpu/runtime/collective_cliques.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ class CollectiveCliques {
6767

6868
bool empty() const { return cliques_map_.empty(); }
6969

70-
absl::StatusOr<std::pair<RendezvousFlag*, RendezvousFlag*>>
71-
GetCliqueFirstRendezvousFlags(const GpuCliqueKey& clique_key) const;
72-
7370
private:
7471
AcquiredCliquesMap cliques_map_;
7572
};
@@ -94,10 +91,6 @@ absl::StatusOr<tsl::TiedRef<T>> CollectiveCliques::Tie(
9491
absl::StatusOr<CollectiveCliques> AcquireCollectiveCliques(
9592
const CollectiveParams& params, const CollectiveCliqueRequests& cliques);
9693

97-
absl::StatusOr<bool> AllFirstRendezvousCompleted(
98-
const CollectiveCliques& collective_cliques,
99-
const std::vector<GpuCliqueKey>& requested_clique_keys);
100-
10194
} // namespace xla::gpu
10295

10396
#endif // XLA_BACKENDS_GPU_RUNTIME_COLLECTIVE_CLIQUES_H_

xla/backends/gpu/runtime/collective_thunk.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,16 +385,13 @@ absl::Status CollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
385385
debug_options
386386
.xla_gpu_first_collective_call_terminate_timeout_seconds()));
387387
};
388-
std::pair<RendezvousFlag*, RendezvousFlag*> rend_flags;
389-
ASSIGN_OR_RETURN(
390-
rend_flags,
391-
params.collective_cliques->GetCliqueFirstRendezvousFlags(clique_key));
392-
RETURN_IF_ERROR(first_call_rendezvous("before", *(rend_flags.first)));
388+
389+
RETURN_IF_ERROR(first_call_rendezvous("before", pre_call_rendezvous_flag_));
393390

394391
// Launch collective operation on the compute stream.
395392
RETURN_IF_ERROR(RunCollective(params, clique_key, *params.stream, *comm));
396393

397-
RETURN_IF_ERROR(first_call_rendezvous("after", *(rend_flags.second)));
394+
RETURN_IF_ERROR(first_call_rendezvous("after", post_call_rendezvous_flag_));
398395

399396
return absl::OkStatus();
400397
}

xla/backends/gpu/runtime/collective_thunk.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@ class CollectiveThunk : public Thunk {
163163
virtual const CollectiveConfig& config() const = 0;
164164

165165
private:
166+
// Before and after a first call to this particular instance of a collective
167+
// thunk we do a round of rendezvous to make sure that all participants are
168+
// ready to execute the collective operation and that all of them successfully
169+
// allocated on-device state required for it. This is required to avoid
170+
// deadlocks when one device goes too far ahead and causes a deadlock in CUDA
171+
// driver (root cause rumored to be fixed in 590 driver series).
172+
RendezvousFlag pre_call_rendezvous_flag_;
173+
RendezvousFlag post_call_rendezvous_flag_;
174+
166175
CommunicationId communication_id_;
167176

168177
// Device assignment is owned by PjRtExecutable and never changes between

xla/service/gpu/gpu_executable.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -747,10 +747,6 @@ absl::Status ExecuteThunksImpl(const DebugOptions* debug_options,
747747
AcquireCollectiveCliques(collective_params,
748748
collective_clique_requests));
749749
}
750-
ASSIGN_OR_RETURN(
751-
bool skip_rendezvous_after_init,
752-
AllFirstRendezvousCompleted(
753-
collective_cliques, collective_clique_requests.RequestedCliques()));
754750

755751
ASSIGN_OR_RETURN(ScratchMemory scratch_memory,
756752
AcquireScratchMemory(
@@ -785,7 +781,7 @@ absl::Status ExecuteThunksImpl(const DebugOptions* debug_options,
785781
// collective operations and clique initialization is famous for introducing
786782
// deadlocks if we try to execute it concurrently with other potentially
787783
// memory-allocating operations.
788-
if (!skip_rendezvous_after_init) {
784+
if (!collective_cliques.empty()) {
789785
RETURN_IF_ERROR(RendezvousAfterInitialization(*run_options, debug_options));
790786
}
791787

0 commit comments

Comments
 (0)