11/* *
2- * Copyright 2023-2024 , XGBoost Contributors
2+ * Copyright 2023-2026 , XGBoost Contributors
33 */
44#include " comm.h"
55
@@ -108,7 +108,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
108108
109109 rc = std::move (rc) << [&] {
110110 return cpu_impl::RingAllgather (comm, s_buffer, HOST_NAME_MAX, 0 , prev_ch, next_ch);
111- } << [&] { return block (); };
111+ } << [&] {
112+ return block ();
113+ };
112114 if (!rc.OK ()) {
113115 return Fail (" Failed to get host names from peers." , std::move (rc));
114116 }
@@ -119,7 +121,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
119121 auto s_ports = common::Span{reinterpret_cast <std::int8_t *>(peers_port.data ()),
120122 peers_port.size () * sizeof (ninfo.port )};
121123 return cpu_impl::RingAllgather (comm, s_ports, sizeof (ninfo.port ), 0 , prev_ch, next_ch);
122- } << [&] { return block (); };
124+ } << [&] {
125+ return block ();
126+ };
123127 if (!rc.OK ()) {
124128 return Fail (" Failed to get the port from peers." , std::move (rc));
125129 }
@@ -143,9 +147,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
143147 for (std::int32_t r = (comm.Rank () + 1 ); r < comm.World (); ++r) {
144148 auto const & peer = peers[r];
145149 auto worker = std::make_shared<TCPSocket>();
146- rc = std::move (rc)
147- << [&] { return Connect (peer.host , peer.port , retry, timeout, worker.get ()); }
148- << [&] { return worker->RecvTimeout (timeout); };
150+ rc = std::move (rc) << [&] {
151+ return Connect (peer.host , peer.port , retry, timeout, worker.get ());
152+ } << [&] {
153+ return worker->RecvTimeout (timeout);
154+ };
149155 if (!rc.OK ()) {
150156 return rc;
151157 }
@@ -204,17 +210,18 @@ std::string InitLog(std::string task_id, std::int32_t rank) {
204210
205211RabitComm::RabitComm (std::string const & tracker_host, std::int32_t tracker_port,
206212 std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
207- StringView nccl_path)
213+ StringView nccl_path, std:: int32_t worker_port )
208214 : HostComm{tracker_host, tracker_port, timeout, retry, std::move (task_id)},
209- nccl_path_{std::move (nccl_path)} {
215+ nccl_path_{std::move (nccl_path)},
216+ worker_port_{worker_port} {
210217 if (this ->TrackerInfo ().host .empty ()) {
211218 // Not in a distributed environment.
212219 LOG (CONSOLE) << InitLog (task_id_, rank_);
213220 return ;
214221 }
215222
216223 loop_.reset (new Loop{std::chrono::seconds{timeout_}}); // NOLINT
217- auto rc = this ->Bootstrap (timeout_, retry_, task_id_);
224+ auto rc = this ->Bootstrap (timeout_, retry_, task_id_, worker_port_ );
218225 if (!rc.OK ()) {
219226 this ->ResetState ();
220227 SafeColl (Fail (" Failed to bootstrap the communication group." , std::move (rc)));
@@ -230,7 +237,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
230237#endif // !defined(XGBOOST_USE_NCCL)
231238
232239[[nodiscard]] Result RabitComm::Bootstrap (std::chrono::seconds timeout, std::int32_t retry,
233- std::string task_id) {
240+ std::string task_id, std:: int32_t worker_port ) {
234241 TCPSocket tracker;
235242 std::int32_t world{-1 };
236243 auto rc = ConnectTrackerImpl (this ->TrackerInfo (), timeout, retry, task_id, &tracker, this ->Rank (),
@@ -243,8 +250,14 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
243250
244251 // Start command
245252 TCPSocket listener = TCPSocket::Create (tracker.Domain ());
246- std::int32_t lport{0 };
253+ std::int32_t lport{worker_port };
247254 rc = std::move (rc) << [&] {
255+ if (lport > 0 ) {
256+ // User-specified port, bind to INADDR_ANY with the given port.
257+ auto addr = (tracker.Domain () == SockDomain::kV6 ) ? " ::" : " 0.0.0.0" ;
258+ return listener.Bind (addr, &lport);
259+ }
260+ // Default: let the OS pick an available port.
248261 return listener.BindHost (&lport);
249262 } << [&] {
250263 return listener.Listen ();
@@ -306,8 +319,11 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
306319 error_worker_.detach ();
307320
308321 proto::Start start;
309- rc = std::move (rc) << [&] { return start.WorkerSend (lport, &tracker, eport); }
310- << [&] { return start.WorkerRecv (&tracker, &world); };
322+ rc = std::move (rc) << [&] {
323+ return start.WorkerSend (lport, &tracker, eport);
324+ } << [&] {
325+ return start.WorkerRecv (&tracker, &world);
326+ };
311327 if (!rc.OK ()) {
312328 return rc;
313329 }
@@ -418,8 +434,11 @@ RabitComm::~RabitComm() noexcept(false) {
418434 }
419435 TCPSocket out;
420436 proto::Print print;
421- return Success () << [&] { return this ->ConnectTracker (&out); }
422- << [&] { return print.WorkerSend (&out, msg); };
437+ return Success () << [&] {
438+ return this ->ConnectTracker (&out);
439+ } << [&] {
440+ return print.WorkerSend (&out, msg);
441+ };
423442}
424443
425444[[nodiscard]] Result RabitComm::SignalError (Result const & res) {
0 commit comments