Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 15396fd

Browse files
authored
Implement ranking support (#189)
Adds support for the qid parameter, allowing ranking to work correctly. The group parameter is not supported - this is also the case for XGBoost's Dask interface. Support for LightGBM to come in the future. Not sure how that will look like, as it doesn't support the qid parameter. Also fixes the ray_dmatrix_params arg being mandatory for RayXGBRanker.
1 parent 070c2a7 commit 15396fd

File tree

5 files changed

+153
-20
lines changed

5 files changed

+153
-20
lines changed

xgboost_ray/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
302302
param["label"] = [param["label"]]
303303
if not isinstance(param["weight"], list):
304304
param["weight"] = [param["weight"]]
305+
if not isinstance(param["qid"], list):
306+
param["qid"] = [param["qid"]]
305307
if not isinstance(param["data"], list):
306308
param["base_margin"] = [param["base_margin"]]
307309

@@ -322,6 +324,7 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
322324
"data": concat_dataframes(param["data"]),
323325
"label": concat_dataframes(param["label"]),
324326
"weight": concat_dataframes(param["weight"]),
327+
"qid": concat_dataframes(param["qid"]),
325328
"base_margin": concat_dataframes(param["base_margin"]),
326329
"label_lower_bound": concat_dataframes(
327330
param["label_lower_bound"]),
@@ -335,6 +338,7 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
335338

336339
if LEGACY_MATRIX:
337340
param.pop("base_margin", None)
341+
param.pop("qid", None)
338342

339343
matrix = xgb.DMatrix(**param)
340344

xgboost_ray/matrix.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
label: List[Optional[Data]],
8484
missing: Optional[float],
8585
weight: List[Optional[Data]],
86+
qid: List[Optional[Data]],
8687
base_margin: List[Optional[Data]],
8788
label_lower_bound: List[Optional[Data]],
8889
label_upper_bound: List[Optional[Data]],
@@ -97,6 +98,7 @@ def __init__(
9798
self._label = label
9899
self._missing = missing
99100
self._weight = weight
101+
self._qid = qid
100102
self._base_margin = base_margin
101103
self._label_lower_bound = label_lower_bound
102104
self._label_upper_bound = label_upper_bound
@@ -128,6 +130,7 @@ def next(self, input_data: Callable):
128130
data=self._prop(self._data),
129131
label=self._prop(self._label),
130132
weight=self._prop(self._weight),
133+
qid=self._prop(self._qid),
131134
group=None,
132135
label_lower_bound=self._prop(self._label_lower_bound),
133136
label_upper_bound=self._prop(self._label_upper_bound),
@@ -148,6 +151,7 @@ def __init__(self,
148151
label_upper_bound: Optional[Data] = None,
149152
feature_names: Optional[List[str]] = None,
150153
feature_types: Optional[List[np.dtype]] = None,
154+
qid: Optional[Data] = None,
151155
filetype: Optional[RayFileType] = None,
152156
ignore: Optional[List[str]] = None,
153157
**kwargs):
@@ -160,6 +164,7 @@ def __init__(self,
160164
self.label_upper_bound = label_upper_bound
161165
self.feature_names = feature_names
162166
self.feature_types = feature_types
167+
self.qid = qid
163168

164169
self.data_source = None
165170
self.actor_shards = None
@@ -233,6 +238,10 @@ def _split_dataframe(
233238
if exclude:
234239
exclude_cols.add(exclude)
235240

241+
qid, exclude = data_source.get_column(local_data, self.qid)
242+
if exclude:
243+
exclude_cols.add(exclude)
244+
236245
base_margin, exclude = data_source.get_column(local_data,
237246
self.base_margin)
238247
if exclude:
@@ -253,7 +262,7 @@ def _split_dataframe(
253262
x = x[[col for col in x.columns if col not in exclude_cols]]
254263

255264
return x, label, weight, base_margin, label_lower_bound, \
256-
label_upper_bound
265+
label_upper_bound, qid
257266

258267
def load_data(self,
259268
num_actors: int,
@@ -341,7 +350,7 @@ def load_data(self,
341350
# yet. Instead, we'll be selecting the rows below.
342351
local_df = data_source.load_data(
343352
self.data, ignore=self.ignore, indices=None, **self.kwargs)
344-
x, y, w, b, ll, lu = self._split_dataframe(
353+
x, y, w, b, ll, lu, qid = self._split_dataframe(
345354
local_df, data_source=data_source)
346355

347356
if isinstance(x, list):
@@ -362,7 +371,8 @@ def load_data(self,
362371
"label_lower_bound": ray.put(ll.iloc[indices]
363372
if ll is not None else None),
364373
"label_upper_bound": ray.put(lu.iloc[indices]
365-
if lu is not None else None)
374+
if lu is not None else None),
375+
"qid": ray.put(qid.iloc[indices] if qid is not None else None),
366376
}
367377
refs[i] = actor_refs
368378

@@ -505,7 +515,7 @@ def load_data(self,
505515
indices=rank_shards,
506516
ignore=self.ignore,
507517
**self.kwargs)
508-
x, y, w, b, ll, lu = self._split_dataframe(
518+
x, y, w, b, ll, lu, qid = self._split_dataframe(
509519
local_df, data_source=data_source)
510520

511521
if isinstance(x, list):
@@ -517,15 +527,16 @@ def load_data(self,
517527
indices = _get_sharding_indices(sharding, rank, num_actors, n)
518528

519529
if not indices:
520-
x, y, w, b, ll, lu = None, None, None, None, None, None
530+
x, y, w, b, ll, lu, qid = (None, None, None, None, None, None,
531+
None)
521532
n = 0
522533
else:
523534
local_df = data_source.load_data(
524535
self.data,
525536
ignore=self.ignore,
526537
indices=indices,
527538
**self.kwargs)
528-
x, y, w, b, ll, lu = self._split_dataframe(
539+
x, y, w, b, ll, lu, qid = self._split_dataframe(
529540
local_df, data_source=data_source)
530541

531542
if isinstance(x, list):
@@ -540,7 +551,8 @@ def load_data(self,
540551
"weight": ray.put(w),
541552
"base_margin": ray.put(b),
542553
"label_lower_bound": ray.put(ll),
543-
"label_upper_bound": ray.put(lu)
554+
"label_upper_bound": ray.put(lu),
555+
"qid": ray.put(qid),
544556
}
545557
}
546558

@@ -648,6 +660,7 @@ def __init__(self,
648660
label_upper_bound: Optional[Data] = None,
649661
feature_names: Optional[List[str]] = None,
650662
feature_types: Optional[List[np.dtype]] = None,
663+
qid: Optional[Data] = None,
651664
num_actors: Optional[int] = None,
652665
filetype: Optional[RayFileType] = None,
653666
ignore: Optional[List[str]] = None,
@@ -656,10 +669,20 @@ def __init__(self,
656669
lazy: bool = False,
657670
**kwargs):
658671

672+
if kwargs.get("group", None) is not None:
673+
raise ValueError(
674+
"`group` parameter is not supported. "
675+
"If you are using XGBoost-Ray, use `qid` parameter instead. "
676+
"If you are using LightGBM-Ray, ranking is not yet supported.")
677+
678+
if qid is not None and weight is not None:
679+
raise NotImplementedError("per-group weight is not implemented.")
680+
659681
self._uid = uuid.uuid4().int
660682

661683
self.feature_names = feature_names
662684
self.feature_types = feature_types
685+
self.qid = qid
663686
self.missing = missing
664687

665688
self.num_actors = num_actors
@@ -691,6 +714,7 @@ def __init__(self,
691714
feature_types=feature_types,
692715
filetype=filetype,
693716
ignore=ignore,
717+
qid=qid,
694718
**kwargs)
695719
else:
696720
self.loader = _CentralRayDMatrixLoader(
@@ -705,6 +729,7 @@ def __init__(self,
705729
feature_types=feature_types,
706730
filetype=filetype,
707731
ignore=ignore,
732+
qid=qid,
708733
**kwargs)
709734

710735
self.refs: Dict[int, Dict[str, ray.ObjectRef]] = {}
@@ -809,6 +834,7 @@ def __init__(self,
809834
label_upper_bound: Optional[Data] = None,
810835
feature_names: Optional[List[str]] = None,
811836
feature_types: Optional[List[np.dtype]] = None,
837+
qid: Optional[Data] = None,
812838
*args,
813839
**kwargs):
814840
if cp is None:
@@ -831,6 +857,7 @@ def __init__(self,
831857
label_upper_bound=None,
832858
feature_names=feature_names,
833859
feature_types=feature_types,
860+
qid=qid,
834861
*args,
835862
**kwargs)
836863

xgboost_ray/sklearn.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,13 @@ def inner_f(*args, **kwargs):
239239
return inner_f
240240

241241

242-
def _check_if_params_are_ray_dmatrix(X, sample_weight, base_margin, eval_set,
242+
def _check_if_params_are_ray_dmatrix(X,
243+
sample_weight,
244+
base_margin,
245+
eval_set,
243246
sample_weight_eval_set,
244-
base_margin_eval_set):
247+
base_margin_eval_set,
248+
eval_qid=None):
245249
train_dmatrix = None
246250
evals = ()
247251
eval_set = eval_set or ()
@@ -266,6 +270,8 @@ def _check_if_params_are_ray_dmatrix(X, sample_weight, base_margin, eval_set,
266270
params_to_warn_about.append("sample_weight_eval_set")
267271
if base_margin_eval_set is not None:
268272
params_to_warn_about.append("base_margin_eval_set")
273+
if eval_qid is not None:
274+
params_to_warn_about.append("eval_qid")
269275
if params_to_warn_about:
270276
warnings.warn(
271277
"`eval_set` is composed of RayDMatrix tuples, "
@@ -951,18 +957,24 @@ def fit(
951957
ray_dmatrix_params: Optional[Dict] = None,
952958
):
953959

954-
# check if group information is provided
955-
if group is None and qid is None:
956-
raise ValueError("group or qid is required for ranking task")
960+
if not (group is None and eval_group is None):
961+
raise ValueError("Use `qid` instead of `group` for RayXGBRanker.")
962+
if qid is None:
963+
raise ValueError("`qid` is required for ranking.")
957964

958965
if eval_set is not None:
959-
if eval_group is None and eval_qid is None:
960-
raise ValueError("eval_group or eval_qid is required if"
961-
" eval_set is not None")
966+
if eval_qid is None:
967+
raise ValueError("`eval_qid `is required if"
968+
" `eval_set` is not None")
969+
970+
evals_result = {}
971+
ray_dmatrix_params = ray_dmatrix_params or {}
972+
973+
params = self.get_xgb_params()
962974

963975
train_dmatrix, evals = _check_if_params_are_ray_dmatrix(
964976
X, sample_weight, base_margin, eval_set, sample_weight_eval_set,
965-
base_margin_eval_set)
977+
base_margin_eval_set, eval_qid)
966978

967979
if train_dmatrix is None:
968980
train_dmatrix, evals = _wrap_evaluation_matrices(
@@ -986,9 +998,6 @@ def fit(
986998
}),
987999
**self._ray_get_wrap_evaluation_matrices_compat_kwargs())
9881000

989-
evals_result = {}
990-
params = self.get_xgb_params()
991-
9921001
try:
9931002
model, feval, params = self._configure_fit(xgb_model, eval_metric,
9941003
params)

xgboost_ray/tests/test_end_to_end.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import ray
1010
from ray.exceptions import RayActorError, RayTaskError
1111

12+
from scipy.sparse import csr_matrix
13+
1214
from xgboost_ray import RayParams, train, RayDMatrix, predict, RayShardingMode
1315
from xgboost_ray.main import RayXGBoostTrainingError
1416
from xgboost_ray.callback import DistributedCallback
@@ -341,6 +343,38 @@ def testKwargsValidation(self):
341343
ray_params=RayParams(num_actors=1, max_actor_restarts=0),
342344
totally_invalid_kwarg="")
343345

346+
def testRanking(self):
347+
Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17])
348+
Xcol = np.array([0, 0, 1, 1, 2, 2, 3, 3])
349+
X = csr_matrix(
350+
(np.ones(shape=8), (Xrow, Xcol)), shape=(20, 4)).toarray()
351+
y = np.array([
352+
0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0,
353+
0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0
354+
])
355+
356+
qid = np.array([0] * 5 + [1] * 5 + [2] * 5 + [3] * 5)
357+
dtrain = RayDMatrix(X, label=y, qid=qid)
358+
359+
params = {
360+
"eta": 1,
361+
"objective": "rank:pairwise",
362+
"eval_metric": ["auc", "aucpr"],
363+
"max_depth": 1
364+
}
365+
evals_result = {}
366+
train(
367+
params,
368+
dtrain,
369+
10,
370+
evals=[(dtrain, "train")],
371+
evals_result=evals_result,
372+
ray_params=RayParams(num_actors=2, max_actor_restarts=0))
373+
auc_rec = evals_result["train"]["auc"]
374+
self.assertTrue(all(p <= q for p, q in zip(auc_rec, auc_rec[1:])))
375+
auc_rec = evals_result["train"]["aucpr"]
376+
self.assertTrue((p <= q for p, q in zip(auc_rec, auc_rec[1:])))
377+
344378

345379
if __name__ == "__main__":
346380
import pytest

xgboost_ray/tests/test_sklearn.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
RayXGBRFClassifier, RayXGBRFRegressor,
4343
RayXGBRanker)
4444

45-
from xgboost_ray.main import XGBOOST_VERSION_TUPLE
45+
from xgboost_ray.main import (XGBOOST_VERSION_TUPLE, RayDMatrix, RayParams,
46+
train, predict)
4647
from xgboost_ray.matrix import RayShardingMode
4748

4849

@@ -1211,6 +1212,64 @@ def test_estimator_type(self):
12111212
cls = RayXGBClassifier()
12121213
cls.load_model(path) # no error
12131214

1215+
def test_ranking(self):
1216+
# generate random data
1217+
x_train = np.random.rand(1000, 10)
1218+
y_train = np.random.randint(5, size=1000)
1219+
train_qid = np.repeat(np.array([list(range(20))]), 50)
1220+
1221+
x_valid = np.random.rand(200, 10)
1222+
y_valid = np.random.randint(5, size=200)
1223+
valid_qid = np.repeat(np.array([list(range(4))]), 50)
1224+
1225+
x_test = np.random.rand(100, 10)
1226+
1227+
params = {
1228+
"objective": "rank:pairwise",
1229+
"learning_rate": 0.1,
1230+
"gamma": 1.0,
1231+
"min_child_weight": 0.1,
1232+
"max_depth": 6,
1233+
"n_estimators": 4,
1234+
"random_state": 1,
1235+
"n_jobs": 2
1236+
}
1237+
model = RayXGBRanker(**params)
1238+
model.fit(
1239+
x_train,
1240+
y_train,
1241+
qid=train_qid,
1242+
eval_set=[(x_valid, y_valid)],
1243+
eval_qid=[valid_qid])
1244+
assert model.evals_result()
1245+
1246+
pred = model.predict(x_test)
1247+
1248+
train_data = RayDMatrix(x_train, y_train, qid=train_qid)
1249+
valid_data = RayDMatrix(x_valid, y_valid, qid=valid_qid)
1250+
test_data = RayDMatrix(x_test)
1251+
1252+
params_orig = {
1253+
"objective": "rank:pairwise",
1254+
"eta": 0.1,
1255+
"gamma": 1.0,
1256+
"min_child_weight": 0.1,
1257+
"max_depth": 6,
1258+
"random_state": 1
1259+
}
1260+
xgb_model_orig = train(
1261+
params_orig,
1262+
train_data,
1263+
num_boost_round=4,
1264+
evals=[(valid_data, "validation")],
1265+
ray_params=RayParams(num_actors=2, max_actor_restarts=0))
1266+
pred_orig = predict(
1267+
xgb_model_orig,
1268+
test_data,
1269+
ray_params=RayParams(num_actors=2, max_actor_restarts=0))
1270+
1271+
np.testing.assert_almost_equal(pred, pred_orig)
1272+
12141273

12151274
if __name__ == "__main__":
12161275
import pytest

0 commit comments

Comments
 (0)