Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit a1089c6

Browse files
authored
Compatibility for xgboost>=1.6.0 (#167)
1 parent 03e0a34 commit a1089c6

File tree

2 files changed

+46
-21
lines changed

2 files changed

+46
-21
lines changed

xgboost_ray/sklearn.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,17 @@ def _ray_predict(
347347
**compat_predict_kwargs,
348348
)
349349

350-
def _ray_get_wrap_evaluation_matrices_compat_kwargs(self) -> dict:
350+
def _ray_get_wrap_evaluation_matrices_compat_kwargs(
351+
self, label_transform=None) -> dict:
352+
ret = {}
353+
if "label_transform" in inspect.signature(
354+
_wrap_evaluation_matrices).parameters:
355+
# XGBoost < 1.6.0
356+
identity_func = lambda x: x # noqa
357+
ret["label_transform"] = label_transform or identity_func
351358
if hasattr(self, "enable_categorical"):
352-
return {"enable_categorical": self.enable_categorical}
353-
return {}
359+
ret["enable_categorical"] = self.enable_categorical
360+
return ret
354361

355362
# copied from the file in the top comment
356363
# provided here for compatibility with legacy xgboost versions
@@ -450,8 +457,13 @@ def fit(
450457
else:
451458
obj = None
452459

453-
model, feval, params = self._configure_fit(xgb_model, eval_metric,
454-
params)
460+
try:
461+
model, feval, params = self._configure_fit(xgb_model, eval_metric,
462+
params)
463+
except TypeError:
464+
# XGBoost >= 1.6.0
465+
model, feval, params, early_stopping_rounds = self._configure_fit(
466+
xgb_model, eval_metric, params, early_stopping_rounds)
455467

456468
# remove those as they will be set in RayXGBoostActor
457469
params.pop("n_jobs", None)
@@ -638,8 +650,13 @@ def fit(
638650
params["objective"] = "multi:softprob"
639651
params["num_class"] = self.n_classes_
640652

641-
model, feval, params = self._configure_fit(xgb_model, eval_metric,
642-
params)
653+
try:
654+
model, feval, params = self._configure_fit(xgb_model, eval_metric,
655+
params)
656+
except TypeError:
657+
# XGBoost >= 1.6.0
658+
model, feval, params, early_stopping_rounds = self._configure_fit(
659+
xgb_model, eval_metric, params, early_stopping_rounds)
643660

644661
if train_dmatrix is None:
645662
train_dmatrix, evals = _wrap_evaluation_matrices(
@@ -656,13 +673,13 @@ def fit(
656673
base_margin_eval_set=base_margin_eval_set,
657674
eval_group=None,
658675
eval_qid=None,
659-
label_transform=label_transform,
660676
# changed in xgboost-ray:
661677
create_dmatrix=lambda **kwargs: RayDMatrix(**{
662678
**kwargs,
663679
**ray_dmatrix_params
664680
}),
665-
**self._ray_get_wrap_evaluation_matrices_compat_kwargs())
681+
**self._ray_get_wrap_evaluation_matrices_compat_kwargs(
682+
label_transform=label_transform))
666683

667684
# remove those as they will be set in RayXGBoostActor
668685
params.pop("n_jobs", None)
@@ -970,8 +987,13 @@ def fit(
970987
evals_result = {}
971988
params = self.get_xgb_params()
972989

973-
model, feval, params = self._configure_fit(xgb_model, eval_metric,
974-
params)
990+
try:
991+
model, feval, params = self._configure_fit(xgb_model, eval_metric,
992+
params)
993+
except TypeError:
994+
# XGBoost >= 1.6.0
995+
model, feval, params, early_stopping_rounds = self._configure_fit(
996+
xgb_model, eval_metric, params, early_stopping_rounds)
975997
if callable(feval):
976998
raise ValueError(
977999
"Custom evaluation metric is not yet supported for XGBRanker.")

xgboost_ray/tests/test_sklearn_matrix.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
from xgboost_ray.main import XGBOOST_VERSION_TUPLE
1313

14+
has_label_encoder = (XGBOOST_VERSION_TUPLE >= (1, 0, 0)
15+
and XGBOOST_VERSION_TUPLE < (1, 6, 0))
16+
1417

1518
class XGBoostRaySklearnMatrixTest(unittest.TestCase):
1619
def setUp(self):
@@ -26,9 +29,9 @@ def _init_ray(self):
2629
if not ray.is_initialized():
2730
ray.init(num_cpus=4)
2831

29-
@unittest.skipIf(XGBOOST_VERSION_TUPLE < (1, 0, 0),
32+
@unittest.skipIf(not has_label_encoder,
3033
f"not supported in xgb version {xgb.__version__}")
31-
def testClassifier(self, n_class=2):
34+
def testClassifierLabelEncoder(self, n_class=2):
3235
self._init_ray()
3336

3437
from sklearn.datasets import load_digits
@@ -74,14 +77,14 @@ def testClassifier(self, n_class=2):
7477
clf.predict(test_matrix)
7578
clf.predict_proba(test_matrix)
7679

77-
@unittest.skipIf(XGBOOST_VERSION_TUPLE < (1, 0, 0),
80+
@unittest.skipIf(not has_label_encoder,
7881
f"not supported in xgb version {xgb.__version__}")
79-
def testClassifierMulticlass(self):
80-
self.testClassifier(n_class=3)
82+
def testClassifierMulticlassLabelEncoder(self):
83+
self.testClassifierLabelEncoder(n_class=3)
8184

82-
@unittest.skipIf(XGBOOST_VERSION_TUPLE >= (1, 0, 0),
85+
@unittest.skipIf(has_label_encoder,
8386
f"not supported in xgb version {xgb.__version__}")
84-
def testClassifierLegacy(self, n_class=2):
87+
def testClassifierNoLabelEncoder(self, n_class=2):
8588
self._init_ray()
8689

8790
from sklearn.datasets import load_digits
@@ -118,10 +121,10 @@ def testClassifierLegacy(self, n_class=2):
118121
clf.predict(test_matrix)
119122
clf.predict_proba(test_matrix)
120123

121-
@unittest.skipIf(XGBOOST_VERSION_TUPLE >= (1, 0, 0),
124+
@unittest.skipIf(has_label_encoder,
122125
f"not supported in xgb version {xgb.__version__}")
123-
def testClassifierMulticlassLegacy(self):
124-
self.testClassifierLegacy(n_class=3)
126+
def testClassifierMulticlassNoLabelEncoder(self):
127+
self.testClassifierNoLabelEncoder(n_class=3)
125128

126129
def testRegressor(self):
127130
self._init_ray()

0 commit comments

Comments
 (0)