Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 371cb19

Browse files
JiahaoYaoamogkam
andauthored
[Ray lightning 1.6] update the change according to the comment in #163 (#195)
* adding the change (based on #163) * Update ray_lightning/launchers/ray_horovod_launcher.py * Update ray_lightning/launchers/ray_launcher.py Co-authored-by: Amog Kamsetty <[email protected]>
1 parent 13468d7 commit 371cb19

File tree

4 files changed

+208
-66
lines changed

4 files changed

+208
-66
lines changed

ray_lightning/launchers/ray_horovod_launcher.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def __init__(self, strategy: "Strategy") -> None:
4848
@property
4949
def global_rank(self) -> int:
5050
"""Return the global rank of the current process.
51-
run on the worker node.
51+
52+
This function is run on the worker process.
5253
"""
5354
if not hvd.is_initialized():
5455
return 0
@@ -57,7 +58,8 @@ def global_rank(self) -> int:
5758
@property
5859
def local_rank(self) -> int:
5960
"""Return the local rank of the current process.
60-
run on the worker node.
61+
62+
This function is run on the worker process.
6163
"""
6264
if not hvd.is_initialized():
6365
return 0
@@ -66,7 +68,8 @@ def local_rank(self) -> int:
6668
@property
6769
def world_size(self) -> int:
6870
"""Return the world size of the current process.
69-
run on the worker node.
71+
72+
This function is run on the worker process.
7073
"""
7174
if not hvd.is_initialized():
7275
return self.num_workers
@@ -81,13 +84,18 @@ def launch(self,
8184
*args: Any,
8285
trainer: Optional["pl.Trainer"] = None,
8386
**kwargs: Any) -> Any:
84-
"""Launch the function on the workers and collect the results."""
87+
"""Launch the function on the workers and collect the results.
88+
89+
This function is run on the driver process.
90+
"""
8591
ray_output = self.run_function_on_workers(
8692
function, *args, trainer=trainer, **kwargs)
8793

8894
if trainer is None:
8995
raise NotImplementedError(
90-
"Ray launcher does not support trainer is None!")
96+
"Ray launcher does not support trainer is None! "
97+
"Did you override the `trainer` variable? "
98+
"If not, please help file an issue on Github.")
9199
self._recover_results_in_main_process(ray_output, trainer)
92100
return_value = ray_output.trainer_results
93101

@@ -99,8 +107,11 @@ def run_function_on_workers(self,
99107
trainer: Optional["pl.Trainer"] = None,
100108
**kwargs: Any):
101109
"""Run the function on the workers and collect the results.
102-
`executor.run_remote` is used to launch multiple ray remote tasks
103-
to distributed training the model using the horovod backend.
110+
111+
This function is run on the driver process.
112+
113+
`executor.run_remote` is used to launch multiple ray remote tasks
114+
to distributed training the model using the horovod backend.
104115
"""
105116

106117
# put the model as the ray object
@@ -109,6 +120,7 @@ def run_function_on_workers(self,
109120
model = trainer.model
110121
model_ref = ray.put(model)
111122
trainer.model = None
123+
# the model always be at the 0th position in the args
112124
new_args = tuple([None] + list(args[1:]))
113125

114126
# remove the executor temporarily from the args
@@ -147,21 +159,29 @@ def _wrapping_function(
147159
tune_queue: Queue,
148160
) -> Any:
149161
"""Wrapping function to run the function on the workers.
150-
`_wrapping_function` is run on each remote worker.
151-
`function(*args, **kwargs)` is where the actual training happens.
162+
163+
This function is run on the worker process.
164+
165+
`_wrapping_function` is run on each remote worker.
166+
`function(*args, **kwargs)` is where the actual training happens.
152167
"""
153168

154169
self._strategy.set_remote(True)
155170

156-
# `function` is a trainer's class method
157-
# in the ray remote tasks, its object `trainer` will also
158-
# be copied when the function is remoted.
171+
# `function` is a trainer's instance method
172+
# in the ray remote tasks, its bound instance `trainer`
173+
# will also be copied when the function is remoted.
174+
#
159175
# ALERT: passing the trainer as an argument of `_wrapping_function`
160-
# does not fillfullied our purpose. Ray remote tasks will
176+
# does not fulfill our purpose. Ray remote tasks will
161177
# create another copy of trainer so that
162178
# `function.__self__ != trainer`, in which the side effect only
163179
# happens to `function.__self__` when running
164-
# `function(*args, **kwargs)`
180+
# `function(*args, **kwargs)` (see SOLUTION below).
181+
#
182+
# SOLUTION: we find the trainer directly from `function`
183+
# by calling `function.__self__` so that we can restore
184+
# all the side effects happened to `function.__self__`
165185
trainer = function.__self__
166186
model = ray.get(model_ref)
167187
trainer.model = model
@@ -193,7 +213,10 @@ def _wrapping_function(
193213

194214
def _collect_rank_zero_results(self, trainer: "pl.Trainer",
195215
results: Any) -> Optional["_RayOutput"]:
196-
"""Collect the results from the rank zero process."""
216+
"""Collect the results from the rank zero process.
217+
218+
This function is run on the worker process.
219+
"""
197220
rank_zero_debug("Finalizing the ray horovod launcher environment.")
198221
checkpoint_callback = trainer.checkpoint_callback
199222
best_model_path = checkpoint_callback.best_model_path \
@@ -225,7 +248,10 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer",
225248

226249
def _recover_results_in_main_process(self, ray_output: "_RayOutput",
227250
trainer: "pl.Trainer") -> None:
228-
"""Recover the results in the main process."""
251+
"""Recover the results in the main process.
252+
253+
This function is run on the worker process.
254+
"""
229255
# transfer back the best path to the trainer
230256
if trainer.checkpoint_callback:
231257
trainer.checkpoint_callback.best_model_path = str(

ray_lightning/launchers/ray_launcher.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def launch(self,
5050
*args: Any,
5151
trainer: Optional["pl.Trainer"] = None,
5252
**kwargs: Any) -> Any:
53-
"""Launches the function on the workers from the driver node."""
53+
"""Launches the function on the workers from the driver node.
54+
55+
This function is run on the driver process.
56+
"""
5457
self.setup_workers()
5558
ray_output = self.run_function_on_workers(
5659
function, *args, trainer=trainer, **kwargs)
@@ -66,8 +69,9 @@ def launch(self,
6669
return return_value
6770

6871
def setup_workers(self, tune_enabled: bool = True) -> None:
69-
"""Creates the Ray actors and sets up PTL Trainer environment
70-
on the worker nodes.
72+
"""Creates the Ray actors and sets up PTL Trainer environment.
73+
74+
This function is run on the driver process.
7175
"""
7276
self._workers = [
7377
self._create_worker() for _ in range(self._strategy.num_workers)
@@ -99,15 +103,21 @@ def setup_workers(self, tune_enabled: bool = True) -> None:
99103
self.tune_queue = Queue(actor_options={"num_cpus": 0})
100104

101105
def _create_worker(self) -> ray.actor.ActorHandle:
102-
"""Creates Ray actor workers."""
106+
"""Creates Ray actor workers.
107+
108+
This function is run on the driver process.
109+
"""
103110
worker = RayExecutor.options(
104111
num_cpus=self._strategy.num_cpus_per_worker,
105112
num_gpus=self._strategy.num_gpus_per_worker,
106113
resources=self._strategy.additional_resources_per_worker).remote()
107114
return worker
108115

109116
def teardown_workers(self):
110-
"""Tears down the Ray actors and PTL Trainer environment"""
117+
"""Tears down the Ray actors and PTL Trainer environment
118+
119+
This function is run on the driver process.
120+
"""
111121
if self.tune_queue:
112122
# Shutdown the queue.
113123
self.tune_queue.shutdown()
@@ -119,7 +129,8 @@ def teardown_workers(self):
119129

120130
def get_local_ranks(self) -> List[Optional[Tuple[int, int]]]:
121131
"""Creates a mapping of global ranks to local ranks/node ranks.
122-
this method is to run on the worker nodes.
132+
133+
This function is run on the driver process.
123134
"""
124135
# Get the local ranks for all the workers and store as a list.
125136
# First get the IP address of each remote worker.
@@ -146,7 +157,10 @@ def get_local_ranks(self) -> List[Optional[Tuple[int, int]]]:
146157
return global_to_local
147158

148159
def _setup_env_vars(self):
149-
"""Sets environment variables for all workers."""
160+
"""Sets environment variables for all workers.
161+
162+
This function is run on the driver process.
163+
"""
150164
# Get rank 0 worker address and port for DDP connection.
151165
os.environ["MASTER_ADDR"] = self._master_addr
152166
os.environ["MASTER_PORT"] = self._master_port
@@ -162,6 +176,9 @@ def _setup_env_vars(self):
162176

163177
def _share_cuda_visible_devices(self):
164178
"""Sets CUDA_VISIBLE_DEVICES on all workers.
179+
180+
This function is run on the driver process.
181+
165182
For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
166183
visible to all workers on that worker's node.
167184
This allows GPU workers on the same node to communicate with one
@@ -207,7 +224,10 @@ def run_function_on_workers(self,
207224
trainer: Optional["pl.Trainer"] = None,
208225
**kwargs: Any):
209226
"""launch a function on all workers.
210-
The actual training parts are run inside `_wrapping_function`
227+
228+
This function is run on the driver process.
229+
230+
The actual training parts are run inside `_wrapping_function`
211231
"""
212232
# put the model as the ray object
213233
# and remove the model temporarily from the args
@@ -240,21 +260,29 @@ def _wrapping_function(
240260
tune_queue: Queue,
241261
) -> Any:
242262
"""Wraps the function to run on the workers.
243-
`results = function(*args, **kwargs)` is where the
244-
actual training parts are run.
263+
264+
This function is run on the worker process.
265+
266+
`results = function(*args, **kwargs)` is where the
267+
actual training parts are run.
245268
"""
246269
self._strategy.set_remote(True)
247270
self._strategy.set_global_to_local(global_to_local)
248271

249-
# `function` is a trainer's class method
250-
# in the ray remote tasks, its object `trainer` will also
251-
# be copied when the function is remoted.
272+
# `function` is a trainer's instance method
273+
# in the ray remote tasks, its bound instance `trainer`
274+
# will also be copied when the function is remoted.
275+
#
252276
# ALERT: passing the trainer as an argument of `_wrapping_function`
253-
# does not fillfullied our purpose. Ray remote tasks will
277+
# does not fulfill our purpose. Ray remote tasks will
254278
# create another copy of trainer so that
255279
# `function.__self__ != trainer`, in which the side effect only
256280
# happens to `function.__self__` when running
257-
# `function(*args, **kwargs)`
281+
# `function(*args, **kwargs)` (see SOLUTION below).
282+
#
283+
# SOLUTION: we find the trainer directly from `function`
284+
# by calling `function.__self__` so that we can restore
285+
# all the side effects happened to `function.__self__`
258286
trainer = function.__self__
259287
trainer.model = model_ref
260288
args = tuple([model_ref] + list(args[1:]))
@@ -284,7 +312,10 @@ def _wrapping_function(
284312

285313
def _collect_rank_zero_results(self, trainer: "pl.Trainer",
286314
results: Any) -> Optional["_RayOutput"]:
287-
"""Collects the results from the worker node 0."""
315+
"""Collects the results from the worker node 0.
316+
317+
This function is run on the worker process.
318+
"""
288319
rank_zero_debug("Finalizing the Ray launcher environment.")
289320
checkpoint_callback = trainer.checkpoint_callback
290321
best_model_path = checkpoint_callback.best_model_path \
@@ -316,7 +347,10 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer",
316347

317348
def _recover_results_in_main_process(self, ray_output: "_RayOutput",
318349
trainer: "pl.Trainer") -> None:
319-
"""Recovers the results in the main process."""
350+
"""Recovers the results in the main process.
351+
352+
This function is run on the worker process.
353+
"""
320354
# transfer back the best path to the trainer
321355
if trainer.checkpoint_callback:
322356
trainer.checkpoint_callback.best_model_path = str(

0 commit comments

Comments
 (0)