1616import pandas as pd
1717
1818from xgboost_ray .xgb import xgboost as xgb
19- from xgboost .core import XGBoostError , EarlyStopException
19+ from xgboost .core import XGBoostError
20+
21+ try :
22+ from xgboost .core import EarlyStopException
23+ except ImportError :
24+
25+ class EarlyStopException (XGBoostError ):
26+ pass
2027
2128from xgboost_ray .callback import DistributedCallback , \
2229 DistributedCallbackContainer
@@ -64,28 +71,56 @@ def inner_f(*args, **kwargs):
6471from xgboost_ray .session import init_session , put_queue , \
6572 set_session_queue
6673
67- # Whether to use SPREAD placement group strategy for training.
68- _USE_SPREAD_STRATEGY = int (os .getenv ("RXGB_USE_SPREAD_STRATEGY" , 1 ))
6974
70- # How long to wait for placement group creation before failing.
71- PLACEMENT_GROUP_TIMEOUT_S = int (
72- os .getenv ("RXGB_PLACEMENT_GROUP_TIMEOUT_S" , 100 ))
75+ def _get_environ (item : str , old_val : Any ):
76+ env_var = f"RXGB_{ item } "
77+ new_val = old_val
78+ if env_var in os .environ :
79+ new_val_str = os .environ .get (env_var )
80+
81+ if isinstance (old_val , bool ):
82+ new_val = bool (int (new_val_str ))
83+ elif isinstance (old_val , int ):
84+ new_val = int (new_val_str )
85+ elif isinstance (old_val , float ):
86+ new_val = float (new_val_str )
87+ else :
88+ new_val = new_val_str
89+
90+ return new_val
91+
92+
93+ @dataclass
94+ class _XGBoostEnv :
95+ # Whether to use SPREAD placement group strategy for training.
96+ USE_SPREAD_STRATEGY : bool = True
97+
98+ # How long to wait for placement group creation before failing.
99+ PLACEMENT_GROUP_TIMEOUT_S : int = 100
100+
101+ # Status report frequency when waiting for initial actors
102+ # and during training
103+ STATUS_FREQUENCY_S : int = 30
104+
105+ # If restarting failed actors is disabled
106+ ELASTIC_RESTART_DISABLED : bool = False
107+
108+ # How often to check for new available resources
109+ ELASTIC_RESTART_RESOURCE_CHECK_S : int = 30
73110
74- # Status report frequency when waiting for initial actors and during training
75- STATUS_FREQUENCY_S = int (os .getenv ("RXGB_STATUS_FREQUENCY_S" , 30 ))
111+ # How long to wait before triggering a new start of the training loop
112+ # when new actors become available
113+ ELASTIC_RESTART_GRACE_PERIOD_S : int = 10
76114
77- # If restarting failed actors is disabled
78- ELASTIC_RESTART_DISABLED = bool (
79- int (os .getenv ("RXGB_ELASTIC_RESTART_DISABLED" , 0 )))
115+ def __getattribute__ (self , item ):
116+ old_val = super (_XGBoostEnv , self ).__getattribute__ (item )
117+ new_val = _get_environ (item , old_val )
118+ if new_val != old_val :
119+ setattr (self , item , new_val )
120+ return super (_XGBoostEnv , self ).__getattribute__ (item )
80121
81- # How often to check for new available resources
82- ELASTIC_RESTART_RESOURCE_CHECK_S = int (
83- os .getenv ("RXGB_ELASTIC_RESTART_RESOURCE_CHECK_S" , 30 ))
84122
85- # How long to wait before triggering a new start of the training loop
86- # when new actors become available
87- ELASTIC_RESTART_GRACE_PERIOD_S = int (
88- os .getenv ("RXGB_ELASTIC_RESTART_GRACE_PERIOD_S" , 10 ))
123+ ENV = _XGBoostEnv ()
89124
90125xgboost_version = xgb .__version__ if xgb else "0.0.0"
91126
@@ -138,22 +173,32 @@ def _is_client_connected() -> bool:
138173 return False
139174
140175
141- class _RabitTracker (RabitTracker ):
176+ class _RabitTrackerCompatMixin :
177+ """Fallback calls to legacy terminology"""
178+
179+ def accept_workers (self , n_workers : int ):
180+ return self .accept_slaves (n_workers )
181+
182+ def worker_envs (self ):
183+ return self .slave_envs ()
184+
185+
186+ class _RabitTracker (RabitTracker , _RabitTrackerCompatMixin ):
142187 """
143188 This method overwrites the xgboost-provided RabitTracker to switch
144189 from a daemon thread to a multiprocessing Process. This is so that
145190 we are able to terminate/kill the tracking process at will.
146191 """
147192
148- def start (self , nslave ):
193+ def start (self , nworker ):
149194 # TODO: refactor RabitTracker to support spawn process creation.
150195 # In python 3.8, spawn is used as default process creation on macOS.
151196 # But spawn doesn't work because `run` is not pickleable.
152197 # For now we force the start method to use fork.
153198 multiprocessing .set_start_method ("fork" , force = True )
154199
155200 def run ():
156- self .accept_slaves ( nslave )
201+ self .accept_workers ( nworker )
157202
158203 self .thread = multiprocessing .Process (target = run , args = ())
159204 self .thread .start ()
@@ -178,10 +223,10 @@ def _start_rabit_tracker(num_workers: int):
178223
179224 env = {"DMLC_NUM_WORKER" : num_workers }
180225
181- rabit_tracker = _RabitTracker (hostIP = host , nslave = num_workers )
226+ rabit_tracker = _RabitTracker (host , num_workers )
182227
183228 # Get tracker Host + IP
184- env .update (rabit_tracker .slave_envs ())
229+ env .update (rabit_tracker .worker_envs ())
185230 rabit_tracker .start (num_workers )
186231
187232 logger .debug (
@@ -704,7 +749,7 @@ def _create_actor(
704749
705750def _trigger_data_load (actor , dtrain , evals ):
706751 wait_load = [actor .load_data .remote (dtrain )]
707- for deval , name in evals :
752+ for deval , _name in evals :
708753 wait_load .append (actor .load_data .remote (deval ))
709754 return wait_load
710755
@@ -778,7 +823,7 @@ def _create_placement_group(cpus_per_actor, gpus_per_actor,
778823 pg = placement_group (bundles , strategy = strategy )
779824 # Wait for placement group to get created.
780825 logger .debug ("Waiting for placement group to start." )
781- ready , _ = ray .wait ([pg .ready ()], timeout = PLACEMENT_GROUP_TIMEOUT_S )
826+ ready , _ = ray .wait ([pg .ready ()], timeout = ENV . PLACEMENT_GROUP_TIMEOUT_S )
782827 if ready :
783828 logger .debug ("Placement group has started." )
784829 else :
@@ -955,7 +1000,7 @@ def handle_actor_failure(actor_id):
9551000 # Construct list before calling any() to force evaluation
9561001 ready_states = [task .is_ready () for task in prepare_actor_tasks ]
9571002 while not all (ready_states ):
958- if time .time () >= last_status + STATUS_FREQUENCY_S :
1003+ if time .time () >= last_status + ENV . STATUS_FREQUENCY_S :
9591004 wait_time = time .time () - start_wait
9601005 logger .info (f"Waiting until actors are ready "
9611006 f"({ wait_time :.0f} seconds passed)." )
@@ -1029,7 +1074,7 @@ def handle_actor_failure(actor_id):
10291074 callback_returns = callback_returns )
10301075
10311076 if ray_params .elastic_training \
1032- and not ELASTIC_RESTART_DISABLED :
1077+ and not ENV . ELASTIC_RESTART_DISABLED :
10331078 _maybe_schedule_new_actors (
10341079 training_state = _training_state ,
10351080 num_cpus_per_actor = cpus_per_actor ,
@@ -1041,7 +1086,7 @@ def handle_actor_failure(actor_id):
10411086 # This may raise RayXGBoostActorAvailable
10421087 _update_scheduled_actor_states (_training_state )
10431088
1044- if time .time () >= last_status + STATUS_FREQUENCY_S :
1089+ if time .time () >= last_status + ENV . STATUS_FREQUENCY_S :
10451090 wait_time = time .time () - start_wait
10461091 logger .info (f"Training in progress "
10471092 f"({ wait_time :.0f} seconds since last restart)." )
@@ -1290,7 +1335,7 @@ def _wrapped(*args, **kwargs):
12901335 if not dtrain .loaded and not dtrain .distributed :
12911336 dtrain .load_data (ray_params .num_actors )
12921337
1293- for (deval , name ) in evals :
1338+ for (deval , _name ) in evals :
12941339 if not deval .has_label :
12951340 raise ValueError (
12961341 "Evaluation data has no label set. Please make sure to set "
@@ -1321,7 +1366,7 @@ def _wrapped(*args, **kwargs):
13211366 placement_strategy = None
13221367 else :
13231368 placement_strategy = "PACK"
1324- elif bool (_USE_SPREAD_STRATEGY ):
1369+ elif bool (ENV . USE_SPREAD_STRATEGY ):
13251370 placement_strategy = "SPREAD"
13261371
13271372 if placement_strategy is not None :
0 commit comments