Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[Ray lightning 1.6] update the change according to the comment in #163 (
Browse files Browse the repository at this point in the history
#195)

* adding the change (based on #163)

* Update ray_lightning/launchers/ray_horovod_launcher.py

* Update ray_lightning/launchers/ray_launcher.py

Co-authored-by: Amog Kamsetty <[email protected]>
  • Loading branch information
JiahaoYao and amogkam authored Aug 16, 2022
1 parent 13468d7 commit 371cb19
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 66 deletions.
58 changes: 42 additions & 16 deletions ray_lightning/launchers/ray_horovod_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(self, strategy: "Strategy") -> None:
@property
def global_rank(self) -> int:
"""Return the global rank of the current process.
run on the worker node.
This function is run on the worker process.
"""
if not hvd.is_initialized():
return 0
Expand All @@ -57,7 +58,8 @@ def global_rank(self) -> int:
@property
def local_rank(self) -> int:
"""Return the local rank of the current process.
run on the worker node.
This function is run on the worker process.
"""
if not hvd.is_initialized():
return 0
Expand All @@ -66,7 +68,8 @@ def local_rank(self) -> int:
@property
def world_size(self) -> int:
"""Return the world size of the current process.
run on the worker node.
This function is run on the worker process.
"""
if not hvd.is_initialized():
return self.num_workers
Expand All @@ -81,13 +84,18 @@ def launch(self,
*args: Any,
trainer: Optional["pl.Trainer"] = None,
**kwargs: Any) -> Any:
"""Launch the function on the workers and collect the results."""
"""Launch the function on the workers and collect the results.
This function is run on the driver process.
"""
ray_output = self.run_function_on_workers(
function, *args, trainer=trainer, **kwargs)

if trainer is None:
raise NotImplementedError(
"Ray launcher does not support trainer is None!")
"Ray launcher does not support trainer is None! "
"Did you override the `trainer` variable? "
"If not, please help file an issue on Github.")
self._recover_results_in_main_process(ray_output, trainer)
return_value = ray_output.trainer_results

Expand All @@ -99,8 +107,11 @@ def run_function_on_workers(self,
trainer: Optional["pl.Trainer"] = None,
**kwargs: Any):
"""Run the function on the workers and collect the results.
`executor.run_remote` is used to launch multiple ray remote tasks
to distributed training the model using the horovod backend.
This function is run on the driver process.
`executor.run_remote` is used to launch multiple ray remote tasks
to distributed training the model using the horovod backend.
"""

# put the model as the ray object
Expand All @@ -109,6 +120,7 @@ def run_function_on_workers(self,
model = trainer.model
model_ref = ray.put(model)
trainer.model = None
# the model always be at the 0th position in the args
new_args = tuple([None] + list(args[1:]))

# remove the executor temporarily from the args
Expand Down Expand Up @@ -147,21 +159,29 @@ def _wrapping_function(
tune_queue: Queue,
) -> Any:
"""Wrapping function to run the function on the workers.
`_wrapping_function` is run on each remote worker.
`function(*args, **kwargs)` is where the actual training happens.
This function is run on the worker process.
`_wrapping_function` is run on each remote worker.
`function(*args, **kwargs)` is where the actual training happens.
"""

self._strategy.set_remote(True)

# `function` is a trainer's class method
# in the ray remote tasks, its object `trainer` will also
# be copied when the function is remoted.
# `function` is a trainer's instance method
# in the ray remote tasks, its bound instance `trainer`
# will also be copied when the function is remoted.
#
# ALERT: passing the trainer as an argument of `_wrapping_function`
# does not fillfullied our purpose. Ray remote tasks will
# does not fulfill our purpose. Ray remote tasks will
# create another copy of trainer so that
# `function.__self__ != trainer`, in which the side effect only
# happens to `function.__self__` when running
# `function(*args, **kwargs)`
# `function(*args, **kwargs)` (see SOLUTION below).
#
# SOLUTION: we find the trainer directly from `function`
# by calling `function.__self__` so that we can restore
# all the side effects happened to `function.__self__`
trainer = function.__self__
model = ray.get(model_ref)
trainer.model = model
Expand Down Expand Up @@ -193,7 +213,10 @@ def _wrapping_function(

def _collect_rank_zero_results(self, trainer: "pl.Trainer",
results: Any) -> Optional["_RayOutput"]:
"""Collect the results from the rank zero process."""
"""Collect the results from the rank zero process.
This function is run on the worker process.
"""
rank_zero_debug("Finalizing the ray horovod launcher environment.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path \
Expand Down Expand Up @@ -225,7 +248,10 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer",

def _recover_results_in_main_process(self, ray_output: "_RayOutput",
trainer: "pl.Trainer") -> None:
"""Recover the results in the main process."""
"""Recover the results in the main process.
This function is run on the worker process.
"""
# transfer back the best path to the trainer
if trainer.checkpoint_callback:
trainer.checkpoint_callback.best_model_path = str(
Expand Down
68 changes: 51 additions & 17 deletions ray_lightning/launchers/ray_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def launch(self,
*args: Any,
trainer: Optional["pl.Trainer"] = None,
**kwargs: Any) -> Any:
"""Launches the function on the workers from the driver node."""
"""Launches the function on the workers from the driver node.
This function is run on the driver process.
"""
self.setup_workers()
ray_output = self.run_function_on_workers(
function, *args, trainer=trainer, **kwargs)
Expand All @@ -66,8 +69,9 @@ def launch(self,
return return_value

def setup_workers(self, tune_enabled: bool = True) -> None:
"""Creates the Ray actors and sets up PTL Trainer environment
on the worker nodes.
"""Creates the Ray actors and sets up PTL Trainer environment.
This function is run on the driver process.
"""
self._workers = [
self._create_worker() for _ in range(self._strategy.num_workers)
Expand Down Expand Up @@ -99,15 +103,21 @@ def setup_workers(self, tune_enabled: bool = True) -> None:
self.tune_queue = Queue(actor_options={"num_cpus": 0})

def _create_worker(self) -> ray.actor.ActorHandle:
"""Creates Ray actor workers."""
"""Creates Ray actor workers.
This function is run on the driver process.
"""
worker = RayExecutor.options(
num_cpus=self._strategy.num_cpus_per_worker,
num_gpus=self._strategy.num_gpus_per_worker,
resources=self._strategy.additional_resources_per_worker).remote()
return worker

def teardown_workers(self):
"""Tears down the Ray actors and PTL Trainer environment"""
"""Tears down the Ray actors and PTL Trainer environment
This function is run on the driver process.
"""
if self.tune_queue:
# Shutdown the queue.
self.tune_queue.shutdown()
Expand All @@ -119,7 +129,8 @@ def teardown_workers(self):

def get_local_ranks(self) -> List[Optional[Tuple[int, int]]]:
"""Creates a mapping of global ranks to local ranks/node ranks.
this method is to run on the worker nodes.
This function is run on the driver process.
"""
# Get the local ranks for all the workers and store as a list.
# First get the IP address of each remote worker.
Expand All @@ -146,7 +157,10 @@ def get_local_ranks(self) -> List[Optional[Tuple[int, int]]]:
return global_to_local

def _setup_env_vars(self):
"""Sets environment variables for all workers."""
"""Sets environment variables for all workers.
This function is run on the driver process.
"""
# Get rank 0 worker address and port for DDP connection.
os.environ["MASTER_ADDR"] = self._master_addr
os.environ["MASTER_PORT"] = self._master_port
Expand All @@ -162,6 +176,9 @@ def _setup_env_vars(self):

def _share_cuda_visible_devices(self):
"""Sets CUDA_VISIBLE_DEVICES on all workers.
This function is run on the driver process.
For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
visible to all workers on that worker's node.
This allows GPU workers on the same node to communicate with one
Expand Down Expand Up @@ -207,7 +224,10 @@ def run_function_on_workers(self,
trainer: Optional["pl.Trainer"] = None,
**kwargs: Any):
"""launch a function on all workers.
The actual training parts are run inside `_wrapping_function`
This function is run on the driver process.
The actual training parts are run inside `_wrapping_function`
"""
# put the model as the ray object
# and remove the model temporarily from the args
Expand Down Expand Up @@ -240,21 +260,29 @@ def _wrapping_function(
tune_queue: Queue,
) -> Any:
"""Wraps the function to run on the workers.
`results = function(*args, **kwargs)` is where the
actual training parts are run.
This function is run on the worker process.
`results = function(*args, **kwargs)` is where the
actual training parts are run.
"""
self._strategy.set_remote(True)
self._strategy.set_global_to_local(global_to_local)

# `function` is a trainer's class method
# in the ray remote tasks, its object `trainer` will also
# be copied when the function is remoted.
# `function` is a trainer's instance method
# in the ray remote tasks, its bound instance `trainer`
# will also be copied when the function is remoted.
#
# ALERT: passing the trainer as an argument of `_wrapping_function`
# does not fillfullied our purpose. Ray remote tasks will
# does not fulfill our purpose. Ray remote tasks will
# create another copy of trainer so that
# `function.__self__ != trainer`, in which the side effect only
# happens to `function.__self__` when running
# `function(*args, **kwargs)`
# `function(*args, **kwargs)` (see SOLUTION below).
#
# SOLUTION: we find the trainer directly from `function`
# by calling `function.__self__` so that we can restore
# all the side effects happened to `function.__self__`
trainer = function.__self__
trainer.model = model_ref
args = tuple([model_ref] + list(args[1:]))
Expand Down Expand Up @@ -284,7 +312,10 @@ def _wrapping_function(

def _collect_rank_zero_results(self, trainer: "pl.Trainer",
results: Any) -> Optional["_RayOutput"]:
"""Collects the results from the worker node 0."""
"""Collects the results from the worker node 0.
This function is run on the worker process.
"""
rank_zero_debug("Finalizing the Ray launcher environment.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path \
Expand Down Expand Up @@ -316,7 +347,10 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer",

def _recover_results_in_main_process(self, ray_output: "_RayOutput",
trainer: "pl.Trainer") -> None:
"""Recovers the results in the main process."""
"""Recovers the results in the main process.
This function is run on the worker process.
"""
# transfer back the best path to the trainer
if trainer.checkpoint_callback:
trainer.checkpoint_callback.best_model_path = str(
Expand Down
Loading

0 comments on commit 371cb19

Please sign in to comment.