Skip to content

Commit 03c196f

Browse files
authored
[coll] Prevent all-to-all connection. (#12075)
1 parent b20ce6d commit 03c196f

File tree

8 files changed

+390
-224
lines changed

8 files changed

+390
-224
lines changed

src/collective/allgather.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#pragma once
55
#include <cstddef> // for size_t
@@ -13,6 +13,7 @@
1313
#include "../common/type.h" // for EraseType
1414
#include "comm.h" // for Comm, Channel
1515
#include "comm_group.h" // for CommGroup
16+
#include "topo.h" // for BootstrapNext, BootstrapPrev
1617
#include "xgboost/collective/result.h" // for Result
1718
#include "xgboost/linalg.h" // for MakeVec
1819
#include "xgboost/span.h" // for Span
@@ -142,8 +143,8 @@ template <typename T>
142143
std::vector<std::int64_t> sizes(comm.World(), 0);
143144
sizes[comm.Rank()] = data.Values().size_bytes();
144145
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
145-
auto rc = comm.Backend(DeviceOrd::CPU())
146-
->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
146+
auto rc =
147+
comm.Backend(DeviceOrd::CPU())->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
147148
if (!rc.OK()) {
148149
return rc;
149150
}
@@ -161,7 +162,8 @@ template <typename T>
161162

162163
return backend->AllgatherV(
163164
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
164-
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast);
165+
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(),
166+
data.Device().IsCUDA() ? AllgatherVAlgo::kBcast : AllgatherVAlgo::kRing);
165167
}
166168

167169
/**

src/collective/allreduce.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#pragma once
55
#include <cstdint> // for int8_t
@@ -12,6 +12,7 @@
1212
#include "broadcast.h" // for Broadcast
1313
#include "comm.h" // for Comm, RestoreType
1414
#include "comm_group.h" // for GlobalCommGroup
15+
#include "topo.h" // for ParentLevel, Parent, Child
1516
#include "xgboost/collective/result.h" // for Result
1617
#include "xgboost/context.h" // for Context
1718
#include "xgboost/span.h" // for Span
@@ -147,25 +148,24 @@ AllreduceV(Comm const& comm, std::vector<T>* data, Fn redop) {
147148
};
148149
};
149150

150-
auto shifted_rank = rank;
151151
std::vector<T> incoming;
152152
std::vector<T> out;
153153
bool continue_reduce = true;
154-
for (std::int32_t step = 1; step < world; step <<= 1) {
154+
for (std::int32_t level = 0; (std::int32_t{1} << level) < world; ++level) {
155155
if (!continue_reduce) {
156156
continue;
157157
}
158-
if (shifted_rank % (step * 2) == step) {
159-
auto parent = shifted_rank - step;
158+
if (rank > 0 && binomial_tree::ParentLevel(rank) == level) {
159+
auto parent = binomial_tree::Parent(rank);
160160
auto rc = send(parent, *data);
161161
if (!rc.OK()) {
162162
return Fail("AllreduceV failed to send data to parent.", std::move(rc));
163163
}
164164
continue_reduce = false;
165165
continue;
166166
}
167-
if (shifted_rank % (step * 2) == 0 && shifted_rank + step < world) {
168-
auto child = shifted_rank + step;
167+
if (binomial_tree::HasChild(rank, level, world)) {
168+
auto child = binomial_tree::Child(rank, level);
169169
auto rc = recv(child, &incoming);
170170
if (!rc.OK()) {
171171
return Fail("AllreduceV failed to receive data from child.", std::move(rc));

src/collective/broadcast.cc

Lines changed: 80 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,40 @@
11
/**
2-
* Copyright 2023, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#include "broadcast.h"
55

6-
#include <cmath> // for ceil, log2
76
#include <cstdint> // for int32_t, int8_t
87
#include <utility> // for move
8+
#include <vector> // for vector
99

10-
#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32
11-
#include "comm.h" // for Comm
10+
#include "comm.h" // for Comm, binomial_tree
11+
#include "topo.h" // for Parent, Child
1212
#include "xgboost/collective/result.h" // for Result
1313
#include "xgboost/span.h" // for Span
1414

1515
namespace xgboost::collective::cpu_impl {
1616
namespace {
17-
std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) {
18-
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
19-
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
20-
RBitField32 rankbits{
21-
common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&shifted_rank), 1}};
22-
// prepare for counting trailing zeros.
23-
for (std::int32_t i = 0; i < depth + 1; ++i) {
24-
if (rankbits.Check(i)) {
25-
maskbits.Set(i);
26-
} else {
27-
maskbits.Clear(i);
28-
}
29-
}
30-
31-
CHECK_NE(mask, 0);
32-
auto k = TrailingZeroBits(mask);
33-
auto shifted_parent = shifted_rank - (1 << k);
34-
return shifted_parent;
35-
}
36-
37-
// Shift the root node to rank 0
38-
std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) {
39-
auto shifted_rank = (rank + world - root) % world;
40-
return shifted_rank;
41-
}
42-
// shift back to the original rank
43-
std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) {
44-
auto orig = (rank + root) % world;
45-
return orig;
46-
}
47-
} // namespace
48-
49-
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
50-
// Binomial tree broadcast
51-
// * Wiki
52-
// https://en.wikipedia.org/wiki/Broadcast_(parallel_pattern)#Binomial_Tree_Broadcast
53-
// * Impl
54-
// https://people.mpi-inf.mpg.de/~mehlhorn/ftp/NewToolbox/collective.pdf
55-
17+
// Binomial tree broadcast using a fixed tree rooted at rank 0.
18+
Result BroadcastTree(Comm const& comm, common::Span<std::int8_t> data) {
5619
auto rank = comm.Rank();
5720
auto world = comm.World();
5821

59-
// shift root to rank 0
60-
auto shifted_rank = ShiftLeft(rank, world, root);
61-
std::int32_t depth = std::ceil(std::log2(static_cast<double>(world))) - 1;
62-
63-
if (shifted_rank != 0) { // not root
64-
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
65-
auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); }
66-
<< [&] { return comm.Chan(parent)->Block(); };
22+
if (rank != 0) {
23+
auto parent = binomial_tree::Parent(rank);
24+
auto rc = Success() << [&] {
25+
return comm.Chan(parent)->RecvAll(data);
26+
} << [&] {
27+
return comm.Chan(parent)->Block();
28+
};
6729
if (!rc.OK()) {
6830
return Fail("broadcast failed.", std::move(rc));
6931
}
7032
}
7133

72-
for (std::int32_t i = depth; i >= 0; --i) {
73-
CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative
74-
if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) {
75-
auto sft_peer = shifted_rank + (1 << i);
76-
auto peer = ShiftRight(sft_peer, world, root);
77-
CHECK_NE(peer, root);
78-
auto rc = comm.Chan(peer)->SendAll(data);
34+
for (std::int32_t level = binomial_tree::Depth(world); level >= 0; --level) {
35+
if (binomial_tree::HasChild(rank, level, world)) {
36+
auto child = binomial_tree::Child(rank, level);
37+
auto rc = comm.Chan(child)->SendAll(data);
7938
if (!rc.OK()) {
8039
return rc;
8140
}
@@ -84,4 +43,67 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
8443

8544
return comm.Block();
8645
}
46+
47+
// Compute the path from `src` to rank 0 through the binomial tree (excluding 0).
48+
std::vector<std::int32_t> TreePathToRoot(std::int32_t node) {
49+
std::vector<std::int32_t> path;
50+
auto cursor = node;
51+
while (cursor > 0) {
52+
path.push_back(cursor);
53+
cursor = binomial_tree::Parent(cursor);
54+
}
55+
return path;
56+
}
57+
58+
// Relay data from `node` up to rank 0 through the binomial tree.
59+
// Only nodes on the path from `node` to 0 participate; all others skip.
60+
Result RelayToRoot(Comm const& comm, common::Span<std::int8_t> data, std::int32_t node) {
61+
auto rank = comm.Rank();
62+
auto path = TreePathToRoot(node);
63+
64+
for (auto node : path) {
65+
CHECK_GT(node, 0);
66+
auto parent = binomial_tree::Parent(node);
67+
68+
if (rank == node) {
69+
auto rc = Success() << [&] {
70+
return comm.Chan(parent)->SendAll(data);
71+
} << [&] {
72+
return comm.Chan(parent)->Block();
73+
};
74+
if (!rc.OK()) {
75+
return Fail("Relay broadcast: failed to send from " + std::to_string(node), std::move(rc));
76+
}
77+
} else if (rank == parent) {
78+
auto rc = Success() << [&] {
79+
return comm.Chan(node)->RecvAll(data);
80+
} << [&] {
81+
return comm.Chan(node)->Block();
82+
};
83+
if (!rc.OK()) {
84+
return Fail("Relay broadcast: failed to recv at " + std::to_string(parent), std::move(rc));
85+
}
86+
}
87+
}
88+
return Success();
89+
}
90+
} // namespace
91+
92+
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
93+
if (comm.World() <= 1) {
94+
return Success();
95+
}
96+
CHECK(!data.empty());
97+
98+
if (root == 0) {
99+
return BroadcastTree(comm, data);
100+
}
101+
102+
// For non-zero root, relay data up to rank 0 through the tree, then broadcast.
103+
return Success() << [&] {
104+
return RelayToRoot(comm, data, root);
105+
} << [&] {
106+
return BroadcastTree(comm, data);
107+
};
108+
}
87109
} // namespace xgboost::collective::cpu_impl

0 commit comments

Comments
 (0)