@@ -50,7 +50,10 @@ def launch(self,
50
50
* args : Any ,
51
51
trainer : Optional ["pl.Trainer" ] = None ,
52
52
** kwargs : Any ) -> Any :
53
- """Launches the function on the workers from the driver node."""
53
+ """Launches the function on the workers from the driver node.
54
+
55
+ This function is run on the driver process.
56
+ """
54
57
self .setup_workers ()
55
58
ray_output = self .run_function_on_workers (
56
59
function , * args , trainer = trainer , ** kwargs )
@@ -66,8 +69,9 @@ def launch(self,
66
69
return return_value
67
70
68
71
def setup_workers (self , tune_enabled : bool = True ) -> None :
69
- """Creates the Ray actors and sets up PTL Trainer environment
70
- on the worker nodes.
72
+ """Creates the Ray actors and sets up PTL Trainer environment.
73
+
74
+ This function is run on the driver process.
71
75
"""
72
76
self ._workers = [
73
77
self ._create_worker () for _ in range (self ._strategy .num_workers )
@@ -99,15 +103,21 @@ def setup_workers(self, tune_enabled: bool = True) -> None:
99
103
self .tune_queue = Queue (actor_options = {"num_cpus" : 0 })
100
104
101
105
def _create_worker (self ) -> ray .actor .ActorHandle :
102
- """Creates Ray actor workers."""
106
+ """Creates Ray actor workers.
107
+
108
+ This function is run on the driver process.
109
+ """
103
110
worker = RayExecutor .options (
104
111
num_cpus = self ._strategy .num_cpus_per_worker ,
105
112
num_gpus = self ._strategy .num_gpus_per_worker ,
106
113
resources = self ._strategy .additional_resources_per_worker ).remote ()
107
114
return worker
108
115
109
116
def teardown_workers (self ):
110
- """Tears down the Ray actors and PTL Trainer environment"""
117
+ """Tears down the Ray actors and PTL Trainer environment
118
+
119
+ This function is run on the driver process.
120
+ """
111
121
if self .tune_queue :
112
122
# Shutdown the queue.
113
123
self .tune_queue .shutdown ()
@@ -119,7 +129,8 @@ def teardown_workers(self):
119
129
120
130
def get_local_ranks (self ) -> List [Optional [Tuple [int , int ]]]:
121
131
"""Creates a mapping of global ranks to local ranks/node ranks.
122
- this method is to run on the worker nodes.
132
+
133
+ This function is run on the driver process.
123
134
"""
124
135
# Get the local ranks for all the workers and store as a list.
125
136
# First get the IP address of each remote worker.
@@ -146,7 +157,10 @@ def get_local_ranks(self) -> List[Optional[Tuple[int, int]]]:
146
157
return global_to_local
147
158
148
159
def _setup_env_vars (self ):
149
- """Sets environment variables for all workers."""
160
+ """Sets environment variables for all workers.
161
+
162
+ This function is run on the driver process.
163
+ """
150
164
# Get rank 0 worker address and port for DDP connection.
151
165
os .environ ["MASTER_ADDR" ] = self ._master_addr
152
166
os .environ ["MASTER_PORT" ] = self ._master_port
@@ -162,6 +176,9 @@ def _setup_env_vars(self):
162
176
163
177
def _share_cuda_visible_devices (self ):
164
178
"""Sets CUDA_VISIBLE_DEVICES on all workers.
179
+
180
+ This function is run on the driver process.
181
+
165
182
For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
166
183
visible to all workers on that worker's node.
167
184
This allows GPU workers on the same node to communicate with one
@@ -207,7 +224,10 @@ def run_function_on_workers(self,
207
224
trainer : Optional ["pl.Trainer" ] = None ,
208
225
** kwargs : Any ):
209
226
"""launch a function on all workers.
210
- The actual training parts are run inside `_wrapping_function`
227
+
228
+ This function is run on the driver process.
229
+
230
+ The actual training parts are run inside `_wrapping_function`
211
231
"""
212
232
# put the model as the ray object
213
233
# and remove the model temporarily from the args
@@ -240,21 +260,29 @@ def _wrapping_function(
240
260
tune_queue : Queue ,
241
261
) -> Any :
242
262
"""Wraps the function to run on the workers.
243
- `results = function(*args, **kwargs)` is where the
244
- actual training parts are run.
263
+
264
+ This function is run on the worker process.
265
+
266
+ `results = function(*args, **kwargs)` is where the
267
+ actual training parts are run.
245
268
"""
246
269
self ._strategy .set_remote (True )
247
270
self ._strategy .set_global_to_local (global_to_local )
248
271
249
- # `function` is a trainer's class method
250
- # in the ray remote tasks, its object `trainer` will also
251
- # be copied when the function is remoted.
272
+ # `function` is a trainer's instance method
273
+ # in the ray remote tasks, its bound instance `trainer`
274
+ # will also be copied when the function is remoted.
275
+ #
252
276
# ALERT: passing the trainer as an argument of `_wrapping_function`
253
- # does not fillfullied our purpose. Ray remote tasks will
277
+ # does not fulfill our purpose. Ray remote tasks will
254
278
# create another copy of trainer so that
255
279
# `function.__self__ != trainer`, in which the side effect only
256
280
# happens to `function.__self__` when running
257
- # `function(*args, **kwargs)`
281
+ # `function(*args, **kwargs)` (see SOLUTION below).
282
+ #
283
+ # SOLUTION: we find the trainer directly from `function`
284
+ # by calling `function.__self__` so that we can restore
285
+ # all the side effects happened to `function.__self__`
258
286
trainer = function .__self__
259
287
trainer .model = model_ref
260
288
args = tuple ([model_ref ] + list (args [1 :]))
@@ -284,7 +312,10 @@ def _wrapping_function(
284
312
285
313
def _collect_rank_zero_results (self , trainer : "pl.Trainer" ,
286
314
results : Any ) -> Optional ["_RayOutput" ]:
287
- """Collects the results from the worker node 0."""
315
+ """Collects the results from the worker node 0.
316
+
317
+ This function is run on the worker process.
318
+ """
288
319
rank_zero_debug ("Finalizing the Ray launcher environment." )
289
320
checkpoint_callback = trainer .checkpoint_callback
290
321
best_model_path = checkpoint_callback .best_model_path \
@@ -316,7 +347,10 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer",
316
347
317
348
def _recover_results_in_main_process (self , ray_output : "_RayOutput" ,
318
349
trainer : "pl.Trainer" ) -> None :
319
- """Recovers the results in the main process."""
350
+ """Recovers the results in the main process.
351
+
352
+ This function is run on the worker process.
353
+ """
320
354
# transfer back the best path to the trainer
321
355
if trainer .checkpoint_callback :
322
356
trainer .checkpoint_callback .best_model_path = str (
0 commit comments