Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 1acfc85

Browse files
krfrickeamogkam
andauthored
Elastic training: Restart actors (#36)
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
1 parent dbf673e commit 1acfc85

File tree

9 files changed

+975
-116
lines changed

9 files changed

+975
-116
lines changed

.github/workflows/test.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
2525
test_linux_ray_master:
2626
runs-on: ubuntu-latest
27-
timeout-minutes: 10
27+
timeout-minutes: 12
2828

2929
steps:
3030
- uses: actions/checkout@v2
@@ -62,7 +62,7 @@ jobs:
6262

6363
test_linux_ray_release:
6464
runs-on: ubuntu-latest
65-
timeout-minutes: 10
65+
timeout-minutes: 12
6666

6767
steps:
6868
- uses: actions/checkout@v2
@@ -102,7 +102,7 @@ jobs:
102102
# Test compatibility when some optional libraries are missing
103103
# Test runs on latest ray release
104104
runs-on: ubuntu-latest
105-
timeout-minutes: 10
105+
timeout-minutes: 12
106106

107107
steps:
108108
- uses: actions/checkout@v2

xgboost_ray/elastic.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import time
2+
from typing import Optional, Dict, List, Tuple, Callable
3+
4+
import ray
5+
6+
from xgboost_ray.main import RayParams, _TrainingState, \
7+
logger, ActorHandle, _PrepareActorTask, _create_actor, \
8+
RayXGBoostActorAvailable, \
9+
ELASTIC_RESTART_RESOURCE_CHECK_S, ELASTIC_RESTART_GRACE_PERIOD_S
10+
11+
from xgboost_ray.matrix import RayDMatrix
12+
13+
14+
def _maybe_schedule_new_actors(
15+
training_state: _TrainingState, num_cpus_per_actor: int,
16+
num_gpus_per_actor: int, resources_per_actor: Optional[Dict],
17+
ray_params: RayParams, load_data: List[RayDMatrix]) -> bool:
18+
"""Schedule new actors for elastic training if resources are available.
19+
20+
Potentially starts new actors and triggers data loading."""
21+
22+
# This is only enabled for elastic training.
23+
if not ray_params.elastic_training:
24+
return False
25+
26+
missing_actor_ranks = [
27+
rank for rank, actor in enumerate(training_state.actors)
28+
if actor is None and rank not in training_state.pending_actors
29+
]
30+
31+
# If all actors are alive, there is nothing to do.
32+
if not missing_actor_ranks:
33+
return False
34+
35+
now = time.time()
36+
37+
# Check periodically every n seconds.
38+
if now < training_state.last_resource_check_at + \
39+
ELASTIC_RESTART_RESOURCE_CHECK_S:
40+
return False
41+
42+
training_state.last_resource_check_at = now
43+
44+
new_pending_actors: Dict[int, Tuple[ActorHandle, _PrepareActorTask]] = {}
45+
for rank in missing_actor_ranks:
46+
# Actor rank should not be already pending
47+
if rank in training_state.pending_actors \
48+
or rank in new_pending_actors:
49+
continue
50+
51+
# Try to schedule this actor
52+
actor = _create_actor(
53+
rank=rank,
54+
num_actors=ray_params.num_actors,
55+
num_cpus_per_actor=num_cpus_per_actor,
56+
num_gpus_per_actor=num_gpus_per_actor,
57+
resources_per_actor=resources_per_actor,
58+
placement_group=training_state.placement_group,
59+
queue=training_state.queue,
60+
checkpoint_frequency=ray_params.checkpoint_frequency)
61+
62+
task = _PrepareActorTask(
63+
actor,
64+
queue=training_state.queue,
65+
stop_event=training_state.stop_event,
66+
load_data=load_data)
67+
68+
new_pending_actors[rank] = (actor, task)
69+
logger.debug(f"Re-scheduled actor with rank {rank}. Waiting for "
70+
f"placement and data loading before promoting it "
71+
f"to training.")
72+
if new_pending_actors:
73+
training_state.pending_actors.update(new_pending_actors)
74+
logger.info(f"Re-scheduled {len(new_pending_actors)} actors for "
75+
f"training. Once data loading finished, they will be "
76+
f"integrated into training again.")
77+
return bool(new_pending_actors)
78+
79+
80+
def _update_scheduled_actor_states(training_state: _TrainingState):
81+
"""Update status of scheduled actors in elastic training.
82+
83+
If actors finished their preparation tasks, promote them to
84+
proper training actors (set the `training_state.actors` entry).
85+
86+
Also schedule a `RayXGBoostActorAvailable` exception so that training
87+
is restarted with the new actors.
88+
89+
"""
90+
now = time.time()
91+
actor_became_ready = False
92+
93+
# Wrap in list so we can alter the `training_state.pending_actors` dict
94+
for rank in list(training_state.pending_actors.keys()):
95+
actor, task = training_state.pending_actors[rank]
96+
if task.is_ready():
97+
# Promote to proper actor
98+
training_state.actors[rank] = actor
99+
del training_state.pending_actors[rank]
100+
actor_became_ready = True
101+
102+
if actor_became_ready:
103+
if not training_state.pending_actors:
104+
# No other actors are pending, so let's restart right away.
105+
training_state.restart_training_at = now - 1.
106+
107+
# If an actor became ready but other actors are pending, we wait
108+
# for n seconds before restarting, as chances are that they become
109+
# ready as well (e.g. if a large node came up).
110+
grace_period = ELASTIC_RESTART_GRACE_PERIOD_S
111+
if training_state.restart_training_at is None:
112+
logger.debug(
113+
f"A RayXGBoostActor became ready for training. Waiting "
114+
f"{grace_period} seconds before triggering training restart.")
115+
training_state.restart_training_at = now + grace_period
116+
117+
if training_state.restart_training_at is not None:
118+
if now > training_state.restart_training_at:
119+
training_state.restart_training_at = None
120+
raise RayXGBoostActorAvailable(
121+
"A new RayXGBoostActor became available for training. "
122+
"Triggering restart.")
123+
124+
125+
def _get_actor_alive_status(actors: List[ActorHandle],
126+
callback: Callable[[ActorHandle], None]):
127+
"""Loop through all actors. Invoke a callback on dead actors. """
128+
obj_to_rank = {}
129+
130+
alive = 0
131+
dead = 0
132+
133+
for rank, actor in enumerate(actors):
134+
if actor is None:
135+
dead += 1
136+
continue
137+
obj = actor.pid.remote()
138+
obj_to_rank[obj] = rank
139+
140+
not_ready = list(obj_to_rank.keys())
141+
while not_ready:
142+
ready, not_ready = ray.wait(not_ready, timeout=0)
143+
144+
for obj in ready:
145+
try:
146+
pid = ray.get(obj)
147+
rank = obj_to_rank[obj]
148+
logger.debug(f"Actor {actors[rank]} with PID {pid} is alive.")
149+
alive += 1
150+
except Exception:
151+
rank = obj_to_rank[obj]
152+
logger.debug(f"Actor {actors[rank]} is _not_ alive.")
153+
dead += 1
154+
callback(actors[rank])
155+
logger.info(f"Actor status: {alive} alive, {dead} dead "
156+
f"({alive+dead} total)")
157+
158+
return alive, dead

0 commit comments

Comments
 (0)