Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit dbf673e

Browse files
authored
Introduce training state dataclass, replacing private variables in _train() (#35)
1 parent aa1f918 commit dbf673e

File tree

2 files changed

+82
-66
lines changed

2 files changed

+82
-66
lines changed

xgboost_ray/main.py

Lines changed: 66 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -587,20 +587,28 @@ def _create_communication_processes():
587587
return queue, stop_event
588588

589589

590+
@dataclass
591+
class _TrainingState:
592+
actors: List
593+
queue: Queue
594+
stop_event: Event
595+
596+
checkpoint: _Checkpoint
597+
additional_results: Dict
598+
599+
placement_group: PlacementGroup
600+
601+
failed_actor_ranks: set
602+
603+
590604
def _train(params: Dict,
591605
dtrain: RayDMatrix,
592606
*args,
593607
evals=(),
594608
ray_params: RayParams,
595609
cpus_per_actor: int,
596610
gpus_per_actor: int,
597-
_checkpoint: _Checkpoint,
598-
_additional_results: Dict,
599-
_actors: List,
600-
_queue: Queue,
601-
_stop_event: Event,
602-
_placement_group: PlacementGroup,
603-
_failed_actor_ranks: set,
611+
_training_state: _TrainingState,
604612
**kwargs) -> Tuple[xgb.Booster, Dict, Dict]:
605613
"""This is the local train function wrapped by :func:`train() <train>`.
606614
@@ -628,16 +636,16 @@ def _train(params: Dict,
628636
# failed actors (which we might want to restart later), and set its
629637
# entry in the actor list to None.
630638
def handle_actor_failure(actor_id):
631-
rank = _actors.index(actor_id)
632-
_failed_actor_ranks.add(rank)
633-
_actors[rank] = None
639+
rank = _training_state.actors.index(actor_id)
640+
_training_state.failed_actor_ranks.add(rank)
641+
_training_state.actors[rank] = None
634642

635643
# Here we create new actors. In the first invocation of _train(), this
636644
# will be all actors. In future invocations, this may be less than
637645
# the num_actors setting, depending on the failure mode.
638646
newly_created = 0
639-
for i in list(_failed_actor_ranks):
640-
if _actors[i] is not None:
647+
for i in list(_training_state.failed_actor_ranks):
648+
if _training_state.actors[i] is not None:
641649
raise RuntimeError(
642650
f"Trying to create actor with rank {i}, but it already "
643651
f"exists.")
@@ -647,36 +655,36 @@ def handle_actor_failure(actor_id):
647655
num_cpus_per_actor=cpus_per_actor,
648656
num_gpus_per_actor=gpus_per_actor,
649657
resources_per_actor=ray_params.resources_per_actor,
650-
placement_group=_placement_group,
651-
queue=_queue,
658+
placement_group=_training_state.placement_group,
659+
queue=_training_state.queue,
652660
checkpoint_frequency=ray_params.checkpoint_frequency)
653661
# Set actor entry in our list
654-
_actors[i] = actor
662+
_training_state.actors[i] = actor
655663
# Remove from this set so it is not created again
656-
_failed_actor_ranks.remove(i)
664+
_training_state.failed_actor_ranks.remove(i)
657665
newly_created += 1
658666

659667
# Maybe we got a new Queue actor, so send it to all actors.
660668
wait_queue = [
661-
actor.set_queue.remote(_queue) for actor in _actors
662-
if actor is not None
669+
actor.set_queue.remote(_training_state.queue)
670+
for actor in _training_state.actors if actor is not None
663671
]
664672
ray.get(wait_queue)
665673

666674
# Maybe we got a new Event actor, so send it to all actors.
667675
wait_event = [
668-
actor.set_stop_event.remote(_stop_event) for actor in _actors
669-
if actor is not None
676+
actor.set_stop_event.remote(_training_state.stop_event)
677+
for actor in _training_state.actors if actor is not None
670678
]
671679
ray.get(wait_event)
672680

673-
alive_actors = sum(1 for a in _actors if a is not None)
681+
alive_actors = sum(1 for a in _training_state.actors if a is not None)
674682
logger.info(f"[RayXGBoost] Created {newly_created} new actors "
675683
f"({alive_actors} total actors).")
676684

677685
# Split data across workers
678686
wait_load = []
679-
for actor in _actors:
687+
for actor in _training_state.actors:
680688
if actor is None:
681689
continue
682690
# If data is already on the node, will not load again
@@ -685,8 +693,8 @@ def handle_actor_failure(actor_id):
685693
try:
686694
ray.get(wait_load)
687695
except Exception as exc:
688-
_stop_event.set()
689-
_get_actor_alive_status(_actors, handle_actor_failure)
696+
_training_state.stop_event.set()
697+
_get_actor_alive_status(_training_state.actors, handle_actor_failure)
690698
raise RayActorError from exc
691699

692700
logger.info("[RayXGBoost] Starting XGBoost training.")
@@ -697,31 +705,33 @@ def handle_actor_failure(actor_id):
697705

698706
# Load checkpoint if we have one. In that case we need to adjust the
699707
# number of training rounds.
700-
if _checkpoint.value:
701-
kwargs["xgb_model"] = pickle.loads(_checkpoint.value)
702-
if _checkpoint.iteration == -1:
708+
if _training_state.checkpoint.value:
709+
kwargs["xgb_model"] = pickle.loads(_training_state.checkpoint.value)
710+
if _training_state.checkpoint.iteration == -1:
703711
# -1 means training already finished.
704712
logger.error(
705713
f"Trying to load continue from checkpoint, but the checkpoint"
706714
f"indicates training already finished. Returning last"
707715
f"checkpointed model instead.")
708-
return kwargs["xgb_model"], {}, _additional_results
716+
return kwargs["xgb_model"], {}, _training_state.additional_results
709717

710718
kwargs["num_boost_round"] = kwargs.get("num_boost_round", 10) - \
711-
_checkpoint.iteration - 1
719+
_training_state.checkpoint.iteration - 1
712720

713721
# The callback_returns dict contains actor-rank indexed lists of
714722
# results obtained through the `put_queue` function, usually
715723
# sent via callbacks.
716-
callback_returns = _additional_results.get("callback_returns")
724+
callback_returns = _training_state.additional_results.get(
725+
"callback_returns")
717726
if callback_returns is None:
718-
callback_returns = [list() for _ in range(len(_actors))]
719-
_additional_results["callback_returns"] = callback_returns
727+
callback_returns = [list() for _ in range(len(_training_state.actors))]
728+
_training_state.additional_results[
729+
"callback_returns"] = callback_returns
720730

721731
# Trigger the train function
722732
training_futures = [
723733
actor.train.remote(rabit_args, params, dtrain, evals, *args, **kwargs)
724-
for actor in _actors if actor is not None
734+
for actor in _training_state.actors if actor is not None
725735
]
726736

727737
# Failure handling loop. Here we wait until all training tasks finished.
@@ -731,10 +741,10 @@ def handle_actor_failure(actor_id):
731741
try:
732742
not_ready = training_futures
733743
while not_ready:
734-
if _queue:
744+
if _training_state.queue:
735745
_handle_queue(
736-
queue=_queue,
737-
checkpoint=_checkpoint,
746+
queue=_training_state.queue,
747+
checkpoint=_training_state.checkpoint,
738748
callback_returns=callback_returns)
739749
ready, not_ready = ray.wait(not_ready, timeout=0)
740750
logger.debug("[RayXGBoost] Waiting for results...")
@@ -743,19 +753,19 @@ def handle_actor_failure(actor_id):
743753
ray.get(training_futures)
744754

745755
# Get items from queue one last time
746-
if _queue:
756+
if _training_state.queue:
747757
_handle_queue(
748-
queue=_queue,
749-
checkpoint=_checkpoint,
758+
queue=_training_state.queue,
759+
checkpoint=_training_state.checkpoint,
750760
callback_returns=callback_returns)
751761

752762
# The inner loop should catch all exceptions
753763
except Exception as exc:
754764
# Stop all other actors from training
755-
_stop_event.set()
765+
_training_state.stop_event.set()
756766

757767
# Check which actors are still alive
758-
_get_actor_alive_status(_actors, handle_actor_failure)
768+
_get_actor_alive_status(_training_state.actors, handle_actor_failure)
759769

760770
# Todo: Try to fetch newer checkpoint, store in `_checkpoint`
761771
# Shut down rabit
@@ -776,16 +786,17 @@ def handle_actor_failure(actor_id):
776786
evals_result = all_results[0]["evals_result"]
777787

778788
if callback_returns:
779-
_additional_results["callback_returns"] = callback_returns
789+
_training_state.additional_results[
790+
"callback_returns"] = callback_returns
780791

781792
total_n = sum(res["train_n"] or 0 for res in all_results)
782793

783-
_additional_results["total_n"] = total_n
794+
_training_state.additional_results["total_n"] = total_n
784795

785796
logger.info(f"[RayXGBoost] Finished XGBoost training on training data "
786797
f"with total N={total_n:,}.")
787798

788-
return bst, evals_result, _additional_results
799+
return bst, evals_result, _training_state.additional_results
789800

790801

791802
def train(params: Dict,
@@ -906,7 +917,17 @@ def train(params: Dict,
906917
pg = None
907918

908919
start_actor_ranks = set(range(ray_params.num_actors)) # Start these
920+
909921
while tries <= max_actor_restarts:
922+
training_state = _TrainingState(
923+
actors=actors,
924+
queue=queue,
925+
stop_event=stop_event,
926+
checkpoint=checkpoint,
927+
additional_results=current_results,
928+
placement_group=pg,
929+
failed_actor_ranks=start_actor_ranks)
930+
910931
try:
911932
bst, train_evals_result, train_additional_results = _train(
912933
params,
@@ -916,13 +937,7 @@ def train(params: Dict,
916937
ray_params=ray_params,
917938
cpus_per_actor=cpus_per_actor,
918939
gpus_per_actor=gpus_per_actor,
919-
_checkpoint=checkpoint,
920-
_additional_results=current_results,
921-
_actors=actors,
922-
_queue=queue,
923-
_stop_event=stop_event,
924-
_placement_group=pg,
925-
_failed_actor_ranks=start_actor_ranks,
940+
_training_state=training_state,
926941
**kwargs)
927942
break
928943
except (RayActorError, RayTaskError) as exc:

xgboost_ray/tests/test_colocation.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,13 @@ def test_communication_colocation(self, ray_start_cluster):
7070
assert len(ray.state.node_ids()) == 2
7171
assert local_node in ray.state.node_ids()
7272

73-
def _mock_train(*args, _queue, _stop_event, **kwargs):
74-
assert ray.get(_queue.actor.get_node_id.remote()
73+
def _mock_train(*args, _training_state, **kwargs):
74+
assert ray.get(_training_state.queue.actor.get_node_id.remote()
7575
) == ray.state.current_node_id()
7676
assert ray.get(
77-
_stop_event.actor.get_node_id.remote()) == \
77+
_training_state.stop_event.actor.get_node_id.remote()) == \
7878
ray.state.current_node_id()
79-
return _train(
80-
*args, _queue=_queue, _stop_event=_stop_event, **kwargs)
79+
return _train(*args, _training_state=_training_state, **kwargs)
8180

8281
with patch("xgboost_ray.main._train", _mock_train):
8382
train(
@@ -99,18 +98,19 @@ def test_no_tune_spread(self, ray_start_cluster):
9998
ray_params = RayParams(
10099
max_actor_restarts=1, num_actors=2, cpus_per_actor=2)
101100

102-
def _mock_train(*args, _actors, **kwargs):
101+
def _mock_train(*args, _training_state, **kwargs):
103102
try:
104-
results = _train(*args, _actors=_actors, **kwargs)
103+
results = _train(
104+
*args, _training_state=_training_state, **kwargs)
105105
return results
106106
except Exception:
107107
raise
108108
finally:
109-
assert len(_actors) == 2
110-
if not any(a is None for a in _actors):
109+
assert len(_training_state.actors) == 2
110+
if not any(a is None for a in _training_state.actors):
111111
actor_infos = ray.actors()
112112
actor_nodes = []
113-
for a in _actors:
113+
for a in _training_state.actors:
114114
actor_info = actor_infos.get(a._actor_id.hex())
115115
actor_node = actor_info["Address"]["NodeID"]
116116
actor_nodes.append(actor_node)
@@ -141,18 +141,19 @@ def test_tune_pack(self, ray_start_cluster):
141141
ray_params = RayParams(
142142
max_actor_restarts=1, num_actors=2, cpus_per_actor=1)
143143

144-
def _mock_train(*args, _actors, **kwargs):
144+
def _mock_train(*args, _training_state, **kwargs):
145145
try:
146-
results = _train(*args, _actors=_actors, **kwargs)
146+
results = _train(
147+
*args, _training_state=_training_state, **kwargs)
147148
return results
148149
except Exception:
149150
raise
150151
finally:
151-
assert len(_actors) == 2
152-
if not any(a is None for a in _actors):
152+
assert len(_training_state.actors) == 2
153+
if not any(a is None for a in _training_state.actors):
153154
actor_infos = ray.actors()
154155
actor_nodes = []
155-
for a in _actors:
156+
for a in _training_state.actors:
156157
actor_info = actor_infos.get(a._actor_id.hex())
157158
actor_node = actor_info["Address"]["NodeID"]
158159
actor_nodes.append(actor_node)

0 commit comments

Comments
 (0)