Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit e5ccecc

Browse files
authored
Fix cutting edge CI failure (#276)
Fixes CI failure with cutting edge by ensuring the docstrings replaced in sklearn integration are searched for in superclasses.
1 parent 44e0f01 commit e5ccecc

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

xgboost_ray/sklearn.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import warnings
3333
import functools
3434
import inspect
35+
from inspect import _finddoc
3536

3637
from ray.util.annotations import PublicAPI, DeveloperAPI
3738

@@ -212,22 +213,41 @@ def _cls_predict_proba(n_classes: int, prediction, vstack: Callable):
212213
)
213214

214215

215-
def _treat_estimator_doc(doc: str) -> str:
216+
def _get_doc(object: Any) -> Optional[str]:
217+
"""Same as ``inspect.getdoc``, but without ``cleandoc`` applied."""
218+
try:
219+
doc = object.__doc__
220+
except AttributeError:
221+
return None
222+
if doc is None:
223+
try:
224+
doc = _finddoc(object)
225+
except (AttributeError, TypeError):
226+
return None
227+
if not isinstance(doc, str):
228+
return None
229+
return doc
230+
231+
232+
def _treat_estimator_doc(doc: Optional[str]) -> Optional[str]:
216233
"""Helper function to make nececssary changes in estimator docstrings"""
217-
doc = doc.replace(*_N_JOBS_DOC_REPLACE).replace(
218-
"scikit-learn API for XGBoost",
219-
"scikit-learn API for Ray-distributed XGBoost").replace(
220-
":doc:`tree method\n </treemethod>`", "tree method")
234+
if doc:
235+
doc = doc.replace(*_N_JOBS_DOC_REPLACE).replace(
236+
"scikit-learn API for XGBoost",
237+
"scikit-learn API for Ray-distributed XGBoost").replace(
238+
":doc:`tree method\n </treemethod>`", "tree method")
221239
return doc
222240

223241

224-
def _treat_X_doc(doc: str) -> str:
225-
doc = doc.replace("Data to predict with.",
226-
"Data to predict with. Can also be a ``RayDMatrix``.")
227-
doc = doc.replace("Feature matrix.",
228-
"Feature matrix. Can also be a ``RayDMatrix``.")
229-
doc = doc.replace("Feature matrix",
230-
"Feature matrix. Can also be a ``RayDMatrix``.")
242+
def _treat_X_doc(doc: Optional[str]) -> Optional[str]:
243+
if doc:
244+
doc = doc.replace(
245+
"Data to predict with.",
246+
"Data to predict with. Can also be a ``RayDMatrix``.")
247+
doc = doc.replace("Feature matrix.",
248+
"Feature matrix. Can also be a ``RayDMatrix``.")
249+
doc = doc.replace("Feature matrix",
250+
"Feature matrix. Can also be a ``RayDMatrix``.")
231251
return doc
232252

233253

@@ -514,7 +534,7 @@ def fit(
514534
self._set_evaluation_result(evals_result)
515535
return self
516536

517-
fit.__doc__ = _treat_X_doc(XGBRegressor.fit.__doc__) + _RAY_PARAMS_DOC
537+
fit.__doc__ = _treat_X_doc(_get_doc(XGBRegressor.fit)) + _RAY_PARAMS_DOC
518538

519539
def _can_use_inplace_predict(self) -> bool:
520540
return False
@@ -542,16 +562,16 @@ def predict(
542562
_remote=_remote,
543563
ray_dmatrix_params=ray_dmatrix_params)
544564

545-
predict.__doc__ = _treat_X_doc(
546-
XGBRegressor.predict.__doc__) + _RAY_PARAMS_DOC
565+
predict.__doc__ = _treat_X_doc(_get_doc(
566+
XGBRegressor.predict)) + _RAY_PARAMS_DOC
547567

548568
def load_model(self, fname):
549569
if not hasattr(self, "_Booster"):
550570
self._Booster = Booster()
551571
return super().load_model(fname)
552572

553573

554-
RayXGBRegressor.__doc__ = _treat_estimator_doc(XGBRegressor.__doc__)
574+
RayXGBRegressor.__doc__ = _treat_estimator_doc(_get_doc(XGBRegressor))
555575

556576

557577
class RayXGBRFRegressor(RayXGBRegressor):
@@ -589,7 +609,7 @@ def get_num_boosting_rounds(self):
589609
return 1
590610

591611

592-
RayXGBRFRegressor.__doc__ = _treat_estimator_doc(XGBRFRegressor.__doc__)
612+
RayXGBRFRegressor.__doc__ = _treat_estimator_doc(_get_doc(XGBRFRegressor))
593613

594614

595615
@PublicAPI(stability="beta")
@@ -734,7 +754,7 @@ def fit(
734754
self._set_evaluation_result(evals_result)
735755
return self
736756

737-
fit.__doc__ = _treat_X_doc(XGBClassifier.fit.__doc__) + _RAY_PARAMS_DOC
757+
fit.__doc__ = _treat_X_doc(_get_doc(XGBClassifier.fit)) + _RAY_PARAMS_DOC
738758

739759
def _ray_fit_preprocess(self, y) -> Callable:
740760
"""This has been separated out so that it can be easily overwritten
@@ -838,7 +858,8 @@ def predict(
838858
return self._le.inverse_transform(column_indexes)
839859
return column_indexes
840860

841-
predict.__doc__ = _treat_X_doc(XGBModel.predict.__doc__) + _RAY_PARAMS_DOC
861+
predict.__doc__ = _treat_X_doc(_get_doc(
862+
XGBModel.predict)) + _RAY_PARAMS_DOC
842863

843864
def predict_proba(
844865
self,
@@ -872,10 +893,10 @@ def load_model(self, fname):
872893
return super().load_model(fname)
873894

874895
predict_proba.__doc__ = (
875-
_treat_X_doc(XGBClassifier.predict_proba.__doc__) + _RAY_PARAMS_DOC)
896+
_treat_X_doc(_get_doc(XGBClassifier.predict_proba)) + _RAY_PARAMS_DOC)
876897

877898

878-
RayXGBClassifier.__doc__ = _treat_estimator_doc(XGBClassifier.__doc__)
899+
RayXGBClassifier.__doc__ = _treat_estimator_doc(_get_doc(XGBClassifier))
879900

880901

881902
class RayXGBRFClassifier(RayXGBClassifier):
@@ -935,7 +956,7 @@ def get_num_boosting_rounds(self):
935956
return 1
936957

937958

938-
RayXGBRFClassifier.__doc__ = _treat_estimator_doc(XGBRFClassifier.__doc__)
959+
RayXGBRFClassifier.__doc__ = _treat_estimator_doc(_get_doc(XGBRFClassifier))
939960

940961

941962
@PublicAPI(stability="beta")
@@ -1053,7 +1074,7 @@ def fit(
10531074
self._set_evaluation_result(evals_result)
10541075
return self
10551076

1056-
fit.__doc__ = _treat_X_doc(XGBRanker.fit.__doc__) + _RAY_PARAMS_DOC
1077+
fit.__doc__ = _treat_X_doc(_get_doc(XGBRanker.fit)) + _RAY_PARAMS_DOC
10571078

10581079
def _can_use_inplace_predict(self) -> bool:
10591080
return False
@@ -1081,12 +1102,13 @@ def predict(
10811102
_remote=_remote,
10821103
ray_dmatrix_params=ray_dmatrix_params)
10831104

1084-
predict.__doc__ = _treat_X_doc(XGBRanker.predict.__doc__) + _RAY_PARAMS_DOC
1105+
predict.__doc__ = _treat_X_doc(_get_doc(
1106+
XGBRanker.predict)) + _RAY_PARAMS_DOC
10851107

10861108
def load_model(self, fname):
10871109
if not hasattr(self, "_Booster"):
10881110
self._Booster = Booster()
10891111
return super().load_model(fname)
10901112

10911113

1092-
RayXGBRanker.__doc__ = _treat_estimator_doc(XGBRanker.__doc__)
1114+
RayXGBRanker.__doc__ = _treat_estimator_doc(_get_doc(XGBRanker))

xgboost_ray/tests/test_sklearn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,10 @@ def test_constraint_parameters(self):
976976
reg.fit(X, y)
977977

978978
config = json.loads(reg.get_booster().save_config())
979-
if XGBOOST_VERSION >= Version("1.6.0"):
979+
if XGBOOST_VERSION > Version("1.7.4"):
980+
assert (config["learner"]["gradient_booster"]["tree_train_param"][
981+
"interaction_constraints"] == "[[0, 1], [2, 3, 4]]")
982+
elif XGBOOST_VERSION >= Version("1.6.0"):
980983
assert (config["learner"]["gradient_booster"]["updater"][
981984
"grow_histmaker"]["train_param"]["interaction_constraints"] ==
982985
"[[0, 1], [2, 3, 4]]")

0 commit comments

Comments
 (0)