Skip to content

Commit cd80557

Browse files
committed
Merge remote-tracking branch 'upstream/master' into quantile-refactor
2 parents b3a02c5 + 0219159 commit cd80557

File tree

24 files changed

+512
-374
lines changed

24 files changed

+512
-374
lines changed

doc/treemethod.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namely ``updater`` and ``tree_method``. XGBoost has 3 builtin tree methods, nam
88
free standing updaters including ``refresh``, ``prune`` and ``sync``. The parameter
99
``updater`` is more primitive than ``tree_method`` as the latter is just a
1010
pre-configuration of the former. The difference is mostly due to historical reasons that
11-
each updater requires some specific configurations and might has missing features. As we
11+
each updater requires some specific configurations and might have missing features. As we
1212
are moving forward, the gap between them is becoming more and more irrelevant. We will
1313
collectively document them under tree methods.
1414

include/xgboost/predictor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ class Predictor {
107107
*/
108108
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
109109
gbm::GBTreeModel const& model, bst_tree_t tree_begin,
110-
bst_tree_t tree_end = 0) const = 0;
110+
bst_tree_t tree_end = 0,
111+
std::vector<float> const* tree_weights = nullptr) const = 0;
111112

112113
/**
113114
* \brief Inplace prediction.

plugin/sycl/predictor/predictor.cc

100755100644
Lines changed: 71 additions & 100 deletions
Large diffs are not rendered by default.

python-package/xgboost/collective.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pickle
77
from dataclasses import dataclass
88
from enum import IntEnum, unique
9-
from typing import Any, Dict, Optional, TypeAlias, Union
9+
from typing import Any, Callable, Dict, Optional, TypeAlias, Union
1010

1111
import numpy as np
1212

@@ -46,6 +46,21 @@ class Config:
4646
4747
tracker_timeout : See :py:class:`~xgboost.tracker.RabitTracker`.
4848
49+
worker_port :
50+
51+
The port each worker listens to for peer-to-peer connections. By default,
52+
workers use an available port assigned by the OS. This option can be used in
53+
restricted network environments where only specific ports are open.
54+
55+
This can be an integer for a fixed port used by all workers, or a callback
56+
function that takes no arguments and returns a port number. The callback is
57+
invoked per-worker at the worker side.
58+
59+
.. note::
60+
61+
The option does not affect the NCCL communicator group, which must be
62+
configured via NCCL's own environment variables.
63+
4964
"""
5065

5166
retry: Optional[int] = None
@@ -55,6 +70,18 @@ class Config:
5570
tracker_port: Optional[int] = None
5671
tracker_timeout: Optional[int] = None
5772

73+
worker_port: Optional[Union[Callable[[], int], int]] = None
74+
75+
def update_worker_args(self, args: _Conf) -> _Conf:
76+
"""Worker side arguments resolution."""
77+
if self.worker_port is None:
78+
return args
79+
if callable(self.worker_port):
80+
args["dmlc_worker_port"] = self.worker_port()
81+
else:
82+
args["dmlc_worker_port"] = self.worker_port
83+
return args
84+
5885
def get_comm_config(self, args: _Conf) -> _Conf:
5986
"""Update the arguments for the communicator."""
6087
if self.retry is not None:

python-package/xgboost/dask/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,9 @@ def do_train( # pylint: disable=too-many-positional-arguments
761761
local_history: TrainingCallback.EvalsLog = {}
762762
global_config.update({"nthread": n_threads})
763763

764+
if coll_cfg is not None:
765+
coll_args = coll_cfg.update_worker_args(coll_args)
766+
764767
with CommunicatorContext(**coll_args), config.config_context(**global_config):
765768
Xy, evals = _get_dmatrices(
766769
train_ref,

python-package/xgboost/spark/core.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,9 @@ def _train_booster(
11201120
if not launch_tracker_on_driver:
11211121
_rabit_args = json.loads(messages[0])["rabit_msg"]
11221122

1123+
if conf is not None:
1124+
_rabit_args = conf.update_worker_args(_rabit_args)
1125+
11231126
evals_result: Dict[str, Any] = {}
11241127
with (
11251128
config_context(verbosity=verbosity, use_rmm=use_rmm),
@@ -1641,7 +1644,11 @@ def saveMetadata(
16411644
if instance.isDefined("coll_cfg"):
16421645
conf: Config = instance.getOrDefault("coll_cfg")
16431646
if conf is not None:
1644-
extraMetadata["coll_cfg"] = asdict(conf)
1647+
extraMetadata["coll_cfg"] = {
1648+
k: v for k, v in asdict(conf).items() if not callable(v)
1649+
}
1650+
if callable(conf.worker_port):
1651+
logger.warning("The `worker_port` is not serialized.")
16451652

16461653
DefaultParamsWriter.saveMetadata(
16471654
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams

src/collective/comm.cc

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#include "comm.h"
55

@@ -108,7 +108,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
108108

109109
rc = std::move(rc) << [&] {
110110
return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch);
111-
} << [&] { return block(); };
111+
} << [&] {
112+
return block();
113+
};
112114
if (!rc.OK()) {
113115
return Fail("Failed to get host names from peers.", std::move(rc));
114116
}
@@ -119,7 +121,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
119121
auto s_ports = common::Span{reinterpret_cast<std::int8_t*>(peers_port.data()),
120122
peers_port.size() * sizeof(ninfo.port)};
121123
return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch);
122-
} << [&] { return block(); };
124+
} << [&] {
125+
return block();
126+
};
123127
if (!rc.OK()) {
124128
return Fail("Failed to get the port from peers.", std::move(rc));
125129
}
@@ -143,9 +147,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
143147
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
144148
auto const& peer = peers[r];
145149
auto worker = std::make_shared<TCPSocket>();
146-
rc = std::move(rc)
147-
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
148-
<< [&] { return worker->RecvTimeout(timeout); };
150+
rc = std::move(rc) << [&] {
151+
return Connect(peer.host, peer.port, retry, timeout, worker.get());
152+
} << [&] {
153+
return worker->RecvTimeout(timeout);
154+
};
149155
if (!rc.OK()) {
150156
return rc;
151157
}
@@ -204,17 +210,18 @@ std::string InitLog(std::string task_id, std::int32_t rank) {
204210

205211
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
206212
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
207-
StringView nccl_path)
213+
StringView nccl_path, std::int32_t worker_port)
208214
: HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)},
209-
nccl_path_{std::move(nccl_path)} {
215+
nccl_path_{std::move(nccl_path)},
216+
worker_port_{worker_port} {
210217
if (this->TrackerInfo().host.empty()) {
211218
// Not in a distributed environment.
212219
LOG(CONSOLE) << InitLog(task_id_, rank_);
213220
return;
214221
}
215222

216223
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
217-
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
224+
auto rc = this->Bootstrap(timeout_, retry_, task_id_, worker_port_);
218225
if (!rc.OK()) {
219226
this->ResetState();
220227
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
@@ -230,7 +237,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
230237
#endif // !defined(XGBOOST_USE_NCCL)
231238

232239
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
233-
std::string task_id) {
240+
std::string task_id, std::int32_t worker_port) {
234241
TCPSocket tracker;
235242
std::int32_t world{-1};
236243
auto rc = ConnectTrackerImpl(this->TrackerInfo(), timeout, retry, task_id, &tracker, this->Rank(),
@@ -243,8 +250,14 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
243250

244251
// Start command
245252
TCPSocket listener = TCPSocket::Create(tracker.Domain());
246-
std::int32_t lport{0};
253+
std::int32_t lport{worker_port};
247254
rc = std::move(rc) << [&] {
255+
if (lport > 0) {
256+
// User-specified port, bind to INADDR_ANY with the given port.
257+
auto addr = (tracker.Domain() == SockDomain::kV6) ? "::" : "0.0.0.0";
258+
return listener.Bind(addr, &lport);
259+
}
260+
// Default: let the OS pick an available port.
248261
return listener.BindHost(&lport);
249262
} << [&] {
250263
return listener.Listen();
@@ -306,8 +319,11 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
306319
error_worker_.detach();
307320

308321
proto::Start start;
309-
rc = std::move(rc) << [&] { return start.WorkerSend(lport, &tracker, eport); }
310-
<< [&] { return start.WorkerRecv(&tracker, &world); };
322+
rc = std::move(rc) << [&] {
323+
return start.WorkerSend(lport, &tracker, eport);
324+
} << [&] {
325+
return start.WorkerRecv(&tracker, &world);
326+
};
311327
if (!rc.OK()) {
312328
return rc;
313329
}
@@ -418,8 +434,11 @@ RabitComm::~RabitComm() noexcept(false) {
418434
}
419435
TCPSocket out;
420436
proto::Print print;
421-
return Success() << [&] { return this->ConnectTracker(&out); }
422-
<< [&] { return print.WorkerSend(&out, msg); };
437+
return Success() << [&] {
438+
return this->ConnectTracker(&out);
439+
} << [&] {
440+
return print.WorkerSend(&out, msg);
441+
};
423442
}
424443

425444
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {

src/collective/comm.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#pragma once
55
#include <chrono> // for seconds
@@ -127,16 +127,19 @@ class HostComm : public Comm {
127127

128128
class RabitComm : public HostComm {
129129
std::string nccl_path_ = std::string{DefaultNcclName()};
130+
// User-specified port for the worker listener socket. 0 means the OS picks an available
131+
// port.
132+
std::int32_t worker_port_{0};
130133

131134
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
132-
std::string task_id);
135+
std::string task_id, std::int32_t worker_port);
133136

134137
public:
135138
// bootstrapping construction.
136139
RabitComm() = default;
137140
RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
138141
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
139-
StringView nccl_path);
142+
StringView nccl_path, std::int32_t worker_port);
140143
~RabitComm() noexcept(false) override;
141144

142145
[[nodiscard]] bool IsFederated() const override { return false; }

src/collective/comm_group.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2026, XGBoost Contributors
33
*/
44
#include "comm_group.h"
55

@@ -84,11 +84,17 @@ CommGroup::CommGroup()
8484
auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{});
8585
auto tracker_port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
8686
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
87-
auto ptr = new CommGroup{
88-
std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
89-
tracker_host, static_cast<std::int32_t>(tracker_port), std::chrono::seconds{timeout},
90-
static_cast<std::int32_t>(retry), task_id, nccl}},
91-
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
87+
auto worker_port = get_param("dmlc_worker_port", static_cast<std::int64_t>(0), Integer{});
88+
CHECK_LE(worker_port, std::numeric_limits<in_port_t>::max());
89+
CHECK_GE(worker_port, 0);
90+
CHECK_LE(tracker_port, std::numeric_limits<in_port_t>::max());
91+
CHECK_GE(tracker_port, 0);
92+
auto ptr = new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{
93+
// NOLINT
94+
tracker_host, static_cast<std::int32_t>(tracker_port),
95+
std::chrono::seconds{timeout}, static_cast<std::int32_t>(retry),
96+
task_id, nccl, static_cast<std::int32_t>(worker_port)}},
97+
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
9298
return ptr;
9399
} else if (type == "federated") {
94100
#if defined(XGBOOST_USE_FEDERATED)

src/data/ellpack_page_source.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,8 @@ class EllpackFormatPolicy {
187187
public:
188188
EllpackFormatPolicy() {
189189
StringView msg{" The overhead of iterating through external memory might be significant."};
190-
if (!has_hmm_) {
190+
if (!(has_hmm_ || curt::SupportsAts())) {
191191
LOG(WARNING) << "CUDA heterogeneous memory management is not available." << msg;
192-
} else if (!curt::SupportsAts()) {
193-
LOG(WARNING) << "CUDA address translation service is not available." << msg;
194192
}
195193
if (!(GlobalConfigThreadLocalStore::Get()->use_rmm ||
196194
GlobalConfigThreadLocalStore::Get()->use_cuda_async_pool)) {

0 commit comments

Comments
 (0)