Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 69e431e

Browse files
krfrickeYard1
andauthored
Update Ray core APIs (#228)
Placement group APIs are updated to use scheduling_strategy Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
1 parent cad8c3e commit 69e431e

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

xgboost_ray/data_sources/modin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414

1515
try:
1616
import modin # noqa: F401
17+
from modin.config.envvars import Engine
1718
from distutils.version import LooseVersion
1819
MODIN_INSTALLED = LooseVersion(modin.__version__) >= LooseVersion("0.9.0")
20+
21+
# Check if importing the Ray engine leads to errors
22+
Engine().get()
23+
1924
except (ImportError, AttributeError):
2025
MODIN_INSTALLED = False
2126

xgboost_ray/main.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class EarlyStopException(XGBoostError):
3838
from ray.util.annotations import PublicAPI, DeveloperAPI
3939
from ray.util.placement_group import PlacementGroup, \
4040
remove_placement_group, get_current_placement_group
41+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
4142
from ray.util.queue import Queue
4243

4344
from xgboost_ray.util import Event, MultiActorTask, force_on_current_node
@@ -747,17 +748,21 @@ def _create_actor(
747748
# Send DEFAULT_PG here, which changed in Ray >= 1.5.0
748749
# If we send `None`, this will ignore the parent placement group and
749750
# lead to errors e.g. when used within Ray Tune
750-
return _RemoteRayXGBoostActor.options(
751+
actor_cls = _RemoteRayXGBoostActor.options(
751752
num_cpus=num_cpus_per_actor,
752753
num_gpus=num_gpus_per_actor,
753754
resources=resources_per_actor,
754-
placement_group_capture_child_tasks=True,
755-
placement_group=placement_group or DEFAULT_PG).remote(
756-
rank=rank,
757-
num_actors=num_actors,
758-
queue=queue,
759-
checkpoint_frequency=checkpoint_frequency,
760-
distributed_callbacks=distributed_callbacks)
755+
scheduling_strategy=PlacementGroupSchedulingStrategy(
756+
placement_group=placement_group or DEFAULT_PG,
757+
placement_group_capture_child_tasks=True,
758+
))
759+
760+
return actor_cls.remote(
761+
rank=rank,
762+
num_actors=num_actors,
763+
queue=queue,
764+
checkpoint_frequency=checkpoint_frequency,
765+
distributed_callbacks=distributed_callbacks)
761766

762767

763768
def _trigger_data_load(actor, dtrain, evals):

xgboost_ray/tests/test_tune.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
from xgboost_ray.tune import TuneReportCallback,\
1818
TuneReportCheckpointCallback, _try_add_tune_callback
1919

20+
try:
21+
from ray.air import Checkpoint
22+
except Exception:
23+
24+
class Checkpoint:
25+
pass
26+
2027

2128
class XGBoostRayTuneTest(unittest.TestCase):
2229
def setUp(self):
@@ -146,13 +153,17 @@ def testEndToEndCheckpointing(self):
146153
log_to_file=True,
147154
local_dir=self.experiment_dir)
148155

149-
self.assertTrue(os.path.exists(analysis.best_checkpoint))
156+
if isinstance(analysis.best_checkpoint, Checkpoint):
157+
self.assertTrue(analysis.best_checkpoint)
158+
else:
159+
self.assertTrue(os.path.exists(analysis.best_checkpoint))
150160

151161
def testEndToEndCheckpointingOrigTune(self):
152162
ray_params = RayParams(cpus_per_actor=1, num_actors=2)
153163
analysis = tune.run(
154164
self.train_func(
155-
ray_params, callbacks=[OrigTuneReportCheckpointCallback()]),
165+
ray_params,
166+
callbacks=[OrigTuneReportCheckpointCallback(frequency=1)]),
156167
config=self.params,
157168
resources_per_trial=ray_params.get_tune_resources(),
158169
num_samples=1,
@@ -161,7 +172,10 @@ def testEndToEndCheckpointingOrigTune(self):
161172
log_to_file=True,
162173
local_dir=self.experiment_dir)
163174

164-
self.assertTrue(os.path.exists(analysis.best_checkpoint))
175+
if isinstance(analysis.best_checkpoint, Checkpoint):
176+
self.assertTrue(analysis.best_checkpoint)
177+
else:
178+
self.assertTrue(os.path.exists(analysis.best_checkpoint))
165179

166180

167181
if __name__ == "__main__":

0 commit comments

Comments
 (0)