11/* *
2- * Copyright 2023, XGBoost Contributors
2+ * Copyright 2023-2026 , XGBoost Contributors
33 */
44#include " broadcast.h"
55
6- #include < cmath> // for ceil, log2
76#include < cstdint> // for int32_t, int8_t
87#include < utility> // for move
8+ #include < vector> // for vector
99
10- #include " ../common/bitfield. h" // for TrailingZeroBits, RBitField32
11- #include " comm .h" // for Comm
10+ #include " comm. h" // for Comm, binomial_tree
11+ #include " topo .h" // for Parent, Child
1212#include " xgboost/collective/result.h" // for Result
1313#include " xgboost/span.h" // for Span
1414
1515namespace xgboost ::collective::cpu_impl {
1616namespace {
17- std::int32_t ShiftedParentRank (std::int32_t shifted_rank, std::int32_t depth) {
18- std::uint32_t mask{std::uint32_t {0 } - 1 }; // Oxff...
19- RBitField32 maskbits{common::Span<std::uint32_t >{&mask, 1 }};
20- RBitField32 rankbits{
21- common::Span<std::uint32_t >{reinterpret_cast <std::uint32_t *>(&shifted_rank), 1 }};
22- // prepare for counting trailing zeros.
23- for (std::int32_t i = 0 ; i < depth + 1 ; ++i) {
24- if (rankbits.Check (i)) {
25- maskbits.Set (i);
26- } else {
27- maskbits.Clear (i);
28- }
29- }
30-
31- CHECK_NE (mask, 0 );
32- auto k = TrailingZeroBits (mask);
33- auto shifted_parent = shifted_rank - (1 << k);
34- return shifted_parent;
35- }
36-
37- // Shift the root node to rank 0
38- std::int32_t ShiftLeft (std::int32_t rank, std::int32_t world, std::int32_t root) {
39- auto shifted_rank = (rank + world - root) % world;
40- return shifted_rank;
41- }
42- // shift back to the original rank
43- std::int32_t ShiftRight (std::int32_t rank, std::int32_t world, std::int32_t root) {
44- auto orig = (rank + root) % world;
45- return orig;
46- }
47- } // namespace
48-
49- Result Broadcast (Comm const & comm, common::Span<std::int8_t > data, std::int32_t root) {
50- // Binomial tree broadcast
51- // * Wiki
52- // https://en.wikipedia.org/wiki/Broadcast_(parallel_pattern)#Binomial_Tree_Broadcast
53- // * Impl
54- // https://people.mpi-inf.mpg.de/~mehlhorn/ftp/NewToolbox/collective.pdf
55-
17+ // Binomial tree broadcast using a fixed tree rooted at rank 0.
18+ Result BroadcastTree (Comm const & comm, common::Span<std::int8_t > data) {
5619 auto rank = comm.Rank ();
5720 auto world = comm.World ();
5821
59- // shift root to rank 0
60- auto shifted_rank = ShiftLeft (rank, world, root);
61- std::int32_t depth = std::ceil (std::log2 (static_cast <double >(world))) - 1 ;
62-
63- if (shifted_rank != 0 ) { // not root
64- auto parent = ShiftRight (ShiftedParentRank (shifted_rank, depth), world, root);
65- auto rc = Success () << [&] { return comm.Chan (parent)->RecvAll (data); }
66- << [&] { return comm.Chan (parent)->Block (); };
22+ if (rank != 0 ) {
23+ auto parent = binomial_tree::Parent (rank);
24+ auto rc = Success () << [&] {
25+ return comm.Chan (parent)->RecvAll (data);
26+ } << [&] {
27+ return comm.Chan (parent)->Block ();
28+ };
6729 if (!rc.OK ()) {
6830 return Fail (" broadcast failed." , std::move (rc));
6931 }
7032 }
7133
72- for (std::int32_t i = depth; i >= 0 ; --i) {
73- CHECK_GE ((i + 1 ), 0 ); // weird clang-tidy error that i might be negative
74- if (shifted_rank % (1 << (i + 1 )) == 0 && shifted_rank + (1 << i) < world) {
75- auto sft_peer = shifted_rank + (1 << i);
76- auto peer = ShiftRight (sft_peer, world, root);
77- CHECK_NE (peer, root);
78- auto rc = comm.Chan (peer)->SendAll (data);
34+ for (std::int32_t level = binomial_tree::Depth (world); level >= 0 ; --level) {
35+ if (binomial_tree::HasChild (rank, level, world)) {
36+ auto child = binomial_tree::Child (rank, level);
37+ auto rc = comm.Chan (child)->SendAll (data);
7938 if (!rc.OK ()) {
8039 return rc;
8140 }
@@ -84,4 +43,67 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
8443
8544 return comm.Block ();
8645}
46+
47+ // Compute the path from `src` to rank 0 through the binomial tree (excluding 0).
48+ std::vector<std::int32_t > TreePathToRoot (std::int32_t node) {
49+ std::vector<std::int32_t > path;
50+ auto cursor = node;
51+ while (cursor > 0 ) {
52+ path.push_back (cursor);
53+ cursor = binomial_tree::Parent (cursor);
54+ }
55+ return path;
56+ }
57+
58+ // Relay data from `node` up to rank 0 through the binomial tree.
59+ // Only nodes on the path from `node` to 0 participate; all others skip.
60+ Result RelayToRoot (Comm const & comm, common::Span<std::int8_t > data, std::int32_t node) {
61+ auto rank = comm.Rank ();
62+ auto path = TreePathToRoot (node);
63+
64+ for (auto node : path) {
65+ CHECK_GT (node, 0 );
66+ auto parent = binomial_tree::Parent (node);
67+
68+ if (rank == node) {
69+ auto rc = Success () << [&] {
70+ return comm.Chan (parent)->SendAll (data);
71+ } << [&] {
72+ return comm.Chan (parent)->Block ();
73+ };
74+ if (!rc.OK ()) {
75+ return Fail (" Relay broadcast: failed to send from " + std::to_string (node), std::move (rc));
76+ }
77+ } else if (rank == parent) {
78+ auto rc = Success () << [&] {
79+ return comm.Chan (node)->RecvAll (data);
80+ } << [&] {
81+ return comm.Chan (node)->Block ();
82+ };
83+ if (!rc.OK ()) {
84+ return Fail (" Relay broadcast: failed to recv at " + std::to_string (parent), std::move (rc));
85+ }
86+ }
87+ }
88+ return Success ();
89+ }
90+ } // namespace
91+
92+ Result Broadcast (Comm const & comm, common::Span<std::int8_t > data, std::int32_t root) {
93+ if (comm.World () <= 1 ) {
94+ return Success ();
95+ }
96+ CHECK (!data.empty ());
97+
98+ if (root == 0 ) {
99+ return BroadcastTree (comm, data);
100+ }
101+
102+ // For non-zero root, relay data up to rank 0 through the tree, then broadcast.
103+ return Success () << [&] {
104+ return RelayToRoot (comm, data, root);
105+ } << [&] {
106+ return BroadcastTree (comm, data);
107+ };
108+ }
87109} // namespace xgboost::collective::cpu_impl
0 commit comments