3737 RAY_INSTALLED = False
3838
3939from xgboost_ray .tune import _try_add_tune_callback , _get_tune_resources , \
40- TUNE_USING_PG
40+ TUNE_USING_PG , is_session_enabled
4141
4242from xgboost_ray .matrix import RayDMatrix , combine_data , \
4343 RayDeviceQuantileDMatrix , RayDataIter , concat_dataframes
@@ -93,6 +93,13 @@ def _assert_ray_support():
9393 "Try: `pip install ray`" )
9494
9595
96+ def _is_client_connected () -> bool :
97+ try :
98+ return ray .util .client .ray .is_connected ()
99+ except Exception :
100+ return False
101+
102+
96103class _RabitTracker (xgb .RabitTracker ):
97104 """
98105 This method overwrites the xgboost-provided RabitTracker to switch
@@ -639,7 +646,7 @@ def _create_placement_group(cpus_per_actor, gpus_per_actor,
639646
640647def _create_communication_processes (added_tune_callback : bool = False ):
641648 # Create Queue and Event actors and make sure to colocate with driver node.
642- node_ip = ray . services . get_node_ip_address ()
649+ node_ip = get_node_ip_address ()
643650 # Have to explicitly set num_cpus to 0.
644651 placement_option = {"num_cpus" : 0 }
645652 if added_tune_callback and TUNE_USING_PG :
@@ -925,6 +932,7 @@ def train(params: Dict,
925932 evals_result : Optional [Dict ] = None ,
926933 additional_results : Optional [Dict ] = None ,
927934 ray_params : Union [None , RayParams , Dict ] = None ,
935+ _remote : Optional [bool ] = None ,
928936 ** kwargs ):
929937 """Distributed XGBoost training via Ray.
930938
@@ -967,11 +975,51 @@ def train(params: Dict,
967975 ray_params (Union[None, RayParams, Dict]): Parameters to configure
968976 Ray-specific behavior. See :class:`RayParams` for a list of valid
969977 configuration parameters.
978+ _remote (bool): Whether to run the driver process in a remote
979+ function. This is enabled by default in Ray client mode.
970980 **kwargs: Keyword arguments will be passed to the local
971981 `xgb.train()` calls.
972982
973983 Returns: An ``xgboost.Booster`` object.
974984 """
985+ os .environ .setdefault ("RAY_IGNORE_UNHANDLED_ERRORS" , "1" )
986+
987+ if _remote is None :
988+ _remote = _is_client_connected () and \
989+ not is_session_enabled ()
990+
991+ if not ray .is_initialized ():
992+ ray .init ()
993+
994+ if _remote :
995+ # Run this function as a remote function to support Ray client mode
996+ @ray .remote (num_cpus = 0 )
997+ def _wrapped (* args , ** kwargs ):
998+ _evals_result = {}
999+ _additional_results = {}
1000+ bst = train (
1001+ * args ,
1002+ evals_result = _evals_result ,
1003+ additional_results = _additional_results ,
1004+ ** kwargs )
1005+ return bst , _evals_result , _additional_results
1006+
1007+ bst , train_evals_result , train_additional_results = ray .get (
1008+ _wrapped .remote (
1009+ params ,
1010+ dtrain ,
1011+ * args ,
1012+ evals = evals ,
1013+ ray_params = ray_params ,
1014+ _remote = False ,
1015+ ** kwargs ,
1016+ ))
1017+ if isinstance (evals_result , dict ):
1018+ evals_result .update (train_evals_result )
1019+ if isinstance (additional_results , dict ):
1020+ additional_results .update (train_additional_results )
1021+ return bst
1022+
9751023 ray_params = _validate_ray_params (ray_params )
9761024
9771025 max_actor_restarts = ray_params .max_actor_restarts \
@@ -986,9 +1034,6 @@ def train(params: Dict,
9861034 "`dtrain = RayDMatrix(data=data, label=label)`." .format (
9871035 type (dtrain )))
9881036
989- if not ray .is_initialized ():
990- ray .init ()
991-
9921037 cpus_per_actor , gpus_per_actor = _autodetect_resources (
9931038 ray_params = ray_params ,
9941039 use_tree_method = "tree_method" in params
@@ -1211,6 +1256,7 @@ def _predict(model: xgb.Booster, data: RayDMatrix, ray_params: RayParams,
12111256def predict (model : xgb .Booster ,
12121257 data : RayDMatrix ,
12131258 ray_params : Union [None , RayParams , Dict ] = None ,
1259+ _remote : Optional [bool ] = None ,
12141260 ** kwargs ) -> Optional [np .ndarray ]:
12151261 """Distributed XGBoost predict via Ray.
12161262
@@ -1225,12 +1271,28 @@ def predict(model: xgb.Booster,
12251271 ray_params (Union[None, RayParams, Dict]): Parameters to configure
12261272 Ray-specific behavior. See :class:`RayParams` for a list of valid
12271273 configuration parameters.
1274+ _remote (bool): Whether to run the driver process in a remote
1275+ function. This is enabled by default in Ray client mode.
12281276 **kwargs: Keyword arguments will be passed to the local
12291277 `xgb.predict()` calls.
12301278
12311279 Returns: ``np.ndarray`` containing the predicted labels.
12321280
12331281 """
1282+ os .environ .setdefault ("RAY_IGNORE_UNHANDLED_ERRORS" , "1" )
1283+
1284+ if _remote is None :
1285+ _remote = _is_client_connected () and \
1286+ not is_session_enabled ()
1287+
1288+ if not ray .is_initialized ():
1289+ ray .init ()
1290+
1291+ if _remote :
1292+ return ray .get (
1293+ ray .remote (num_cpus = 0 )(predict ).remote (
1294+ model , data , ray_params , _remote = False , ** kwargs ))
1295+
12341296 ray_params = _validate_ray_params (ray_params )
12351297
12361298 max_actor_restarts = ray_params .max_actor_restarts \
0 commit comments