Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 1e6494f

Browse files
authored
Allow train and predict to run in remote function, enabling Ray client mode (#57)
* Update README * Detect ray client session and run in remote call * Shut down ray * Ray client compatibility * Move init check * move order * Fix typo
1 parent c261b93 commit 1e6494f

File tree

4 files changed

+124
-7
lines changed

4 files changed

+124
-7
lines changed

xgboost_ray/main.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
RAY_INSTALLED = False
3838

3939
from xgboost_ray.tune import _try_add_tune_callback, _get_tune_resources, \
40-
TUNE_USING_PG
40+
TUNE_USING_PG, is_session_enabled
4141

4242
from xgboost_ray.matrix import RayDMatrix, combine_data, \
4343
RayDeviceQuantileDMatrix, RayDataIter, concat_dataframes
@@ -93,6 +93,13 @@ def _assert_ray_support():
9393
"Try: `pip install ray`")
9494

9595

96+
def _is_client_connected() -> bool:
97+
try:
98+
return ray.util.client.ray.is_connected()
99+
except Exception:
100+
return False
101+
102+
96103
class _RabitTracker(xgb.RabitTracker):
97104
"""
98105
This method overwrites the xgboost-provided RabitTracker to switch
@@ -639,7 +646,7 @@ def _create_placement_group(cpus_per_actor, gpus_per_actor,
639646

640647
def _create_communication_processes(added_tune_callback: bool = False):
641648
# Create Queue and Event actors and make sure to colocate with driver node.
642-
node_ip = ray.services.get_node_ip_address()
649+
node_ip = get_node_ip_address()
643650
# Have to explicitly set num_cpus to 0.
644651
placement_option = {"num_cpus": 0}
645652
if added_tune_callback and TUNE_USING_PG:
@@ -925,6 +932,7 @@ def train(params: Dict,
925932
evals_result: Optional[Dict] = None,
926933
additional_results: Optional[Dict] = None,
927934
ray_params: Union[None, RayParams, Dict] = None,
935+
_remote: Optional[bool] = None,
928936
**kwargs):
929937
"""Distributed XGBoost training via Ray.
930938
@@ -967,11 +975,51 @@ def train(params: Dict,
967975
ray_params (Union[None, RayParams, Dict]): Parameters to configure
968976
Ray-specific behavior. See :class:`RayParams` for a list of valid
969977
configuration parameters.
978+
_remote (bool): Whether to run the driver process in a remote
979+
function. This is enabled by default in Ray client mode.
970980
**kwargs: Keyword arguments will be passed to the local
971981
`xgb.train()` calls.
972982
973983
Returns: An ``xgboost.Booster`` object.
974984
"""
985+
os.environ.setdefault("RAY_IGNORE_UNHANDLED_ERRORS", "1")
986+
987+
if _remote is None:
988+
_remote = _is_client_connected() and \
989+
not is_session_enabled()
990+
991+
if not ray.is_initialized():
992+
ray.init()
993+
994+
if _remote:
995+
# Run this function as a remote function to support Ray client mode
996+
@ray.remote(num_cpus=0)
997+
def _wrapped(*args, **kwargs):
998+
_evals_result = {}
999+
_additional_results = {}
1000+
bst = train(
1001+
*args,
1002+
evals_result=_evals_result,
1003+
additional_results=_additional_results,
1004+
**kwargs)
1005+
return bst, _evals_result, _additional_results
1006+
1007+
bst, train_evals_result, train_additional_results = ray.get(
1008+
_wrapped.remote(
1009+
params,
1010+
dtrain,
1011+
*args,
1012+
evals=evals,
1013+
ray_params=ray_params,
1014+
_remote=False,
1015+
**kwargs,
1016+
))
1017+
if isinstance(evals_result, dict):
1018+
evals_result.update(train_evals_result)
1019+
if isinstance(additional_results, dict):
1020+
additional_results.update(train_additional_results)
1021+
return bst
1022+
9751023
ray_params = _validate_ray_params(ray_params)
9761024

9771025
max_actor_restarts = ray_params.max_actor_restarts \
@@ -986,9 +1034,6 @@ def train(params: Dict,
9861034
"`dtrain = RayDMatrix(data=data, label=label)`.".format(
9871035
type(dtrain)))
9881036

989-
if not ray.is_initialized():
990-
ray.init()
991-
9921037
cpus_per_actor, gpus_per_actor = _autodetect_resources(
9931038
ray_params=ray_params,
9941039
use_tree_method="tree_method" in params
@@ -1211,6 +1256,7 @@ def _predict(model: xgb.Booster, data: RayDMatrix, ray_params: RayParams,
12111256
def predict(model: xgb.Booster,
12121257
data: RayDMatrix,
12131258
ray_params: Union[None, RayParams, Dict] = None,
1259+
_remote: Optional[bool] = None,
12141260
**kwargs) -> Optional[np.ndarray]:
12151261
"""Distributed XGBoost predict via Ray.
12161262
@@ -1225,12 +1271,28 @@ def predict(model: xgb.Booster,
12251271
ray_params (Union[None, RayParams, Dict]): Parameters to configure
12261272
Ray-specific behavior. See :class:`RayParams` for a list of valid
12271273
configuration parameters.
1274+
_remote (bool): Whether to run the driver process in a remote
1275+
function. This is enabled by default in Ray client mode.
12281276
**kwargs: Keyword arguments will be passed to the local
12291277
`xgb.predict()` calls.
12301278
12311279
Returns: ``np.ndarray`` containing the predicted labels.
12321280
12331281
"""
1282+
os.environ.setdefault("RAY_IGNORE_UNHANDLED_ERRORS", "1")
1283+
1284+
if _remote is None:
1285+
_remote = _is_client_connected() and \
1286+
not is_session_enabled()
1287+
1288+
if not ray.is_initialized():
1289+
ray.init()
1290+
1291+
if _remote:
1292+
return ray.get(
1293+
ray.remote(num_cpus=0)(predict).remote(
1294+
model, data, ray_params, _remote=False, **kwargs))
1295+
12341296
ray_params = _validate_ray_params(ray_params)
12351297

12361298
max_actor_restarts = ray_params.max_actor_restarts \

xgboost_ray/matrix.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,6 @@ def __init__(self,
652652
self.feature_types = feature_types
653653
self.missing = missing
654654

655-
self.memory_node_ip = ray.services.get_node_ip_address()
656655
self.num_actors = num_actors
657656
self.sharding = sharding
658657

xgboost_ray/tests/test_end_to_end.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import ray
66

7-
from xgboost_ray import RayParams, train, RayDMatrix
7+
from xgboost_ray import RayParams, train, RayDMatrix, predict
88

99

1010
class XGBoostRayEndToEndTest(unittest.TestCase):
@@ -40,6 +40,10 @@ def setUp(self):
4040
"num_class": 4
4141
}
4242

43+
def tearDown(self):
44+
if ray.is_initialized:
45+
ray.shutdown()
46+
4347
def testSingleTraining(self):
4448
"""Test that XGBoost learns to predict full matrix"""
4549
dtrain = xgb.DMatrix(self.x, self.y)
@@ -99,6 +103,45 @@ def testJointTraining(self):
99103
pred_y = bst.predict(x_mat)
100104
self.assertSequenceEqual(list(self.y), list(pred_y))
101105

106+
def testTrainPredict(self, init=True, remote=None):
107+
"""Train with evaluation and predict"""
108+
if init:
109+
ray.init(num_cpus=2, num_gpus=0)
110+
111+
dtrain = RayDMatrix(self.x, self.y)
112+
113+
evals_result = {}
114+
bst = train(
115+
self.params,
116+
dtrain,
117+
ray_params=RayParams(num_actors=2),
118+
evals=[(dtrain, "dtrain")],
119+
evals_result=evals_result,
120+
_remote=remote)
121+
122+
self.assertTrue("dtrain" in evals_result)
123+
124+
x_mat = RayDMatrix(self.x)
125+
pred_y = predict(bst, x_mat, _remote=remote)
126+
self.assertSequenceEqual(list(self.y), list(pred_y))
127+
128+
def testTrainPredictRemote(self):
129+
"""Train with evaluation and predict in a remote call"""
130+
self.testTrainPredict(init=True, remote=True)
131+
132+
def testTrainPredictClient(self):
133+
"""Train with evaluation and predict in a client session"""
134+
if ray.__version__ <= "1.2.0":
135+
self.skipTest("Ray client mocks do not work in Ray <= 1.2.0")
136+
from ray.util.client.ray_client_helpers import ray_start_client_server
137+
138+
ray.init(num_cpus=2, num_gpus=0)
139+
self.assertFalse(ray.util.client.ray.is_connected())
140+
with ray_start_client_server():
141+
self.assertTrue(ray.util.client.ray.is_connected())
142+
143+
self.testTrainPredict(init=False, remote=None)
144+
102145

103146
if __name__ == "__main__":
104147
import pytest

xgboost_ray/tests/test_tune.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def tearDown(self):
6565

6666
# noinspection PyTypeChecker
6767
def testNumIters(self):
68+
"""Test that the number of reported tune results is correct"""
6869
ray_params = RayParams(cpus_per_actor=1, num_actors=2)
6970
analysis = tune.run(
7071
self.train_func(ray_params),
@@ -76,6 +77,18 @@ def testNumIters(self):
7677
list(analysis.results_df["training_iteration"]),
7778
list(analysis.results_df["config.num_boost_round"]))
7879

80+
def testNumItersClient(self):
81+
"""Test ray client mode"""
82+
if ray.__version__ <= "1.2.0":
83+
self.skipTest("Ray client mocks do not work in Ray <= 1.2.0")
84+
85+
from ray.util.client.ray_client_helpers import ray_start_client_server
86+
87+
self.assertFalse(ray.util.client.ray.is_connected())
88+
with ray_start_client_server():
89+
self.assertTrue(ray.util.client.ray.is_connected())
90+
self.testNumIters()
91+
7992
def testElasticFails(self):
8093
"""Test if error is thrown when using Tune with elastic training."""
8194
ray_params = RayParams(

0 commit comments

Comments
 (0)