Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit cee3008

Browse files
authored
Track total training time (#58)
* Update README * Track total training time * Fix test colocation
1 parent 1e6494f commit cee3008

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

xgboost_ray/main.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,8 @@ class _TrainingState:
681681
checkpoint: _Checkpoint
682682
additional_results: Dict
683683

684+
training_started_at: float = 0.
685+
684686
placement_group: Optional[PlacementGroup] = None
685687

686688
failed_actor_ranks: set = field(default_factory=set)
@@ -830,6 +832,8 @@ def handle_actor_failure(actor_id):
830832
_training_state.additional_results[
831833
"callback_returns"] = callback_returns
832834

835+
_training_state.training_started_at = time.time()
836+
833837
# Trigger the train function
834838
training_futures = [
835839
actor.train.remote(rabit_args, params, dtrain, evals, *args, **kwargs)
@@ -919,9 +923,6 @@ def handle_actor_failure(actor_id):
919923

920924
_training_state.additional_results["total_n"] = total_n
921925

922-
logger.info(f"[RayXGBoost] Finished XGBoost training on training data "
923-
f"with total N={total_n:,}.")
924-
925926
return bst, evals_result, _training_state.additional_results
926927

927928

@@ -1020,6 +1021,8 @@ def _wrapped(*args, **kwargs):
10201021
additional_results.update(train_additional_results)
10211022
return bst
10221023

1024+
start_time = time.time()
1025+
10231026
ray_params = _validate_ray_params(ray_params)
10241027

10251028
max_actor_restarts = ray_params.max_actor_restarts \
@@ -1104,13 +1107,15 @@ def _wrapped(*args, **kwargs):
11041107

11051108
start_actor_ranks = set(range(ray_params.num_actors)) # Start these
11061109

1110+
total_training_time = 0.
11071111
while tries <= max_actor_restarts:
11081112
training_state = _TrainingState(
11091113
actors=actors,
11101114
queue=queue,
11111115
stop_event=stop_event,
11121116
checkpoint=checkpoint,
11131117
additional_results=current_results,
1118+
training_started_at=0.,
11141119
placement_group=pg,
11151120
failed_actor_ranks=start_actor_ranks,
11161121
pending_actors=pending_actors)
@@ -1126,8 +1131,14 @@ def _wrapped(*args, **kwargs):
11261131
gpus_per_actor=gpus_per_actor,
11271132
_training_state=training_state,
11281133
**kwargs)
1134+
if training_state.training_started_at > 0.:
1135+
total_training_time += time.time(
1136+
) - training_state.training_started_at
11291137
break
11301138
except (RayActorError, RayTaskError) as exc:
1139+
if training_state.training_started_at > 0.:
1140+
total_training_time += time.time(
1141+
) - training_state.training_started_at
11311142
alive_actors = sum(1 for a in actors if a is not None)
11321143
start_again = False
11331144
if ray_params.elastic_training:
@@ -1186,6 +1197,16 @@ def _wrapped(*args, **kwargs):
11861197
) from exc
11871198
tries += 1
11881199

1200+
total_time = time.time() - start_time
1201+
1202+
train_additional_results["training_time_s"] = total_training_time
1203+
train_additional_results["total_time_s"] = total_time
1204+
1205+
logger.info("[RayXGBoost] Finished XGBoost training on training data "
1206+
"with total N={total_n:,} in {total_time_s:.2f} seconds "
1207+
"({training_time_s:.2f} pure XGBoost training time).".format(
1208+
**train_additional_results))
1209+
11891210
_shutdown(
11901211
actors=actors,
11911212
pending_actors=pending_actors,

xgboost_ray/tests/test_colocation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import shutil
33
import tempfile
44
import unittest
5-
from unittest.mock import patch, DEFAULT
5+
from unittest.mock import patch
66
import pytest
77

88
import numpy as np
@@ -79,7 +79,7 @@ def _mock_train(*args, _training_state, **kwargs):
7979
assert ray.get(
8080
_training_state.stop_event.actor.get_node_id.remote()) == \
8181
ray.state.current_node_id()
82-
return DEFAULT, DEFAULT, DEFAULT
82+
return _train(*args, _training_state=_training_state, **kwargs)
8383

8484
with patch("xgboost_ray.main._train") as mocked:
8585
mocked.side_effect = _mock_train

0 commit comments

Comments
 (0)