@@ -587,20 +587,28 @@ def _create_communication_processes():
587587 return queue , stop_event
588588
589589
590+ @dataclass
591+ class _TrainingState :
592+ actors : List
593+ queue : Queue
594+ stop_event : Event
595+
596+ checkpoint : _Checkpoint
597+ additional_results : Dict
598+
599+ placement_group : PlacementGroup
600+
601+ failed_actor_ranks : set
602+
603+
590604def _train (params : Dict ,
591605 dtrain : RayDMatrix ,
592606 * args ,
593607 evals = (),
594608 ray_params : RayParams ,
595609 cpus_per_actor : int ,
596610 gpus_per_actor : int ,
597- _checkpoint : _Checkpoint ,
598- _additional_results : Dict ,
599- _actors : List ,
600- _queue : Queue ,
601- _stop_event : Event ,
602- _placement_group : PlacementGroup ,
603- _failed_actor_ranks : set ,
611+ _training_state : _TrainingState ,
604612 ** kwargs ) -> Tuple [xgb .Booster , Dict , Dict ]:
605613 """This is the local train function wrapped by :func:`train() <train>`.
606614
@@ -628,16 +636,16 @@ def _train(params: Dict,
628636 # failed actors (which we might want to restart later), and set its
629637 # entry in the actor list to None.
630638 def handle_actor_failure (actor_id ):
631- rank = _actors .index (actor_id )
632- _failed_actor_ranks .add (rank )
633- _actors [rank ] = None
639+ rank = _training_state . actors .index (actor_id )
640+ _training_state . failed_actor_ranks .add (rank )
641+ _training_state . actors [rank ] = None
634642
635643 # Here we create new actors. In the first invocation of _train(), this
636644 # will be all actors. In future invocations, this may be less than
637645 # the num_actors setting, depending on the failure mode.
638646 newly_created = 0
639- for i in list (_failed_actor_ranks ):
640- if _actors [i ] is not None :
647+ for i in list (_training_state . failed_actor_ranks ):
648+ if _training_state . actors [i ] is not None :
641649 raise RuntimeError (
642650 f"Trying to create actor with rank { i } , but it already "
643651 f"exists." )
@@ -647,36 +655,36 @@ def handle_actor_failure(actor_id):
647655 num_cpus_per_actor = cpus_per_actor ,
648656 num_gpus_per_actor = gpus_per_actor ,
649657 resources_per_actor = ray_params .resources_per_actor ,
650- placement_group = _placement_group ,
651- queue = _queue ,
658+ placement_group = _training_state . placement_group ,
659+ queue = _training_state . queue ,
652660 checkpoint_frequency = ray_params .checkpoint_frequency )
653661 # Set actor entry in our list
654- _actors [i ] = actor
662+ _training_state . actors [i ] = actor
655663 # Remove from this set so it is not created again
656- _failed_actor_ranks .remove (i )
664+ _training_state . failed_actor_ranks .remove (i )
657665 newly_created += 1
658666
659667 # Maybe we got a new Queue actor, so send it to all actors.
660668 wait_queue = [
661- actor .set_queue .remote (_queue ) for actor in _actors
662- if actor is not None
669+ actor .set_queue .remote (_training_state . queue )
670+ for actor in _training_state . actors if actor is not None
663671 ]
664672 ray .get (wait_queue )
665673
666674 # Maybe we got a new Event actor, so send it to all actors.
667675 wait_event = [
668- actor .set_stop_event .remote (_stop_event ) for actor in _actors
669- if actor is not None
676+ actor .set_stop_event .remote (_training_state . stop_event )
677+ for actor in _training_state . actors if actor is not None
670678 ]
671679 ray .get (wait_event )
672680
673- alive_actors = sum (1 for a in _actors if a is not None )
681+ alive_actors = sum (1 for a in _training_state . actors if a is not None )
674682 logger .info (f"[RayXGBoost] Created { newly_created } new actors "
675683 f"({ alive_actors } total actors)." )
676684
677685 # Split data across workers
678686 wait_load = []
679- for actor in _actors :
687+ for actor in _training_state . actors :
680688 if actor is None :
681689 continue
682690 # If data is already on the node, will not load again
@@ -685,8 +693,8 @@ def handle_actor_failure(actor_id):
685693 try :
686694 ray .get (wait_load )
687695 except Exception as exc :
688- _stop_event .set ()
689- _get_actor_alive_status (_actors , handle_actor_failure )
696+ _training_state . stop_event .set ()
697+ _get_actor_alive_status (_training_state . actors , handle_actor_failure )
690698 raise RayActorError from exc
691699
692700 logger .info ("[RayXGBoost] Starting XGBoost training." )
@@ -697,31 +705,33 @@ def handle_actor_failure(actor_id):
697705
698706 # Load checkpoint if we have one. In that case we need to adjust the
699707 # number of training rounds.
700- if _checkpoint .value :
701- kwargs ["xgb_model" ] = pickle .loads (_checkpoint .value )
702- if _checkpoint .iteration == - 1 :
708+ if _training_state . checkpoint .value :
709+ kwargs ["xgb_model" ] = pickle .loads (_training_state . checkpoint .value )
710+ if _training_state . checkpoint .iteration == - 1 :
703711 # -1 means training already finished.
704712 logger .error (
705713 f"Trying to load continue from checkpoint, but the checkpoint"
706714 f"indicates training already finished. Returning last"
707715 f"checkpointed model instead." )
708- return kwargs ["xgb_model" ], {}, _additional_results
716+ return kwargs ["xgb_model" ], {}, _training_state . additional_results
709717
710718 kwargs ["num_boost_round" ] = kwargs .get ("num_boost_round" , 10 ) - \
711- _checkpoint .iteration - 1
719+ _training_state . checkpoint .iteration - 1
712720
713721 # The callback_returns dict contains actor-rank indexed lists of
714722 # results obtained through the `put_queue` function, usually
715723 # sent via callbacks.
716- callback_returns = _additional_results .get ("callback_returns" )
724+ callback_returns = _training_state .additional_results .get (
725+ "callback_returns" )
717726 if callback_returns is None :
718- callback_returns = [list () for _ in range (len (_actors ))]
719- _additional_results ["callback_returns" ] = callback_returns
727+ callback_returns = [list () for _ in range (len (_training_state .actors ))]
728+ _training_state .additional_results [
729+ "callback_returns" ] = callback_returns
720730
721731 # Trigger the train function
722732 training_futures = [
723733 actor .train .remote (rabit_args , params , dtrain , evals , * args , ** kwargs )
724- for actor in _actors if actor is not None
734+ for actor in _training_state . actors if actor is not None
725735 ]
726736
727737 # Failure handling loop. Here we wait until all training tasks finished.
@@ -731,10 +741,10 @@ def handle_actor_failure(actor_id):
731741 try :
732742 not_ready = training_futures
733743 while not_ready :
734- if _queue :
744+ if _training_state . queue :
735745 _handle_queue (
736- queue = _queue ,
737- checkpoint = _checkpoint ,
746+ queue = _training_state . queue ,
747+ checkpoint = _training_state . checkpoint ,
738748 callback_returns = callback_returns )
739749 ready , not_ready = ray .wait (not_ready , timeout = 0 )
740750 logger .debug ("[RayXGBoost] Waiting for results..." )
@@ -743,19 +753,19 @@ def handle_actor_failure(actor_id):
743753 ray .get (training_futures )
744754
745755 # Get items from queue one last time
746- if _queue :
756+ if _training_state . queue :
747757 _handle_queue (
748- queue = _queue ,
749- checkpoint = _checkpoint ,
758+ queue = _training_state . queue ,
759+ checkpoint = _training_state . checkpoint ,
750760 callback_returns = callback_returns )
751761
752762 # The inner loop should catch all exceptions
753763 except Exception as exc :
754764 # Stop all other actors from training
755- _stop_event .set ()
765+ _training_state . stop_event .set ()
756766
757767 # Check which actors are still alive
758- _get_actor_alive_status (_actors , handle_actor_failure )
768+ _get_actor_alive_status (_training_state . actors , handle_actor_failure )
759769
760770 # Todo: Try to fetch newer checkpoint, store in `_checkpoint`
761771 # Shut down rabit
@@ -776,16 +786,17 @@ def handle_actor_failure(actor_id):
776786 evals_result = all_results [0 ]["evals_result" ]
777787
778788 if callback_returns :
779- _additional_results ["callback_returns" ] = callback_returns
789+ _training_state .additional_results [
790+ "callback_returns" ] = callback_returns
780791
781792 total_n = sum (res ["train_n" ] or 0 for res in all_results )
782793
783- _additional_results ["total_n" ] = total_n
794+ _training_state . additional_results ["total_n" ] = total_n
784795
785796 logger .info (f"[RayXGBoost] Finished XGBoost training on training data "
786797 f"with total N={ total_n :,} ." )
787798
788- return bst , evals_result , _additional_results
799+ return bst , evals_result , _training_state . additional_results
789800
790801
791802def train (params : Dict ,
@@ -906,7 +917,17 @@ def train(params: Dict,
906917 pg = None
907918
908919 start_actor_ranks = set (range (ray_params .num_actors )) # Start these
920+
909921 while tries <= max_actor_restarts :
922+ training_state = _TrainingState (
923+ actors = actors ,
924+ queue = queue ,
925+ stop_event = stop_event ,
926+ checkpoint = checkpoint ,
927+ additional_results = current_results ,
928+ placement_group = pg ,
929+ failed_actor_ranks = start_actor_ranks )
930+
910931 try :
911932 bst , train_evals_result , train_additional_results = _train (
912933 params ,
@@ -916,13 +937,7 @@ def train(params: Dict,
916937 ray_params = ray_params ,
917938 cpus_per_actor = cpus_per_actor ,
918939 gpus_per_actor = gpus_per_actor ,
919- _checkpoint = checkpoint ,
920- _additional_results = current_results ,
921- _actors = actors ,
922- _queue = queue ,
923- _stop_event = stop_event ,
924- _placement_group = pg ,
925- _failed_actor_ranks = start_actor_ranks ,
940+ _training_state = training_state ,
926941 ** kwargs )
927942 break
928943 except (RayActorError , RayTaskError ) as exc :
0 commit comments