Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 6a4685b

Browse files
authored
Use new train.report API (#292)
We are converging on using train.report throughout the Ray library code base instead of tune.report.
1 parent d415b49 commit 6a4685b

File tree

5 files changed

+65
-28
lines changed

5 files changed

+65
-28
lines changed

requirements/test-requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@ packaging
22
petastorm
33
pytest
44
pyarrow
5-
ray[tune, data]
5+
ray[tune, data, default]
66
scikit-learn
77
modin
88
dask
99

1010
#workaround for now
1111
protobuf<4.0.0
1212
tensorboardX==2.2
13-
aiohttp

xgboost_ray/examples/simple_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main(cpus_per_actor, num_actors, num_samples):
6666

6767
# Load the best model checkpoint.
6868
best_bst = xgboost_ray.tune.load_model(
69-
os.path.join(analysis.best_logdir, "tuned.xgb")
69+
os.path.join(analysis.best_trial.local_path, "tuned.xgb")
7070
)
7171

7272
best_bst.save_model("best_model.xgb")

xgboost_ray/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def __init__(
557557

558558
self.checkpoint_frequency = checkpoint_frequency
559559

560-
self._data: Dict[RayDMatrix, xgb.DMatrix] = {}
560+
self._data: Dict[RayDMatrix, dict] = {}
561561
self._local_n: Dict[RayDMatrix, int] = {}
562562

563563
self._stop_event = stop_event

xgboost_ray/tests/test_tune.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,13 @@ def testReplaceTuneCheckpoints(self):
158158

159159
replaced = in_dict["callbacks"][0]
160160
self.assertTrue(isinstance(replaced, TuneReportCheckpointCallback))
161-
self.assertSequenceEqual(replaced._report._metrics, ["met"])
162-
self.assertEqual(replaced._checkpoint._filename, "test")
161+
162+
if getattr(replaced, "_report", None):
163+
self.assertSequenceEqual(replaced._report._metrics, ["met"])
164+
self.assertEqual(replaced._checkpoint._filename, "test")
165+
else:
166+
self.assertSequenceEqual(replaced._metrics, ["met"])
167+
self.assertEqual(replaced._filename, "test")
163168

164169
def testEndToEndCheckpointing(self):
165170
ray_params = RayParams(cpus_per_actor=1, num_actors=2)

xgboost_ray/tune.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from typing import Dict, Optional
44

55
import ray
6+
from ray.train._internal.session import get_session
67
from ray.util.annotations import PublicAPI
78

89
from xgboost_ray.session import get_rabit_rank, put_queue
910
from xgboost_ray.util import Unavailable, force_on_current_node
1011
from xgboost_ray.xgb import xgboost as xgb
1112

1213
try:
13-
from ray import tune
14+
from ray import train, tune
1415
from ray.tune import is_session_enabled
1516
from ray.tune.integration.xgboost import (
1617
TuneReportCallback as OrigTuneReportCallback,
@@ -39,30 +40,53 @@ def is_session_enabled():
3940
flatten_dict = is_session_enabled
4041
TUNE_INSTALLED = False
4142

43+
4244
if TUNE_INSTALLED:
43-
# New style callbacks.
44-
class TuneReportCallback(OrigTuneReportCallback):
45-
def after_iteration(self, model, epoch: int, evals_log: Dict):
46-
if get_rabit_rank() == 0:
47-
report_dict = self._get_report_dict(evals_log)
48-
put_queue(lambda: tune.report(**report_dict))
49-
50-
class _TuneCheckpointCallback(_OrigTuneCheckpointCallback):
51-
def after_iteration(self, model, epoch: int, evals_log: Dict):
52-
if get_rabit_rank() == 0:
53-
put_queue(
54-
lambda: self._create_checkpoint(
55-
model, epoch, self._filename, self._frequency
45+
if not hasattr(train, "report"):
46+
47+
# New style callbacks.
48+
class TuneReportCallback(OrigTuneReportCallback):
49+
def after_iteration(self, model, epoch: int, evals_log: Dict):
50+
if get_rabit_rank() == 0:
51+
report_dict = self._get_report_dict(evals_log)
52+
put_queue(lambda: tune.report(**report_dict))
53+
54+
class _TuneCheckpointCallback(_OrigTuneCheckpointCallback):
55+
def after_iteration(self, model, epoch: int, evals_log: Dict):
56+
if get_rabit_rank() == 0:
57+
put_queue(
58+
lambda: self._create_checkpoint(
59+
model, epoch, self._filename, self._frequency
60+
)
5661
)
57-
)
5862

59-
class TuneReportCheckpointCallback(OrigTuneReportCheckpointCallback):
60-
_checkpoint_callback_cls = _TuneCheckpointCallback
61-
_report_callbacks_cls = TuneReportCallback
63+
class TuneReportCheckpointCallback(OrigTuneReportCheckpointCallback):
64+
_checkpoint_callback_cls = _TuneCheckpointCallback
65+
_report_callbacks_cls = TuneReportCallback
66+
67+
else:
68+
69+
class TuneReportCheckpointCallback(OrigTuneReportCheckpointCallback):
70+
def after_iteration(self, model, epoch: int, evals_log: Dict):
71+
if get_rabit_rank() == 0:
72+
put_queue(
73+
lambda: super(
74+
TuneReportCheckpointCallback, self
75+
).after_iteration(model=model, epoch=epoch, evals_log=evals_log)
76+
)
77+
78+
class TuneReportCallback(OrigTuneReportCallback):
79+
def after_iteration(self, model, epoch: int, evals_log: Dict):
80+
if get_rabit_rank() == 0:
81+
put_queue(
82+
lambda: super(TuneReportCallback, self).after_iteration(
83+
model=model, epoch=epoch, evals_log=evals_log
84+
)
85+
)
6286

6387

6488
def _try_add_tune_callback(kwargs: Dict):
65-
if TUNE_INSTALLED and is_session_enabled():
89+
if TUNE_INSTALLED and (is_session_enabled() or get_session()):
6690
callbacks = kwargs.get("callbacks", []) or []
6791
new_callbacks = []
6892
has_tune_callback = False
@@ -88,10 +112,19 @@ def _try_add_tune_callback(kwargs: Dict):
88112
)
89113
has_tune_callback = True
90114
elif isinstance(cb, OrigTuneReportCheckpointCallback):
115+
if getattr(cb, "_report", None):
116+
orig_metrics = cb._report._metrics
117+
orig_filename = cb._checkpoint._filename
118+
orig_frequency = cb._checkpoint._frequency
119+
else:
120+
orig_metrics = cb._metrics
121+
orig_filename = cb._filename
122+
orig_frequency = cb._frequency
123+
91124
replace_cb = TuneReportCheckpointCallback(
92-
metrics=cb._report._metrics,
93-
filename=cb._checkpoint._filename,
94-
frequency=cb._checkpoint._frequency,
125+
metrics=orig_metrics,
126+
filename=orig_filename,
127+
frequency=orig_frequency,
95128
)
96129
new_callbacks.append(replace_cb)
97130
logging.warning(

0 commit comments

Comments
 (0)