|
32 | 32 | import warnings |
33 | 33 | import functools |
34 | 34 | import inspect |
| 35 | +from inspect import _finddoc |
35 | 36 |
|
36 | 37 | from ray.util.annotations import PublicAPI, DeveloperAPI |
37 | 38 |
|
@@ -212,22 +213,41 @@ def _cls_predict_proba(n_classes: int, prediction, vstack: Callable): |
212 | 213 | ) |
213 | 214 |
|
214 | 215 |
|
215 | | -def _treat_estimator_doc(doc: str) -> str: |
| 216 | +def _get_doc(object: Any) -> Optional[str]: |
| 217 | + """Same as ``inspect.getdoc``, but without ``cleandoc`` applied.""" |
| 218 | + try: |
| 219 | + doc = object.__doc__ |
| 220 | + except AttributeError: |
| 221 | + return None |
| 222 | + if doc is None: |
| 223 | + try: |
| 224 | + doc = _finddoc(object) |
| 225 | + except (AttributeError, TypeError): |
| 226 | + return None |
| 227 | + if not isinstance(doc, str): |
| 228 | + return None |
| 229 | + return doc |
| 230 | + |
| 231 | + |
| 232 | +def _treat_estimator_doc(doc: Optional[str]) -> Optional[str]: |
216 | 233 | """Helper function to make nececssary changes in estimator docstrings""" |
217 | | - doc = doc.replace(*_N_JOBS_DOC_REPLACE).replace( |
218 | | - "scikit-learn API for XGBoost", |
219 | | - "scikit-learn API for Ray-distributed XGBoost").replace( |
220 | | - ":doc:`tree method\n </treemethod>`", "tree method") |
| 234 | + if doc: |
| 235 | + doc = doc.replace(*_N_JOBS_DOC_REPLACE).replace( |
| 236 | + "scikit-learn API for XGBoost", |
| 237 | + "scikit-learn API for Ray-distributed XGBoost").replace( |
| 238 | + ":doc:`tree method\n </treemethod>`", "tree method") |
221 | 239 | return doc |
222 | 240 |
|
223 | 241 |
|
224 | | -def _treat_X_doc(doc: str) -> str: |
225 | | - doc = doc.replace("Data to predict with.", |
226 | | - "Data to predict with. Can also be a ``RayDMatrix``.") |
227 | | - doc = doc.replace("Feature matrix.", |
228 | | - "Feature matrix. Can also be a ``RayDMatrix``.") |
229 | | - doc = doc.replace("Feature matrix", |
230 | | - "Feature matrix. Can also be a ``RayDMatrix``.") |
| 242 | +def _treat_X_doc(doc: Optional[str]) -> Optional[str]: |
| 243 | + if doc: |
| 244 | + doc = doc.replace( |
| 245 | + "Data to predict with.", |
| 246 | + "Data to predict with. Can also be a ``RayDMatrix``.") |
| 247 | + doc = doc.replace("Feature matrix.", |
| 248 | + "Feature matrix. Can also be a ``RayDMatrix``.") |
| 249 | + doc = doc.replace("Feature matrix", |
| 250 | + "Feature matrix. Can also be a ``RayDMatrix``.") |
231 | 251 | return doc |
232 | 252 |
|
233 | 253 |
|
@@ -514,7 +534,7 @@ def fit( |
514 | 534 | self._set_evaluation_result(evals_result) |
515 | 535 | return self |
516 | 536 |
|
517 | | - fit.__doc__ = _treat_X_doc(XGBRegressor.fit.__doc__) + _RAY_PARAMS_DOC |
| 537 | + fit.__doc__ = _treat_X_doc(_get_doc(XGBRegressor.fit)) + _RAY_PARAMS_DOC |
518 | 538 |
|
519 | 539 | def _can_use_inplace_predict(self) -> bool: |
520 | 540 | return False |
@@ -542,16 +562,16 @@ def predict( |
542 | 562 | _remote=_remote, |
543 | 563 | ray_dmatrix_params=ray_dmatrix_params) |
544 | 564 |
|
545 | | - predict.__doc__ = _treat_X_doc( |
546 | | - XGBRegressor.predict.__doc__) + _RAY_PARAMS_DOC |
| 565 | + predict.__doc__ = _treat_X_doc(_get_doc( |
| 566 | + XGBRegressor.predict)) + _RAY_PARAMS_DOC |
547 | 567 |
|
548 | 568 | def load_model(self, fname): |
549 | 569 | if not hasattr(self, "_Booster"): |
550 | 570 | self._Booster = Booster() |
551 | 571 | return super().load_model(fname) |
552 | 572 |
|
553 | 573 |
|
554 | | -RayXGBRegressor.__doc__ = _treat_estimator_doc(XGBRegressor.__doc__) |
| 574 | +RayXGBRegressor.__doc__ = _treat_estimator_doc(_get_doc(XGBRegressor)) |
555 | 575 |
|
556 | 576 |
|
557 | 577 | class RayXGBRFRegressor(RayXGBRegressor): |
@@ -589,7 +609,7 @@ def get_num_boosting_rounds(self): |
589 | 609 | return 1 |
590 | 610 |
|
591 | 611 |
|
592 | | -RayXGBRFRegressor.__doc__ = _treat_estimator_doc(XGBRFRegressor.__doc__) |
| 612 | +RayXGBRFRegressor.__doc__ = _treat_estimator_doc(_get_doc(XGBRFRegressor)) |
593 | 613 |
|
594 | 614 |
|
595 | 615 | @PublicAPI(stability="beta") |
@@ -734,7 +754,7 @@ def fit( |
734 | 754 | self._set_evaluation_result(evals_result) |
735 | 755 | return self |
736 | 756 |
|
737 | | - fit.__doc__ = _treat_X_doc(XGBClassifier.fit.__doc__) + _RAY_PARAMS_DOC |
| 757 | + fit.__doc__ = _treat_X_doc(_get_doc(XGBClassifier.fit)) + _RAY_PARAMS_DOC |
738 | 758 |
|
739 | 759 | def _ray_fit_preprocess(self, y) -> Callable: |
740 | 760 | """This has been separated out so that it can be easily overwritten |
@@ -838,7 +858,8 @@ def predict( |
838 | 858 | return self._le.inverse_transform(column_indexes) |
839 | 859 | return column_indexes |
840 | 860 |
|
841 | | - predict.__doc__ = _treat_X_doc(XGBModel.predict.__doc__) + _RAY_PARAMS_DOC |
| 861 | + predict.__doc__ = _treat_X_doc(_get_doc( |
| 862 | + XGBModel.predict)) + _RAY_PARAMS_DOC |
842 | 863 |
|
843 | 864 | def predict_proba( |
844 | 865 | self, |
@@ -872,10 +893,10 @@ def load_model(self, fname): |
872 | 893 | return super().load_model(fname) |
873 | 894 |
|
874 | 895 | predict_proba.__doc__ = ( |
875 | | - _treat_X_doc(XGBClassifier.predict_proba.__doc__) + _RAY_PARAMS_DOC) |
| 896 | + _treat_X_doc(_get_doc(XGBClassifier.predict_proba)) + _RAY_PARAMS_DOC) |
876 | 897 |
|
877 | 898 |
|
878 | | -RayXGBClassifier.__doc__ = _treat_estimator_doc(XGBClassifier.__doc__) |
| 899 | +RayXGBClassifier.__doc__ = _treat_estimator_doc(_get_doc(XGBClassifier)) |
879 | 900 |
|
880 | 901 |
|
881 | 902 | class RayXGBRFClassifier(RayXGBClassifier): |
@@ -935,7 +956,7 @@ def get_num_boosting_rounds(self): |
935 | 956 | return 1 |
936 | 957 |
|
937 | 958 |
|
938 | | -RayXGBRFClassifier.__doc__ = _treat_estimator_doc(XGBRFClassifier.__doc__) |
| 959 | +RayXGBRFClassifier.__doc__ = _treat_estimator_doc(_get_doc(XGBRFClassifier)) |
939 | 960 |
|
940 | 961 |
|
941 | 962 | @PublicAPI(stability="beta") |
@@ -1053,7 +1074,7 @@ def fit( |
1053 | 1074 | self._set_evaluation_result(evals_result) |
1054 | 1075 | return self |
1055 | 1076 |
|
1056 | | - fit.__doc__ = _treat_X_doc(XGBRanker.fit.__doc__) + _RAY_PARAMS_DOC |
| 1077 | + fit.__doc__ = _treat_X_doc(_get_doc(XGBRanker.fit)) + _RAY_PARAMS_DOC |
1057 | 1078 |
|
1058 | 1079 | def _can_use_inplace_predict(self) -> bool: |
1059 | 1080 | return False |
@@ -1081,12 +1102,13 @@ def predict( |
1081 | 1102 | _remote=_remote, |
1082 | 1103 | ray_dmatrix_params=ray_dmatrix_params) |
1083 | 1104 |
|
1084 | | - predict.__doc__ = _treat_X_doc(XGBRanker.predict.__doc__) + _RAY_PARAMS_DOC |
| 1105 | + predict.__doc__ = _treat_X_doc(_get_doc( |
| 1106 | + XGBRanker.predict)) + _RAY_PARAMS_DOC |
1085 | 1107 |
|
1086 | 1108 | def load_model(self, fname): |
1087 | 1109 | if not hasattr(self, "_Booster"): |
1088 | 1110 | self._Booster = Booster() |
1089 | 1111 | return super().load_model(fname) |
1090 | 1112 |
|
1091 | 1113 |
|
1092 | | -RayXGBRanker.__doc__ = _treat_estimator_doc(XGBRanker.__doc__) |
| 1114 | +RayXGBRanker.__doc__ = _treat_estimator_doc(_get_doc(XGBRanker)) |
0 commit comments