33from typing import Dict , Optional
44
55import ray
6+ from ray .train ._internal .session import get_session
67from ray .util .annotations import PublicAPI
78
89from xgboost_ray .session import get_rabit_rank , put_queue
910from xgboost_ray .util import Unavailable , force_on_current_node
1011from xgboost_ray .xgb import xgboost as xgb
1112
1213try :
13- from ray import tune
14+ from ray import train , tune
1415 from ray .tune import is_session_enabled
1516 from ray .tune .integration .xgboost import (
1617 TuneReportCallback as OrigTuneReportCallback ,
@@ -39,30 +40,53 @@ def is_session_enabled():
3940 flatten_dict = is_session_enabled
4041 TUNE_INSTALLED = False
4142
43+
4244if TUNE_INSTALLED :
43- # New style callbacks.
44- class TuneReportCallback (OrigTuneReportCallback ):
45- def after_iteration (self , model , epoch : int , evals_log : Dict ):
46- if get_rabit_rank () == 0 :
47- report_dict = self ._get_report_dict (evals_log )
48- put_queue (lambda : tune .report (** report_dict ))
49-
50- class _TuneCheckpointCallback (_OrigTuneCheckpointCallback ):
51- def after_iteration (self , model , epoch : int , evals_log : Dict ):
52- if get_rabit_rank () == 0 :
53- put_queue (
54- lambda : self ._create_checkpoint (
55- model , epoch , self ._filename , self ._frequency
45+ if not hasattr (train , "report" ):
46+
47+ # New style callbacks.
48+ class TuneReportCallback (OrigTuneReportCallback ):
49+ def after_iteration (self , model , epoch : int , evals_log : Dict ):
50+ if get_rabit_rank () == 0 :
51+ report_dict = self ._get_report_dict (evals_log )
52+ put_queue (lambda : tune .report (** report_dict ))
53+
54+ class _TuneCheckpointCallback (_OrigTuneCheckpointCallback ):
55+ def after_iteration (self , model , epoch : int , evals_log : Dict ):
56+ if get_rabit_rank () == 0 :
57+ put_queue (
58+ lambda : self ._create_checkpoint (
59+ model , epoch , self ._filename , self ._frequency
60+ )
5661 )
57- )
5862
59- class TuneReportCheckpointCallback (OrigTuneReportCheckpointCallback ):
60- _checkpoint_callback_cls = _TuneCheckpointCallback
61- _report_callbacks_cls = TuneReportCallback
63+ class TuneReportCheckpointCallback (OrigTuneReportCheckpointCallback ):
64+ _checkpoint_callback_cls = _TuneCheckpointCallback
65+ _report_callbacks_cls = TuneReportCallback
66+
67+ else :
68+
69+ class TuneReportCheckpointCallback (OrigTuneReportCheckpointCallback ):
70+ def after_iteration (self , model , epoch : int , evals_log : Dict ):
71+ if get_rabit_rank () == 0 :
72+ put_queue (
73+ lambda : super (
74+ TuneReportCheckpointCallback , self
75+ ).after_iteration (model = model , epoch = epoch , evals_log = evals_log )
76+ )
77+
78+ class TuneReportCallback (OrigTuneReportCallback ):
79+ def after_iteration (self , model , epoch : int , evals_log : Dict ):
80+ if get_rabit_rank () == 0 :
81+ put_queue (
82+ lambda : super (TuneReportCallback , self ).after_iteration (
83+ model = model , epoch = epoch , evals_log = evals_log
84+ )
85+ )
6286
6387
6488def _try_add_tune_callback (kwargs : Dict ):
65- if TUNE_INSTALLED and is_session_enabled ():
89+ if TUNE_INSTALLED and ( is_session_enabled () or get_session () ):
6690 callbacks = kwargs .get ("callbacks" , []) or []
6791 new_callbacks = []
6892 has_tune_callback = False
@@ -88,10 +112,19 @@ def _try_add_tune_callback(kwargs: Dict):
88112 )
89113 has_tune_callback = True
90114 elif isinstance (cb , OrigTuneReportCheckpointCallback ):
115+ if getattr (cb , "_report" , None ):
116+ orig_metrics = cb ._report ._metrics
117+ orig_filename = cb ._checkpoint ._filename
118+ orig_frequency = cb ._checkpoint ._frequency
119+ else :
120+ orig_metrics = cb ._metrics
121+ orig_filename = cb ._filename
122+ orig_frequency = cb ._frequency
123+
91124 replace_cb = TuneReportCheckpointCallback (
92- metrics = cb . _report . _metrics ,
93- filename = cb . _checkpoint . _filename ,
94- frequency = cb . _checkpoint . _frequency ,
125+ metrics = orig_metrics ,
126+ filename = orig_filename ,
127+ frequency = orig_frequency ,
95128 )
96129 new_callbacks .append (replace_cb )
97130 logging .warning (
0 commit comments