Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 93ff047

Browse files
authored
Fix number of boost rounds after failures (#59)
* Update README * Fix number of boosting rounds after failures * Fix for client mode
1 parent cee3008 commit 93ff047

File tree

4 files changed

+66
-5
lines changed

4 files changed

+66
-5
lines changed

xgboost_ray/main.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -819,9 +819,6 @@ def handle_actor_failure(actor_id):
819819
f"checkpointed model instead.")
820820
return kwargs["xgb_model"], {}, _training_state.additional_results
821821

822-
kwargs["num_boost_round"] = kwargs.get(
823-
"num_boost_round", 10) - _training_state.checkpoint.iteration - 1
824-
825822
# The callback_returns dict contains actor-rank indexed lists of
826823
# results obtained through the `put_queue` function, usually
827824
# sent via callbacks.
@@ -928,13 +925,14 @@ def handle_actor_failure(actor_id):
928925

929926
def train(params: Dict,
930927
dtrain: RayDMatrix,
928+
num_boost_round: int = 10,
931929
*args,
932930
evals=(),
933931
evals_result: Optional[Dict] = None,
934932
additional_results: Optional[Dict] = None,
935933
ray_params: Union[None, RayParams, Dict] = None,
936934
_remote: Optional[bool] = None,
937-
**kwargs):
935+
**kwargs) -> xgb.Booster:
938936
"""Distributed XGBoost training via Ray.
939937
940938
This function will connect to a Ray cluster, create ``num_actors``
@@ -1000,6 +998,7 @@ def _wrapped(*args, **kwargs):
1000998
_additional_results = {}
1001999
bst = train(
10021000
*args,
1001+
num_boost_round=num_boost_round,
10031002
evals_result=_evals_result,
10041003
additional_results=_additional_results,
10051004
**kwargs)
@@ -1108,7 +1107,19 @@ def _wrapped(*args, **kwargs):
11081107
start_actor_ranks = set(range(ray_params.num_actors)) # Start these
11091108

11101109
total_training_time = 0.
1110+
boost_rounds_left = num_boost_round
1111+
last_checkpoint_value = checkpoint.value
11111112
while tries <= max_actor_restarts:
1113+
# Only update number of iterations if the checkpoint changed
1114+
# If it didn't change, we already subtracted the iterations.
1115+
if checkpoint.iteration >= 0 and \
1116+
checkpoint.value != last_checkpoint_value:
1117+
boost_rounds_left -= checkpoint.iteration + 1
1118+
1119+
last_checkpoint_value = checkpoint.value
1120+
1121+
logger.debug(f"Boost rounds left: {boost_rounds_left}")
1122+
11121123
training_state = _TrainingState(
11131124
actors=actors,
11141125
queue=queue,
@@ -1124,6 +1135,7 @@ def _wrapped(*args, **kwargs):
11241135
bst, train_evals_result, train_additional_results = _train(
11251136
params,
11261137
dtrain,
1138+
boost_rounds_left,
11271139
*args,
11281140
evals=evals,
11291141
ray_params=ray_params,

xgboost_ray/tests/test_end_to_end.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ray
66

77
from xgboost_ray import RayParams, train, RayDMatrix, predict
8+
from xgboost_ray.tests.utils import get_num_trees
89

910

1011
class XGBoostRayEndToEndTest(unittest.TestCase):
@@ -114,11 +115,14 @@ def testTrainPredict(self, init=True, remote=None):
114115
bst = train(
115116
self.params,
116117
dtrain,
118+
num_boost_round=38,
117119
ray_params=RayParams(num_actors=2),
118120
evals=[(dtrain, "dtrain")],
119121
evals_result=evals_result,
120122
_remote=remote)
121123

124+
self.assertEqual(get_num_trees(bst), 38)
125+
122126
self.assertTrue("dtrain" in evals_result)
123127

124128
x_mat = RayDMatrix(self.x)

xgboost_ray/tests/test_fault_tolerance.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from xgboost_ray import train, RayDMatrix, RayParams
1515
from xgboost_ray.main import RayXGBoostActorAvailable
1616
from xgboost_ray.tests.utils import flatten_obj, _checkpoint_callback, \
17-
_fail_callback, tree_obj, _kill_callback, _sleep_callback
17+
_fail_callback, tree_obj, _kill_callback, _sleep_callback, get_num_trees
1818

1919

2020
class _FakeTask(MagicMock):
@@ -90,6 +90,8 @@ def keep(actors, *args, **kwargs):
9090
ray_params=RayParams(max_actor_restarts=1, num_actors=2),
9191
additional_results=additional_results)
9292

93+
self.assertEqual(20, get_num_trees(bst))
94+
9395
x_mat = xgb.DMatrix(self.x)
9496
pred_y = bst.predict(x_mat)
9597
self.assertSequenceEqual(list(self.y), list(pred_y))
@@ -129,6 +131,8 @@ def keep(actors, *args, **kwargs):
129131
max_failed_actors=1),
130132
additional_results=additional_results)
131133

134+
self.assertEqual(20, get_num_trees(bst))
135+
132136
x_mat = xgb.DMatrix(self.x)
133137
pred_y = bst.predict(x_mat)
134138
self.assertSequenceEqual(list(self.y), list(pred_y))
@@ -172,6 +176,8 @@ def keep(actors, *args, **kwargs):
172176
max_failed_actors=1),
173177
additional_results=additional_results)
174178

179+
self.assertEqual(20, get_num_trees(bst))
180+
175181
x_mat = xgb.DMatrix(self.x)
176182
pred_y = bst.predict(x_mat)
177183
self.assertSequenceEqual(list(self.y), list(pred_y))
@@ -186,6 +192,37 @@ def keep(actors, *args, **kwargs):
186192
# Both workers finished, so n=32
187193
self.assertEqual(additional_results["total_n"], 32)
188194

195+
@patch("xgboost_ray.main.ELASTIC_RESTART_DISABLED", True)
196+
def testTrainingContinuationElasticMultiKilled(self):
197+
"""This should still show 20 boost rounds after two failures."""
198+
logging.getLogger().setLevel(10)
199+
200+
additional_results = {}
201+
202+
bst = train(
203+
self.params,
204+
RayDMatrix(self.x, self.y),
205+
callbacks=[
206+
_kill_callback(
207+
self.die_lock_file, fail_iteration=6, actor_rank=0),
208+
_kill_callback(
209+
self.die_lock_file_2, fail_iteration=14, actor_rank=1),
210+
],
211+
num_boost_round=20,
212+
ray_params=RayParams(
213+
max_actor_restarts=2,
214+
num_actors=2,
215+
elastic_training=True,
216+
max_failed_actors=2),
217+
additional_results=additional_results)
218+
219+
self.assertEqual(20, get_num_trees(bst))
220+
221+
x_mat = xgb.DMatrix(self.x)
222+
pred_y = bst.predict(x_mat)
223+
self.assertSequenceEqual(list(self.y), list(pred_y))
224+
print(f"Got correct predictions: {pred_y}")
225+
189226
@patch("xgboost_ray.main.ELASTIC_RESTART_DISABLED", True)
190227
def testTrainingContinuationElasticFailed(self):
191228
"""This should continue after one actor failed training."""
@@ -211,6 +248,8 @@ def keep(actors, *args, **kwargs):
211248
max_failed_actors=1),
212249
additional_results=additional_results)
213250

251+
self.assertEqual(20, get_num_trees(bst))
252+
214253
x_mat = xgb.DMatrix(self.x)
215254
pred_y = bst.predict(x_mat)
216255
self.assertSequenceEqual(list(self.y), list(pred_y))

xgboost_ray/tests/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
from xgboost_ray.session import get_actor_rank, put_queue
1414

1515

16+
def get_num_trees(bst: xgb.Booster):
17+
import json
18+
data = [json.loads(d) for d in bst.get_dump(dump_format="json")]
19+
return len(data) // 4
20+
21+
1622
def create_data(num_rows: int, num_cols: int, dtype: np.dtype = np.float32):
1723

1824
return pd.DataFrame(

0 commit comments

Comments
 (0)