Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 03e0a34

Browse files
authored
Lazily read environment variables (#166)
1 parent 88c3188 commit 03e0a34

File tree

5 files changed

+102
-54
lines changed

5 files changed

+102
-54
lines changed

xgboost_ray/elastic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
from xgboost_ray.main import RayParams, _TrainingState, \
77
logger, ActorHandle, _PrepareActorTask, _create_actor, \
8-
RayXGBoostActorAvailable, \
9-
ELASTIC_RESTART_RESOURCE_CHECK_S, ELASTIC_RESTART_GRACE_PERIOD_S
8+
RayXGBoostActorAvailable, ENV
109

1110
from xgboost_ray.matrix import RayDMatrix
1211

@@ -36,7 +35,7 @@ def _maybe_schedule_new_actors(
3635

3736
# Check periodically every n seconds.
3837
if now < training_state.last_resource_check_at + \
39-
ELASTIC_RESTART_RESOURCE_CHECK_S:
38+
ENV.ELASTIC_RESTART_RESOURCE_CHECK_S:
4039
return False
4140

4241
training_state.last_resource_check_at = now
@@ -108,7 +107,7 @@ def _update_scheduled_actor_states(training_state: _TrainingState):
108107
# If an actor became ready but other actors are pending, we wait
109108
# for n seconds before restarting, as chances are that they become
110109
# ready as well (e.g. if a large node came up).
111-
grace_period = ELASTIC_RESTART_GRACE_PERIOD_S
110+
grace_period = ENV.ELASTIC_RESTART_GRACE_PERIOD_S
112111
if training_state.restart_training_at is None:
113112
logger.debug(
114113
f"A RayXGBoostActor became ready for training. Waiting "

xgboost_ray/main.py

Lines changed: 75 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
import pandas as pd
1717

1818
from xgboost_ray.xgb import xgboost as xgb
19-
from xgboost.core import XGBoostError, EarlyStopException
19+
from xgboost.core import XGBoostError
20+
21+
try:
22+
from xgboost.core import EarlyStopException
23+
except ImportError:
24+
25+
class EarlyStopException(XGBoostError):
26+
pass
2027

2128
from xgboost_ray.callback import DistributedCallback, \
2229
DistributedCallbackContainer
@@ -64,28 +71,56 @@ def inner_f(*args, **kwargs):
6471
from xgboost_ray.session import init_session, put_queue, \
6572
set_session_queue
6673

67-
# Whether to use SPREAD placement group strategy for training.
68-
_USE_SPREAD_STRATEGY = int(os.getenv("RXGB_USE_SPREAD_STRATEGY", 1))
6974

70-
# How long to wait for placement group creation before failing.
71-
PLACEMENT_GROUP_TIMEOUT_S = int(
72-
os.getenv("RXGB_PLACEMENT_GROUP_TIMEOUT_S", 100))
75+
def _get_environ(item: str, old_val: Any):
76+
env_var = f"RXGB_{item}"
77+
new_val = old_val
78+
if env_var in os.environ:
79+
new_val_str = os.environ.get(env_var)
80+
81+
if isinstance(old_val, bool):
82+
new_val = bool(int(new_val_str))
83+
elif isinstance(old_val, int):
84+
new_val = int(new_val_str)
85+
elif isinstance(old_val, float):
86+
new_val = float(new_val_str)
87+
else:
88+
new_val = new_val_str
89+
90+
return new_val
91+
92+
93+
@dataclass
94+
class _XGBoostEnv:
95+
# Whether to use SPREAD placement group strategy for training.
96+
USE_SPREAD_STRATEGY: bool = True
97+
98+
# How long to wait for placement group creation before failing.
99+
PLACEMENT_GROUP_TIMEOUT_S: int = 100
100+
101+
# Status report frequency when waiting for initial actors
102+
# and during training
103+
STATUS_FREQUENCY_S: int = 30
104+
105+
# If restarting failed actors is disabled
106+
ELASTIC_RESTART_DISABLED: bool = False
107+
108+
# How often to check for new available resources
109+
ELASTIC_RESTART_RESOURCE_CHECK_S: int = 30
73110

74-
# Status report frequency when waiting for initial actors and during training
75-
STATUS_FREQUENCY_S = int(os.getenv("RXGB_STATUS_FREQUENCY_S", 30))
111+
# How long to wait before triggering a new start of the training loop
112+
# when new actors become available
113+
ELASTIC_RESTART_GRACE_PERIOD_S: int = 10
76114

77-
# If restarting failed actors is disabled
78-
ELASTIC_RESTART_DISABLED = bool(
79-
int(os.getenv("RXGB_ELASTIC_RESTART_DISABLED", 0)))
115+
def __getattribute__(self, item):
116+
old_val = super(_XGBoostEnv, self).__getattribute__(item)
117+
new_val = _get_environ(item, old_val)
118+
if new_val != old_val:
119+
setattr(self, item, new_val)
120+
return super(_XGBoostEnv, self).__getattribute__(item)
80121

81-
# How often to check for new available resources
82-
ELASTIC_RESTART_RESOURCE_CHECK_S = int(
83-
os.getenv("RXGB_ELASTIC_RESTART_RESOURCE_CHECK_S", 30))
84122

85-
# How long to wait before triggering a new start of the training loop
86-
# when new actors become available
87-
ELASTIC_RESTART_GRACE_PERIOD_S = int(
88-
os.getenv("RXGB_ELASTIC_RESTART_GRACE_PERIOD_S", 10))
123+
ENV = _XGBoostEnv()
89124

90125
xgboost_version = xgb.__version__ if xgb else "0.0.0"
91126

@@ -138,22 +173,32 @@ def _is_client_connected() -> bool:
138173
return False
139174

140175

141-
class _RabitTracker(RabitTracker):
176+
class _RabitTrackerCompatMixin:
177+
"""Fallback calls to legacy terminology"""
178+
179+
def accept_workers(self, n_workers: int):
180+
return self.accept_slaves(n_workers)
181+
182+
def worker_envs(self):
183+
return self.slave_envs()
184+
185+
186+
class _RabitTracker(RabitTracker, _RabitTrackerCompatMixin):
142187
"""
143188
This method overwrites the xgboost-provided RabitTracker to switch
144189
from a daemon thread to a multiprocessing Process. This is so that
145190
we are able to terminate/kill the tracking process at will.
146191
"""
147192

148-
def start(self, nslave):
193+
def start(self, nworker):
149194
# TODO: refactor RabitTracker to support spawn process creation.
150195
# In python 3.8, spawn is used as default process creation on macOS.
151196
# But spawn doesn't work because `run` is not pickleable.
152197
# For now we force the start method to use fork.
153198
multiprocessing.set_start_method("fork", force=True)
154199

155200
def run():
156-
self.accept_slaves(nslave)
201+
self.accept_workers(nworker)
157202

158203
self.thread = multiprocessing.Process(target=run, args=())
159204
self.thread.start()
@@ -178,10 +223,10 @@ def _start_rabit_tracker(num_workers: int):
178223

179224
env = {"DMLC_NUM_WORKER": num_workers}
180225

181-
rabit_tracker = _RabitTracker(hostIP=host, nslave=num_workers)
226+
rabit_tracker = _RabitTracker(host, num_workers)
182227

183228
# Get tracker Host + IP
184-
env.update(rabit_tracker.slave_envs())
229+
env.update(rabit_tracker.worker_envs())
185230
rabit_tracker.start(num_workers)
186231

187232
logger.debug(
@@ -704,7 +749,7 @@ def _create_actor(
704749

705750
def _trigger_data_load(actor, dtrain, evals):
706751
wait_load = [actor.load_data.remote(dtrain)]
707-
for deval, name in evals:
752+
for deval, _name in evals:
708753
wait_load.append(actor.load_data.remote(deval))
709754
return wait_load
710755

@@ -778,7 +823,7 @@ def _create_placement_group(cpus_per_actor, gpus_per_actor,
778823
pg = placement_group(bundles, strategy=strategy)
779824
# Wait for placement group to get created.
780825
logger.debug("Waiting for placement group to start.")
781-
ready, _ = ray.wait([pg.ready()], timeout=PLACEMENT_GROUP_TIMEOUT_S)
826+
ready, _ = ray.wait([pg.ready()], timeout=ENV.PLACEMENT_GROUP_TIMEOUT_S)
782827
if ready:
783828
logger.debug("Placement group has started.")
784829
else:
@@ -955,7 +1000,7 @@ def handle_actor_failure(actor_id):
9551000
# Construct list before calling any() to force evaluation
9561001
ready_states = [task.is_ready() for task in prepare_actor_tasks]
9571002
while not all(ready_states):
958-
if time.time() >= last_status + STATUS_FREQUENCY_S:
1003+
if time.time() >= last_status + ENV.STATUS_FREQUENCY_S:
9591004
wait_time = time.time() - start_wait
9601005
logger.info(f"Waiting until actors are ready "
9611006
f"({wait_time:.0f} seconds passed).")
@@ -1029,7 +1074,7 @@ def handle_actor_failure(actor_id):
10291074
callback_returns=callback_returns)
10301075

10311076
if ray_params.elastic_training \
1032-
and not ELASTIC_RESTART_DISABLED:
1077+
and not ENV.ELASTIC_RESTART_DISABLED:
10331078
_maybe_schedule_new_actors(
10341079
training_state=_training_state,
10351080
num_cpus_per_actor=cpus_per_actor,
@@ -1041,7 +1086,7 @@ def handle_actor_failure(actor_id):
10411086
# This may raise RayXGBoostActorAvailable
10421087
_update_scheduled_actor_states(_training_state)
10431088

1044-
if time.time() >= last_status + STATUS_FREQUENCY_S:
1089+
if time.time() >= last_status + ENV.STATUS_FREQUENCY_S:
10451090
wait_time = time.time() - start_wait
10461091
logger.info(f"Training in progress "
10471092
f"({wait_time:.0f} seconds since last restart).")
@@ -1290,7 +1335,7 @@ def _wrapped(*args, **kwargs):
12901335
if not dtrain.loaded and not dtrain.distributed:
12911336
dtrain.load_data(ray_params.num_actors)
12921337

1293-
for (deval, name) in evals:
1338+
for (deval, _name) in evals:
12941339
if not deval.has_label:
12951340
raise ValueError(
12961341
"Evaluation data has no label set. Please make sure to set "
@@ -1321,7 +1366,7 @@ def _wrapped(*args, **kwargs):
13211366
placement_strategy = None
13221367
else:
13231368
placement_strategy = "PACK"
1324-
elif bool(_USE_SPREAD_STRATEGY):
1369+
elif bool(ENV.USE_SPREAD_STRATEGY):
13251370
placement_strategy = "SPREAD"
13261371

13271372
if placement_strategy is not None:

xgboost_ray/tests/test_colocation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,10 @@ def inner_func(config):
176176
num_samples=1,
177177
)
178178

179-
@patch("xgboost_ray.main.PLACEMENT_GROUP_TIMEOUT_S", 5)
180179
def test_timeout(self):
181180
"""Checks that an error occurs when placement group setup times out."""
181+
os.environ["RXGB_PLACEMENT_GROUP_TIMEOUT_S"] = "5"
182+
182183
with self.ray_start_cluster() as cluster:
183184
ray.init(address=cluster.address)
184185

xgboost_ray/tests/test_fault_tolerance.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ class XGBoostRayFaultToleranceTest(unittest.TestCase):
3333
"""
3434

3535
def setUp(self):
36+
# Set default
37+
os.environ["RXGB_ELASTIC_RESTART_DISABLED"] = "0"
38+
3639
repeat = 8 # Repeat data a couple of times for stability
3740
self.x = np.array([
3841
[1, 0, 0, 0], # Feature 0 -> Label 0
@@ -107,9 +110,10 @@ def keep(actors, *args, **kwargs):
107110
# Two workers finished, so N=32
108111
self.assertEqual(additional_results["total_n"], 32)
109112

110-
@patch("xgboost_ray.main.ELASTIC_RESTART_DISABLED", True)
111113
def testTrainingContinuationElasticKilled(self):
112114
"""This should continue after one actor died."""
115+
os.environ["RXGB_ELASTIC_RESTART_DISABLED"] = "1"
116+
113117
logging.getLogger().setLevel(10)
114118

115119
additional_results = {}
@@ -148,7 +152,6 @@ def keep(actors, *args, **kwargs):
148152
# Only one worker finished, so n=16
149153
self.assertEqual(additional_results["total_n"], 16)
150154

151-
@patch("xgboost_ray.main.ELASTIC_RESTART_DISABLED", False)
152155
def testTrainingContinuationElasticKilledRestarted(self):
153156
"""This should continue after one actor died and restart it."""
154157
logging.getLogger().setLevel(10)
@@ -201,9 +204,10 @@ def keep(actors, *args, **kwargs):
201204
# Both workers finished, so n=32
202205
self.assertEqual(additional_results["total_n"], 32)
203206

204-
@patch("xgboost_ray.main.ELASTIC_RESTART_DISABLED", True)
205207
def testTrainingContinuationElasticMultiKilled(self):
206208
"""This should still show 20 boost rounds after two failures."""
209+
os.environ["RXGB_ELASTIC_RESTART_DISABLED"] = "1"
210+
207211
logging.getLogger().setLevel(10)
208212

209213
additional_results = {}
@@ -232,9 +236,9 @@ def testTrainingContinuationElasticMultiKilled(self):
232236
self.assertSequenceEqual(list(self.y), list(pred_y))
233237
print(f"Got correct predictions: {pred_y}")
234238

235-
@patch("xgboost_ray.main.ELASTIC_RESTART_DISABLED", True)
236239
def testTrainingContinuationElasticFailed(self):
237240
"""This should continue after one actor failed training."""
241+
os.environ["RXGB_ELASTIC_RESTART_DISABLED"] = "1"
238242

239243
additional_results = {}
240244
keep_actors = {}
@@ -285,6 +289,8 @@ def testTrainingStop(self):
285289

286290
def testTrainingStopElastic(self):
287291
"""This should now stop training after one actor died."""
292+
os.environ["RXGB_ELASTIC_RESTART_DISABLED"] = "0"
293+
288294
# The `train()` function raises a RuntimeError
289295
ft_manager = FaultToleranceManager.remote()
290296

@@ -419,8 +425,6 @@ def testSameResultWithAndWithoutError(self):
419425
@patch("xgboost_ray.main._PrepareActorTask", _FakeTask)
420426
@patch("xgboost_ray.elastic._PrepareActorTask", _FakeTask)
421427
@patch("xgboost_ray.main._RemoteRayXGBoostActor", MagicMock)
422-
@patch("xgboost_ray.main.ELASTIC_RESTART_GRACE_PERIOD_S", 30)
423-
@patch("xgboost_ray.elastic.ELASTIC_RESTART_GRACE_PERIOD_S", 30)
424428
def testMaybeScheduleNewActors(self):
425429
"""Test scheduling of new actors if resources become available.
426430
@@ -436,6 +440,8 @@ def testMaybeScheduleNewActors(self):
436440
from xgboost_ray.elastic import _update_scheduled_actor_states
437441
from xgboost_ray.elastic import _maybe_schedule_new_actors
438442

443+
os.environ["RXGB_ELASTIC_RESTART_GRACE_PERIOD_S"] = "30"
444+
439445
# Three actors are dead
440446
actors = [
441447
MagicMock(), None,
@@ -520,7 +526,7 @@ def fake_create_actor(rank, *args, **kwargs):
520526
# actor.
521527
_update_scheduled_actor_states(training_state=state)
522528

523-
# Grace period is set through ELASTIC_RESTART_GRACE_PERIOD_S
529+
# Grace period is set through ENV.ELASTIC_RESTART_GRACE_PERIOD_S
524530
# Allow for some slack in test execution
525531
self.assertGreaterEqual(state.restart_training_at,
526532
time.time() + 22)

xgboost_ray/tests/test_sklearn_matrix.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,21 @@ def testClassifier(self, n_class=2):
4343
train_matrix = RayDMatrix(X_train, y_train)
4444
test_matrix = RayDMatrix(X_test, y_test)
4545

46-
with self.assertRaisesRegex(ValueError, "use_label_encoder"):
46+
with self.assertRaisesRegex(Exception, "use_label_encoder"):
4747
RayXGBClassifier(
4848
use_label_encoder=True, **self.params).fit(train_matrix, None)
4949

50-
with self.assertRaisesRegex(ValueError, "num_class"):
50+
with self.assertRaisesRegex(Exception, "num_class"):
5151
RayXGBClassifier(
5252
use_label_encoder=False, **self.params).fit(
5353
train_matrix, None)
5454

55-
with self.assertRaisesRegex(ValueError,
56-
r"must be \(RayDMatrix, str\)"):
55+
with self.assertRaisesRegex(Exception, r"must be \(RayDMatrix, str\)"):
5756
RayXGBClassifier(
5857
use_label_encoder=False, **self.params).fit(
5958
train_matrix, None, eval_set=[(X_test, y_test)])
6059

61-
with self.assertRaisesRegex(ValueError,
60+
with self.assertRaisesRegex(Exception,
6261
r"must be \(array_like, array_like\)"):
6362
RayXGBClassifier(
6463
use_label_encoder=False, **self.params).fit(
@@ -97,15 +96,14 @@ def testClassifierLegacy(self, n_class=2):
9796
train_matrix = RayDMatrix(X_train, y_train)
9897
test_matrix = RayDMatrix(X_test, y_test)
9998

100-
with self.assertRaisesRegex(ValueError, "num_class"):
99+
with self.assertRaisesRegex(Exception, "num_class"):
101100
RayXGBClassifier(**self.params).fit(train_matrix, None)
102101

103-
with self.assertRaisesRegex(ValueError,
104-
r"must be \(RayDMatrix, str\)"):
102+
with self.assertRaisesRegex(Exception, r"must be \(RayDMatrix, str\)"):
105103
RayXGBClassifier(**self.params).fit(
106104
train_matrix, None, eval_set=[(X_test, y_test)])
107105

108-
with self.assertRaisesRegex(ValueError,
106+
with self.assertRaisesRegex(Exception,
109107
r"must be \(array_like, array_like\)"):
110108
RayXGBClassifier(**self.params).fit(
111109
X_train, y_train, eval_set=[(test_matrix, "eval")])
@@ -140,12 +138,11 @@ def testRegressor(self):
140138
train_matrix = RayDMatrix(X_train, y_train)
141139
test_matrix = RayDMatrix(X_test, y_test)
142140

143-
with self.assertRaisesRegex(ValueError,
144-
r"must be \(RayDMatrix, str\)"):
141+
with self.assertRaisesRegex(Exception, r"must be \(RayDMatrix, str\)"):
145142
RayXGBRegressor(**self.params).fit(
146143
train_matrix, None, eval_set=[(X_test, y_test)])
147144

148-
with self.assertRaisesRegex(ValueError,
145+
with self.assertRaisesRegex(Exception,
149146
r"must be \(array_like, array_like\)"):
150147
RayXGBRegressor(**self.params).fit(
151148
X_train, y_train, eval_set=[(test_matrix, "eval")])

0 commit comments

Comments
 (0)