Skip to content

Commit 0219159

Browse files
authored
[coll] Support worker port. (#12010)
1 parent e7358f9 commit 0219159

File tree

11 files changed

+148
-46
lines changed

11 files changed

+148
-46
lines changed

python-package/xgboost/collective.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pickle
77
from dataclasses import dataclass
88
from enum import IntEnum, unique
9-
from typing import Any, Dict, Optional, TypeAlias, Union
9+
from typing import Any, Callable, Dict, Optional, TypeAlias, Union
1010

1111
import numpy as np
1212

@@ -46,6 +46,21 @@ class Config:
4646
4747
tracker_timeout : See :py:class:`~xgboost.tracker.RabitTracker`.
4848
49+
worker_port :
50+
51+
The port each worker listens to for peer-to-peer connections. By default,
52+
workers use an available port assigned by the OS. This option can be used in
53+
restricted network environments where only specific ports are open.
54+
55+
This can be an integer for a fixed port used by all workers, or a callback
56+
function that takes no arguments and returns a port number. The callback is
57+
invoked per-worker at the worker side.
58+
59+
.. note::
60+
61+
The option does not affect the NCCL communicator group, which must be
62+
configured via NCCL's own environment variables.
63+
4964
"""
5065

5166
retry: Optional[int] = None
@@ -55,6 +70,18 @@ class Config:
5570
tracker_port: Optional[int] = None
5671
tracker_timeout: Optional[int] = None
5772

73+
worker_port: Optional[Union[Callable[[], int], int]] = None
74+
75+
def update_worker_args(self, args: _Conf) -> _Conf:
76+
"""Worker side arguments resolution."""
77+
if self.worker_port is None:
78+
return args
79+
if callable(self.worker_port):
80+
args["dmlc_worker_port"] = self.worker_port()
81+
else:
82+
args["dmlc_worker_port"] = self.worker_port
83+
return args
84+
5885
def get_comm_config(self, args: _Conf) -> _Conf:
5986
"""Update the arguments for the communicator."""
6087
if self.retry is not None:

python-package/xgboost/dask/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,9 @@ def do_train( # pylint: disable=too-many-positional-arguments
761761
local_history: TrainingCallback.EvalsLog = {}
762762
global_config.update({"nthread": n_threads})
763763

764+
if coll_cfg is not None:
765+
coll_args = coll_cfg.update_worker_args(coll_args)
766+
764767
with CommunicatorContext(**coll_args), config.config_context(**global_config):
765768
Xy, evals = _get_dmatrices(
766769
train_ref,

python-package/xgboost/spark/core.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,9 @@ def _train_booster(
11201120
if not launch_tracker_on_driver:
11211121
_rabit_args = json.loads(messages[0])["rabit_msg"]
11221122

1123+
if conf is not None:
1124+
_rabit_args = conf.update_worker_args(_rabit_args)
1125+
11231126
evals_result: Dict[str, Any] = {}
11241127
with (
11251128
config_context(verbosity=verbosity, use_rmm=use_rmm),
@@ -1641,7 +1644,11 @@ def saveMetadata(
16411644
if instance.isDefined("coll_cfg"):
16421645
conf: Config = instance.getOrDefault("coll_cfg")
16431646
if conf is not None:
1644-
extraMetadata["coll_cfg"] = asdict(conf)
1647+
extraMetadata["coll_cfg"] = {
1648+
k: v for k, v in asdict(conf).items() if not callable(v)
1649+
}
1650+
if callable(conf.worker_port):
1651+
logger.warning("The `worker_port` is not serialized.")
16451652

16461653
DefaultParamsWriter.saveMetadata(
16471654
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams

src/collective/comm.cc

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#include "comm.h"
55

@@ -108,7 +108,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
108108

109109
rc = std::move(rc) << [&] {
110110
return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch);
111-
} << [&] { return block(); };
111+
} << [&] {
112+
return block();
113+
};
112114
if (!rc.OK()) {
113115
return Fail("Failed to get host names from peers.", std::move(rc));
114116
}
@@ -119,7 +121,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
119121
auto s_ports = common::Span{reinterpret_cast<std::int8_t*>(peers_port.data()),
120122
peers_port.size() * sizeof(ninfo.port)};
121123
return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch);
122-
} << [&] { return block(); };
124+
} << [&] {
125+
return block();
126+
};
123127
if (!rc.OK()) {
124128
return Fail("Failed to get the port from peers.", std::move(rc));
125129
}
@@ -143,9 +147,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
143147
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
144148
auto const& peer = peers[r];
145149
auto worker = std::make_shared<TCPSocket>();
146-
rc = std::move(rc)
147-
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
148-
<< [&] { return worker->RecvTimeout(timeout); };
150+
rc = std::move(rc) << [&] {
151+
return Connect(peer.host, peer.port, retry, timeout, worker.get());
152+
} << [&] {
153+
return worker->RecvTimeout(timeout);
154+
};
149155
if (!rc.OK()) {
150156
return rc;
151157
}
@@ -204,17 +210,18 @@ std::string InitLog(std::string task_id, std::int32_t rank) {
204210

205211
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
206212
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
207-
StringView nccl_path)
213+
StringView nccl_path, std::int32_t worker_port)
208214
: HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)},
209-
nccl_path_{std::move(nccl_path)} {
215+
nccl_path_{std::move(nccl_path)},
216+
worker_port_{worker_port} {
210217
if (this->TrackerInfo().host.empty()) {
211218
// Not in a distributed environment.
212219
LOG(CONSOLE) << InitLog(task_id_, rank_);
213220
return;
214221
}
215222

216223
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
217-
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
224+
auto rc = this->Bootstrap(timeout_, retry_, task_id_, worker_port_);
218225
if (!rc.OK()) {
219226
this->ResetState();
220227
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
@@ -230,7 +237,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
230237
#endif // !defined(XGBOOST_USE_NCCL)
231238

232239
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
233-
std::string task_id) {
240+
std::string task_id, std::int32_t worker_port) {
234241
TCPSocket tracker;
235242
std::int32_t world{-1};
236243
auto rc = ConnectTrackerImpl(this->TrackerInfo(), timeout, retry, task_id, &tracker, this->Rank(),
@@ -243,8 +250,14 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
243250

244251
// Start command
245252
TCPSocket listener = TCPSocket::Create(tracker.Domain());
246-
std::int32_t lport{0};
253+
std::int32_t lport{worker_port};
247254
rc = std::move(rc) << [&] {
255+
if (lport > 0) {
256+
// User-specified port, bind to INADDR_ANY with the given port.
257+
auto addr = (tracker.Domain() == SockDomain::kV6) ? "::" : "0.0.0.0";
258+
return listener.Bind(addr, &lport);
259+
}
260+
// Default: let the OS pick an available port.
248261
return listener.BindHost(&lport);
249262
} << [&] {
250263
return listener.Listen();
@@ -306,8 +319,11 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
306319
error_worker_.detach();
307320

308321
proto::Start start;
309-
rc = std::move(rc) << [&] { return start.WorkerSend(lport, &tracker, eport); }
310-
<< [&] { return start.WorkerRecv(&tracker, &world); };
322+
rc = std::move(rc) << [&] {
323+
return start.WorkerSend(lport, &tracker, eport);
324+
} << [&] {
325+
return start.WorkerRecv(&tracker, &world);
326+
};
311327
if (!rc.OK()) {
312328
return rc;
313329
}
@@ -418,8 +434,11 @@ RabitComm::~RabitComm() noexcept(false) {
418434
}
419435
TCPSocket out;
420436
proto::Print print;
421-
return Success() << [&] { return this->ConnectTracker(&out); }
422-
<< [&] { return print.WorkerSend(&out, msg); };
437+
return Success() << [&] {
438+
return this->ConnectTracker(&out);
439+
} << [&] {
440+
return print.WorkerSend(&out, msg);
441+
};
423442
}
424443

425444
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {

src/collective/comm.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#pragma once
55
#include <chrono> // for seconds
@@ -127,16 +127,19 @@ class HostComm : public Comm {
127127

128128
class RabitComm : public HostComm {
129129
std::string nccl_path_ = std::string{DefaultNcclName()};
130+
// User-specified port for the worker listener socket. 0 means the OS picks an available
131+
// port.
132+
std::int32_t worker_port_{0};
130133

131134
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
132-
std::string task_id);
135+
std::string task_id, std::int32_t worker_port);
133136

134137
public:
135138
// bootstrapping construction.
136139
RabitComm() = default;
137140
RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
138141
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
139-
StringView nccl_path);
142+
StringView nccl_path, std::int32_t worker_port);
140143
~RabitComm() noexcept(false) override;
141144

142145
[[nodiscard]] bool IsFederated() const override { return false; }

src/collective/comm_group.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#include "comm_group.h"
55

@@ -84,11 +84,17 @@ CommGroup::CommGroup()
8484
auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{});
8585
auto tracker_port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
8686
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
87-
auto ptr = new CommGroup{
88-
std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
89-
tracker_host, static_cast<std::int32_t>(tracker_port), std::chrono::seconds{timeout},
90-
static_cast<std::int32_t>(retry), task_id, nccl}},
91-
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
87+
auto worker_port = get_param("dmlc_worker_port", static_cast<std::int64_t>(0), Integer{});
88+
CHECK_LE(worker_port, std::numeric_limits<in_port_t>::max());
89+
CHECK_GE(worker_port, 0);
90+
CHECK_LE(tracker_port, std::numeric_limits<in_port_t>::max());
91+
CHECK_GE(tracker_port, 0);
92+
auto ptr = new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{
93+
// NOLINT
94+
tracker_host, static_cast<std::int32_t>(tracker_port),
95+
std::chrono::seconds{timeout}, static_cast<std::int32_t>(retry),
96+
task_id, nccl, static_cast<std::int32_t>(worker_port)}},
97+
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
9298
return ptr;
9399
} else if (type == "federated") {
94100
#if defined(XGBOOST_USE_FEDERATED)

src/predictor/interpretability/shap.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
#include <type_traits> // for remove_const_t
88
#include <vector> // for vector
99

10-
#include "../../common/threading_utils.h" // for ParallelFor
11-
#include "../../gbm/gbtree_model.h" // for GBTreeModel
12-
#include "../../tree/tree_view.h" // for ScalarTreeView
13-
#include "../data_accessor.h" // for GHistIndexMatrixView
14-
#include "../predict_fn.h" // for GetTreeLimit
15-
#include "../treeshap.h" // for CalculateContributions
16-
#include "dmlc/omp.h" // for omp_get_thread_num
17-
#include "xgboost/base.h" // for bst_omp_uint
18-
#include "xgboost/logging.h" // for CHECK
19-
#include "xgboost/multi_target_tree_model.h" // for MTNotImplemented
10+
#include "../../common/threading_utils.h" // for ParallelFor
11+
#include "../../gbm/gbtree_model.h" // for GBTreeModel
12+
#include "../../tree/tree_view.h" // for ScalarTreeView
13+
#include "../data_accessor.h" // for GHistIndexMatrixView
14+
#include "../predict_fn.h" // for GetTreeLimit
15+
#include "../treeshap.h" // for CalculateContributions
16+
#include "dmlc/omp.h" // for omp_get_thread_num
17+
#include "xgboost/base.h" // for bst_omp_uint
18+
#include "xgboost/logging.h" // for CHECK
19+
#include "xgboost/tree_model.h" // for MTNotImplemented
2020

2121
namespace xgboost::interpretability {
2222
namespace {

src/predictor/interpretability/shap.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "../../common/cuda_context.cuh" // for CUDAContext
2626
#include "../../common/cuda_rt_utils.h" // for SetDevice
2727
#include "../../common/device_helpers.cuh"
28-
#include "../../common/error_msg.h"
2928
#include "../../common/nvtx_utils.h"
3029
#include "../../data/batch_utils.h" // for StaticBatch
3130
#include "../../data/cat_container.cuh" // for EncPolicy, MakeCatAccessor

tests/cpp/collective/test_worker.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#pragma once
55
#include <gtest/gtest.h>
66
#include <xgboost/global_config.h> // for InitNewThread
77

8-
#include <chrono> // for seconds
9-
#include <cstdint> // for int32_t
10-
#include <fstream> // for ifstream
11-
#include <string> // for string
12-
#include <thread> // for thread
13-
#include <utility> // for move
14-
#include <vector> // for vector
8+
#include <algorithm> // for max
9+
#include <chrono> // for seconds
10+
#include <cstdint> // for int32_t
11+
#include <fstream> // for ifstream
12+
#include <string> // for string
13+
#include <thread> // for thread
14+
#include <utility> // for move
15+
#include <vector> // for vector
1516

1617
#include "../../../src/collective/comm.h" // for RabitComm
1718
#include "../../../src/collective/communicator-inl.h" // for Init, Finalize
@@ -42,7 +43,7 @@ class WorkerForTest {
4243
tracker_port_{port},
4344
world_size_{world},
4445
task_id_{"t:" + std::to_string(rank)},
45-
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} {
46+
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName(), 0} {
4647
CHECK_EQ(world_size_, comm_.World());
4748
}
4849
virtual ~WorkerForTest() noexcept(false) { SafeColl(comm_.Shutdown()); }

tests/python/test_tracker.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from hypothesis import HealthCheck, given, settings, strategies
99
from xgboost import RabitTracker, collective
1010
from xgboost import testing as tm
11+
from xgboost.testing.collective import get_avail_port
1112

1213

1314
def test_rabit_tracker() -> None:
@@ -53,6 +54,32 @@ def test_socket_error() -> None:
5354
tracker.free()
5455

5556

57+
@pytest.mark.skipif(**tm.no_loky())
58+
def test_worker_port() -> None:
59+
from loky import get_reusable_executor
60+
61+
n_workers = 4
62+
63+
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers)
64+
tracker.start()
65+
args = tracker.worker_args()
66+
67+
def local_test(worker_id: int, rabit_args: dict) -> int:
68+
cfg = collective.Config(worker_port=get_avail_port)
69+
cfg.update_worker_args(rabit_args)
70+
with collective.CommunicatorContext(**rabit_args):
71+
a = np.array([1])
72+
result = collective.allreduce(a, collective.Op.SUM)
73+
assert result[0] == n_workers
74+
75+
return 1
76+
77+
fn = update_wrapper(partial(local_test, rabit_args=args), local_test)
78+
with get_reusable_executor(max_workers=n_workers) as pool:
79+
results = pool.map(fn, range(n_workers))
80+
assert sum(results) == n_workers
81+
82+
5683
def run_rabit_ops(pool, n_workers: int, address: str) -> None:
5784
tracker = RabitTracker(host_ip=address, n_workers=n_workers)
5885
tracker.start()

0 commit comments

Comments
 (0)