@@ -681,6 +681,8 @@ class _TrainingState:
681681 checkpoint : _Checkpoint
682682 additional_results : Dict
683683
684+ training_started_at : float = 0.
685+
684686 placement_group : Optional [PlacementGroup ] = None
685687
686688 failed_actor_ranks : set = field (default_factory = set )
@@ -830,6 +832,8 @@ def handle_actor_failure(actor_id):
830832 _training_state .additional_results [
831833 "callback_returns" ] = callback_returns
832834
835+ _training_state .training_started_at = time .time ()
836+
833837 # Trigger the train function
834838 training_futures = [
835839 actor .train .remote (rabit_args , params , dtrain , evals , * args , ** kwargs )
@@ -919,9 +923,6 @@ def handle_actor_failure(actor_id):
919923
920924 _training_state .additional_results ["total_n" ] = total_n
921925
922- logger .info (f"[RayXGBoost] Finished XGBoost training on training data "
923- f"with total N={ total_n :,} ." )
924-
925926 return bst , evals_result , _training_state .additional_results
926927
927928
@@ -1020,6 +1021,8 @@ def _wrapped(*args, **kwargs):
10201021 additional_results .update (train_additional_results )
10211022 return bst
10221023
1024+ start_time = time .time ()
1025+
10231026 ray_params = _validate_ray_params (ray_params )
10241027
10251028 max_actor_restarts = ray_params .max_actor_restarts \
@@ -1104,13 +1107,15 @@ def _wrapped(*args, **kwargs):
11041107
11051108 start_actor_ranks = set (range (ray_params .num_actors )) # Start these
11061109
1110+ total_training_time = 0.
11071111 while tries <= max_actor_restarts :
11081112 training_state = _TrainingState (
11091113 actors = actors ,
11101114 queue = queue ,
11111115 stop_event = stop_event ,
11121116 checkpoint = checkpoint ,
11131117 additional_results = current_results ,
1118+ training_started_at = 0. ,
11141119 placement_group = pg ,
11151120 failed_actor_ranks = start_actor_ranks ,
11161121 pending_actors = pending_actors )
@@ -1126,8 +1131,14 @@ def _wrapped(*args, **kwargs):
11261131 gpus_per_actor = gpus_per_actor ,
11271132 _training_state = training_state ,
11281133 ** kwargs )
1134+ if training_state .training_started_at > 0. :
1135+ total_training_time += time .time (
1136+ ) - training_state .training_started_at
11291137 break
11301138 except (RayActorError , RayTaskError ) as exc :
1139+ if training_state .training_started_at > 0. :
1140+ total_training_time += time .time (
1141+ ) - training_state .training_started_at
11311142 alive_actors = sum (1 for a in actors if a is not None )
11321143 start_again = False
11331144 if ray_params .elastic_training :
@@ -1186,6 +1197,16 @@ def _wrapped(*args, **kwargs):
11861197 ) from exc
11871198 tries += 1
11881199
1200+ total_time = time .time () - start_time
1201+
1202+ train_additional_results ["training_time_s" ] = total_training_time
1203+ train_additional_results ["total_time_s" ] = total_time
1204+
1205+ logger .info ("[RayXGBoost] Finished XGBoost training on training data "
1206+ "with total N={total_n:,} in {total_time_s:.2f} seconds "
1207+ "({training_time_s:.2f} pure XGBoost training time)." .format (
1208+ ** train_additional_results ))
1209+
11891210 _shutdown (
11901211 actors = actors ,
11911212 pending_actors = pending_actors ,
0 commit comments