11# Tune imports.
2- import os
3- from typing import Dict , Union , List , Optional
2+ from typing import Dict , Optional
43
54import ray
65
7- try :
8- from typing import OrderedDict
9- except ImportError :
10- from collections import OrderedDict
11-
126import logging
137
148from ray .util .annotations import PublicAPI
159
1610from xgboost_ray .xgb import xgboost as xgb
1711
18- from xgboost_ray .compat import TrainingCallback
1912from xgboost_ray .session import put_queue , get_rabit_rank
2013from xgboost_ray .util import Unavailable , force_on_current_node
2114
@@ -42,90 +35,7 @@ def is_session_enabled():
4235 flatten_dict = is_session_enabled
4336 TUNE_INSTALLED = False
4437
45- # Todo(krfricke): Remove after next ray core release
46- if not hasattr (OrigTuneReportCallback , "_get_report_dict" ) or not issubclass (
47- OrigTuneReportCallback , TrainingCallback ):
48- TUNE_LEGACY = True
49- else :
50- TUNE_LEGACY = False
51-
52- # Todo(amogkam): Remove after Ray 1.3 release.
53- try :
54- from ray .tune import PlacementGroupFactory
55-
56- TUNE_USING_PG = True
57- except ImportError :
58- TUNE_USING_PG = False
59- PlacementGroupFactory = Unavailable
60-
61- if TUNE_LEGACY and TUNE_INSTALLED :
62- # Until the next release, keep compatible callbacks here.
63- class TuneReportCallback (OrigTuneReportCallback , TrainingCallback ):
64- def _get_report_dict (self , evals_log ):
65- if isinstance (evals_log , OrderedDict ):
66- # xgboost>=1.3
67- result_dict = flatten_dict (evals_log , delimiter = "-" )
68- for k in list (result_dict ):
69- result_dict [k ] = result_dict [k ][0 ]
70- else :
71- # xgboost<1.3
72- result_dict = dict (evals_log )
73- if not self ._metrics :
74- report_dict = result_dict
75- else :
76- report_dict = {}
77- for key in self ._metrics :
78- if isinstance (self ._metrics , dict ):
79- metric = self ._metrics [key ]
80- else :
81- metric = key
82- report_dict [key ] = result_dict [metric ]
83- return report_dict
84-
85- def after_iteration (self , model , epoch : int , evals_log : Dict ):
86- if get_rabit_rank () == 0 :
87- report_dict = self ._get_report_dict (evals_log )
88- put_queue (lambda : tune .report (** report_dict ))
89-
90- class _TuneCheckpointCallback (_OrigTuneCheckpointCallback ,
91- TrainingCallback ):
92- def __init__ (self , filename : str , frequency : int ):
93- super (_TuneCheckpointCallback , self ).__init__ (filename )
94- self ._frequency = frequency
95-
96- @staticmethod
97- def _create_checkpoint (model , epoch : int , filename : str ,
98- frequency : int ):
99- if epoch % frequency > 0 :
100- return
101- with tune .checkpoint_dir (step = epoch ) as checkpoint_dir :
102- model .save_model (os .path .join (checkpoint_dir , filename ))
103-
104- def after_iteration (self , model , epoch : int , evals_log : Dict ):
105- if get_rabit_rank () == 0 :
106- put_queue (lambda : self ._create_checkpoint (
107- model , epoch , self ._filename , self ._frequency ))
108-
109- class TuneReportCheckpointCallback (OrigTuneReportCheckpointCallback ,
110- TrainingCallback ):
111- _checkpoint_callback_cls = _TuneCheckpointCallback
112- _report_callbacks_cls = TuneReportCallback
113-
114- def __init__ (
115- self ,
116- metrics : Union [None , str , List [str ], Dict [str , str ]] = None ,
117- filename : str = "checkpoint" ,
118- frequency : int = 5 ):
119- self ._checkpoint = self ._checkpoint_callback_cls (
120- filename , frequency )
121- self ._report = self ._report_callbacks_cls (metrics )
122-
123- def after_iteration (self , model , epoch : int , evals_log : Dict ):
124- if get_rabit_rank () == 0 :
125- self ._checkpoint .after_iteration (model , epoch , evals_log )
126- self ._report .after_iteration (model , epoch , evals_log )
127-
128- elif TUNE_INSTALLED :
38+ if TUNE_INSTALLED :
12939 # New style callbacks.
13040 class TuneReportCallback (OrigTuneReportCallback ):
13141 def after_iteration (self , model , epoch : int , evals_log : Dict ):
@@ -168,15 +78,10 @@ def _try_add_tune_callback(kwargs: Dict):
16878 target = "xgboost_ray.tune.TuneReportCallback" ))
16979 has_tune_callback = True
17080 elif isinstance (cb , OrigTuneReportCheckpointCallback ):
171- if TUNE_LEGACY :
172- replace_cb = TuneReportCheckpointCallback (
173- metrics = cb ._report ._metrics ,
174- filename = cb ._checkpoint ._filename )
175- else :
176- replace_cb = TuneReportCheckpointCallback (
177- metrics = cb ._report ._metrics ,
178- filename = cb ._checkpoint ._filename ,
179- frequency = cb ._checkpoint ._frequency )
81+ replace_cb = TuneReportCheckpointCallback (
82+ metrics = cb ._report ._metrics ,
83+ filename = cb ._checkpoint ._filename ,
84+ frequency = cb ._checkpoint ._frequency )
18085 new_callbacks .append (replace_cb )
18186 logging .warning (
18287 REPLACE_MSG .format (
@@ -203,35 +108,21 @@ def _get_tune_resources(num_actors: int, cpus_per_actor: int,
203108 resources_per_actor : Optional [Dict ]):
204109 """Returns object to use for ``resources_per_trial`` with Ray Tune."""
205110 if TUNE_INSTALLED :
206- if not TUNE_USING_PG :
207- resources_per_actor = {} if not resources_per_actor \
208- else resources_per_actor
209- extra_custom_resources = {
210- k : v * num_actors
211- for k , v in resources_per_actor .items ()
212- }
213- return dict (
214- cpu = 1 ,
215- extra_cpu = cpus_per_actor * num_actors ,
216- extra_gpu = gpus_per_actor * num_actors ,
217- extra_custom_resources = extra_custom_resources ,
218- )
219- else :
220- from ray .tune import PlacementGroupFactory
221-
222- head_bundle = {"CPU" : 1 }
223- child_bundle = {"CPU" : cpus_per_actor , "GPU" : gpus_per_actor }
224- child_bundle_extra = {} if resources_per_actor is None else \
225- resources_per_actor
226- child_bundles = [{
227- ** child_bundle ,
228- ** child_bundle_extra
229- } for _ in range (num_actors )]
230- bundles = [head_bundle ] + child_bundles
231- placement_group_factory = PlacementGroupFactory (
232- bundles , strategy = "PACK" )
233-
234- return placement_group_factory
111+ from ray .tune import PlacementGroupFactory
112+
113+ head_bundle = {"CPU" : 1 }
114+ child_bundle = {"CPU" : cpus_per_actor , "GPU" : gpus_per_actor }
115+ child_bundle_extra = {} if resources_per_actor is None else \
116+ resources_per_actor
117+ child_bundles = [{
118+ ** child_bundle ,
119+ ** child_bundle_extra
120+ } for _ in range (num_actors )]
121+ bundles = [head_bundle ] + child_bundles
122+ placement_group_factory = PlacementGroupFactory (
123+ bundles , strategy = "PACK" )
124+
125+ return placement_group_factory
235126 else :
236127 raise RuntimeError ("Tune is not installed, so `get_tune_resources` is "
237128 "not supported. You can install Ray Tune via `pip "
0 commit comments