Tseah/convert torchft replica group#60820
Tseah/convert torchft replica group#60820TimothySeah wants to merge 22 commits intoray-project:masterfrom
Conversation
Signed-off-by: Lonnie Liu <lonnie@anyscale.com>
cherrypick ray-project#59494 Signed-off-by: Lonnie Liu <lonnie@anyscale.com>
… you request 0 GPUs on CPU-only cluster (ray-project#59516) Cherry-pick of ray-project#59514 Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu>
…ct#59519) EWMA_ALPHA Update EWMA_ALPHA from 0.2->0.1. This makes adjusting level to be more in-favor of limiting concurrency by being more sensitive to downstreaming queuing. K_DEV Update K_DEV from 2.0->1.0. This makes stddev to be more in-favor of limiting concurrency by being more sensitive to downstreaming queuing. cherry-pick of ray-project#59392
…oject#59606) Created by release automation bot. Update with commit 0de2118 Signed-off-by: Lonnie Liu <lonnie@anyscale.com> Co-authored-by: Lonnie Liu <lonnie@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
This reverts commit 436def7.
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for torchft in Ray Train, enabling more granular fault tolerance through replica groups. The changes are extensive, touching configuration, backend setup, controller logic, and worker group management. Key additions include the TorchftConfig, a new _TorchftBackend that manages per-replica-group process groups, and the logic in WorkerGroup to replace failed replica groups. The PR also includes several improvements to Ray Data's backpressure mechanism and a bug fix in the autoscaling coordinator.
Overall, this is a significant feature addition. I've found one critical issue related to the autoscaling/resizing logic and one opportunity for refactoring to reduce code duplication.
| async def _execute_resize_decision( | ||
| self, decision: ResizeDecision | ||
| ) -> TrainControllerLoopIterationResult: | ||
| """Executes resize decisions.""" | ||
|
|
||
| for callback in self._controller_callbacks: | ||
| callback.before_controller_execute_resize_decision(decision) | ||
|
|
||
| if self._worker_group: | ||
| self._shutdown_worker_group() | ||
| optional_controller_error = None | ||
|
|
||
| optional_controller_error = self._start_worker_group( | ||
| num_workers=decision.num_workers, | ||
| resources_per_worker=decision.resources_per_worker, | ||
| ) | ||
| if self._worker_group: | ||
| # Replace bad workers in the existing worker group | ||
| # TODO: propagate poll_status rather than recalculating it | ||
| try: | ||
| self._replace_bad_workers(await self._poll_workers()) | ||
| except Exception as e: | ||
| optional_controller_error = ControllerError(e) | ||
| else: | ||
| optional_controller_error = self._start_worker_group( | ||
| num_workers=decision.num_workers, | ||
| resources_per_worker=decision.resources_per_worker, | ||
| ) |
There was a problem hiding this comment.
The logic in _execute_resize_decision has been changed from handling resizing to handling failure recovery. The original implementation, which performed a full restart of the worker group to apply a ResizeDecision, has been replaced with a call to _replace_bad_workers. This new logic only recovers failed workers and does not adjust the total number of workers, which effectively breaks autoscaling (both scaling up and down).
If the intention is to disable scaling when using torchft, this should be handled more explicitly, for instance by ensuring the ScalingPolicy does not generate a ResizeDecision. As it stands, this change is a regression in functionality.
| # Re-initialize backend (per-group TCPStore + init_process_group) | ||
| # via BackendSetupCallback | ||
| from ray.train.v2._internal.callbacks.backend_setup import BackendSetupCallback | ||
|
|
||
| for callback in self._callbacks: | ||
| if isinstance(callback, BackendSetupCallback): | ||
| # First update workers in state so the callback can access them | ||
| new_workers_by_rank = { | ||
| w.distributed_context.world_rank: w for w in new_workers | ||
| } | ||
| updated_workers = [ | ||
| new_workers_by_rank.get(w.distributed_context.world_rank, w) | ||
| for w in workers | ||
| ] | ||
| self._worker_group_state = WorkerGroupState( | ||
| start_time=self._worker_group_state.start_time, | ||
| placement_group=pg, | ||
| workers=updated_workers, | ||
| sync_actor=sync_actor, | ||
| ) | ||
| callback.reinitialize_workers(self, target_group.world_ranks) | ||
| break | ||
|
|
||
| # Get train context args from callbacks | ||
| train_context_args = {} | ||
| for cb in self._callbacks: | ||
| args = cb.before_init_train_context(new_workers) | ||
| for arg, arg_values in args.items(): | ||
| assert len(arg_values) == len(new_workers), ( | ||
| f"Callback {cb} returned {arg} with " | ||
| f"{len(arg_values)} values, expected {len(new_workers)}." | ||
| ) | ||
| assert ( | ||
| arg not in train_context_args | ||
| ), f"Callback {cb} returned {arg} which is already set." | ||
| train_context_args[arg] = arg_values | ||
|
|
||
| # Initialize train context on new workers | ||
| try: | ||
| self._init_train_context_on_workers( | ||
| new_workers, sync_actor, train_context_args | ||
| ) | ||
| except RayActorError as actor_error: | ||
| for worker in new_workers: | ||
| ray.kill(worker.actor) | ||
| error_msg = ( | ||
| "Replacement workers failed during train context initialization." | ||
| ) | ||
| raise WorkerGroupStartupFailedError(error_msg) from actor_error | ||
|
|
||
| # Launch training function on new workers | ||
| ray_get_safe( | ||
| [ | ||
| worker.actor.run_train_fn.remote( | ||
| self._worker_group_context.train_fn_ref | ||
| ) | ||
| for worker in new_workers | ||
| ] | ||
| ) | ||
|
|
||
| # Update state if not already updated above (in case no BackendSetupCallback) | ||
| if not any(isinstance(cb, BackendSetupCallback) for cb in self._callbacks): | ||
| new_workers_by_rank = { | ||
| w.distributed_context.world_rank: w for w in new_workers | ||
| } | ||
| updated_workers = [ | ||
| new_workers_by_rank.get(w.distributed_context.world_rank, w) | ||
| for w in workers | ||
| ] | ||
| self._worker_group_state = WorkerGroupState( | ||
| start_time=self._worker_group_state.start_time, | ||
| placement_group=pg, | ||
| workers=updated_workers, | ||
| sync_actor=sync_actor, | ||
| ) |
There was a problem hiding this comment.
The logic to update self._worker_group_state with the new workers is duplicated. It appears once inside the loop that finds the BackendSetupCallback and again in a separate if block for the case where the callback is not found. This can be refactored to update the state once before the loop, improving code clarity and maintainability.
# Update worker group state with the new workers.
new_workers_by_rank = {
w.distributed_context.world_rank: w for w in new_workers
}
updated_workers = [
new_workers_by_rank.get(w.distributed_context.world_rank, w)
for w in workers
]
self._worker_group_state = WorkerGroupState(
start_time=self._worker_group_state.start_time,
placement_group=pg,
workers=updated_workers,
sync_actor=sync_actor,
)
# Re-initialize backend (per-group TCPStore + init_process_group)
# via BackendSetupCallback
from ray.train.v2._internal.callbacks.backend_setup import BackendSetupCallback
for callback in self._callbacks:
if isinstance(callback, BackendSetupCallback):
callback.reinitialize_workers(self, target_group.world_ranks)
break
# Get train context args from callbacks
train_context_args = {}
for cb in self._callbacks:
args = cb.before_init_train_context(new_workers)
for arg, arg_values in args.items():
assert len(arg_values) == len(new_workers), (
f"Callback {cb} returned {arg} with "
f"{len(arg_values)} values, expected {len(new_workers)}."
)
assert (
arg not in train_context_args
), f"Callback {cb} returned {arg} which is already set."
train_context_args[arg] = arg_values
# Initialize train context on new workers
try:
self._init_train_context_on_workers(
new_workers, sync_actor, train_context_args
)
except RayActorError as actor_error:
for worker in new_workers:
ray.kill(worker.actor)
error_msg = (
"Replacement workers failed during train context initialization."
)
raise WorkerGroupStartupFailedError(error_msg) from actor_error
# Launch training function on new workers
ray_get_safe(
[
worker.actor.run_train_fn.remote(
self._worker_group_context.train_fn_ref
)
for worker in new_workers
]
)
Description
Related issues
Additional information