11
11
import torch .nn as nn
12
12
import torch .optim as optim
13
13
from torch .distributed .optim import DistributedOptimizer
14
+ from torch .distributed .rpc import RRef
14
15
from torch .nn .utils .rnn import pad_sequence
15
16
from torch .utils .data import DataLoader
16
17
21
22
from torchtext .experimental .transforms import sentencepiece_tokenizer
22
23
from transforms import PretrainedSPVocab
23
24
from torchtext .experimental .models .utils import count_model_param
25
+ from torch .distributed .pipeline .sync import Pipe
24
26
25
27
26
28
def collate_batch (batch_data , args , mask_id , pad_id , text_transform ):
@@ -58,7 +60,7 @@ def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devi
58
60
return total_loss / (len (data_source ) - 1 ) # Set batch # to 1 for inference
59
61
60
62
61
- def local_step (model , data , targets , criterion , optimizer , ntokens ):
63
+ def local_step (model , data , targets , criterion , optimizer , ntokens , args ):
62
64
optimizer .zero_grad ()
63
65
output = model (data )
64
66
loss = criterion (output .view (- 1 , ntokens ), targets .view (- 1 ))
@@ -69,7 +71,18 @@ def local_step(model, data, targets, criterion, optimizer, ntokens):
69
71
return res
70
72
71
73
72
- def dist_step (model , data , targets , criterion , optimizer , ntokens ):
74
+ def pipe_step (model , data , targets , criterion , optimizer , ntokens , args ):
75
+ optimizer .zero_grad ()
76
+ output = model (data ).local_value () # Because torch.distributed.pipeline.sync.Pipe.forward returns RRef
77
+ loss = criterion (output .view (- 1 , ntokens ), targets .view (- 1 ))
78
+ loss .backward ()
79
+ res = loss .item ()
80
+ torch .nn .utils .clip_grad_norm_ (model .parameters (), args .clip )
81
+ optimizer .step ()
82
+ return res
83
+
84
+
85
+ def rpc_step (model , data , targets , criterion , optimizer , ntokens , args ):
73
86
with dist_autograd .context () as context_id :
74
87
output = model (data )
75
88
loss = criterion (output .view (- 1 , ntokens ), targets .view (- 1 ))
@@ -91,7 +104,7 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
91
104
for batch , (data , targets ) in enumerate (dataloader ):
92
105
data = data .to (devices [0 ])
93
106
targets = targets .to (devices [- 1 ])
94
- loss = step_impl (model , data , targets , criterion , optimizer , ntokens )
107
+ loss = step_impl (model , data , targets , criterion , optimizer , ntokens , args )
95
108
96
109
total_loss += loss
97
110
if batch % args .log_interval == 0 and batch > 0 :
@@ -171,12 +184,19 @@ def text_transform(x: str) -> List:
171
184
print ("Allocating memory" )
172
185
if args .pipeline_mode == 'sp' :
173
186
model = SingleProcessPipeline (shards , devices )
174
-
175
187
optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
176
188
scheduler = torch .optim .lr_scheduler .StepLR (optimizer , 1.0 , gamma = 0.75 )
177
- else :
189
+ elif args .pipeline_mode == 'pipe' :
190
+ model = Pipe (SingleProcessPipeline (shards , devices , to_device = False ), chunks = args .batch_size // args .split_size )
191
+ optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
192
+ scheduler = torch .optim .lr_scheduler .StepLR (optimizer , 1.0 , gamma = 0.75 )
193
+ elif args .pipeline_mode == 'cpu' or args .pipeline_mode == 'cuda' :
178
194
workers = [f"worker{ i + 1 } " for i in range (len (devices ))]
179
- model = RPCPipeline (shards , devices , workers , split_size = args .split_size , remote_base_class = (RemoteBaseCUDARPC if args .pipeline_mode == 'cuda' else RemoteBaseCPURPC ))
195
+ if args .pipeline_mode == 'cpu' :
196
+ impl = RemoteBaseCPURPC
197
+ elif args .pipeline_mode == 'cuda' :
198
+ impl = RemoteBaseCUDARPC
199
+ model = RPCPipeline (shards , devices , workers , split_size = args .split_size , remote_base_class = impl )
180
200
optimizer = DistributedOptimizer (
181
201
optim .Adam ,
182
202
model .parameter_rrefs (),
@@ -199,16 +219,25 @@ def text_transform(x: str) -> List:
199
219
200
220
epoch_start_time = time .time ()
201
221
last_lr = scheduler .get_last_lr ()[0 ] if scheduler is not None else args .lr
222
+
223
+ if args .pipeline_mode == 'sp' :
224
+ step = local_step
225
+ elif args .pipeline_mode == 'pipe' :
226
+ step = pipe_step
227
+ else :
228
+ step = rpc_step
229
+
230
+ if args .pipeline_mode == 'cpu' :
231
+ train_devices = ["cpu" ] # Because "TensorPipe RPC backend only supports CPU tensors by default, please move your tensors to CPU before sending them over RPC"
232
+ else :
233
+ train_devices = devices
234
+
202
235
train (model , mask_id , pad_id , train_loss_log , train_data , text_transform ,
203
- optimizer , criterion , ntokens , epoch , last_lr , args ,
204
- devices if args .pipeline_mode == 'sp' or args .pipeline_mode == 'cuda' else ["cpu" ],
205
- local_step if args .pipeline_mode == 'sp' else dist_step )
236
+ optimizer , criterion , ntokens , epoch , last_lr , args , train_devices , step )
206
237
207
238
# Turn on evaluation mode which disables dropout.
208
239
model .eval ()
209
- val_loss = evaluate (val_data , model , mask_id , pad_id , ntokens , criterion , args ,
210
- devices if args .pipeline_mode == 'sp' or args .pipeline_mode == 'cuda' else ["cpu" ],
211
- text_transform )
240
+ val_loss = evaluate (val_data , model , mask_id , pad_id , ntokens , criterion , args , train_devices , text_transform )
212
241
val_loss_log .append (val_loss )
213
242
print ('-' * 89 )
214
243
print ('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
@@ -253,7 +282,7 @@ def _forward(x):
253
282
print ('-' * 89 )
254
283
255
284
256
- def run_worker (rank , args ):
285
+ def run_worker (rank , world_size , args ):
257
286
os .environ ['MASTER_ADDR' ] = 'localhost'
258
287
os .environ ['MASTER_PORT' ] = '29500'
259
288
options = rpc .TensorPipeRpcBackendOptions (num_worker_threads = 256 )
@@ -265,7 +294,7 @@ def run_worker(rank, args):
265
294
rpc .init_rpc (
266
295
"master" ,
267
296
rank = rank ,
268
- world_size = args . gpus + 1 ,
297
+ world_size = world_size ,
269
298
rpc_backend_options = options
270
299
)
271
300
run_main (args )
@@ -278,7 +307,7 @@ def run_worker(rank, args):
278
307
rpc .init_rpc (
279
308
f"worker{ rank } " ,
280
309
rank = rank ,
281
- world_size = args . gpus + 1 ,
310
+ world_size = world_size ,
282
311
rpc_backend_options = options
283
312
)
284
313
pass
@@ -337,5 +366,8 @@ def run_worker(rank, args):
337
366
338
367
if args .pipeline_mode == 'sp' :
339
368
run_main (args )
369
+ elif args .pipeline_mode == 'pipe' :
370
+ # Because torch.distributed.pipeline.sync.Pipe.forward returns RRef and requires RPC
371
+ mp .spawn (run_worker , args = (1 , args ), nprocs = 1 , join = True )
340
372
else :
341
- mp .spawn (run_worker , args = (args , ), nprocs = args .gpus + 1 , join = True )
373
+ mp .spawn (run_worker , args = (args . gpus + 1 , args ), nprocs = args .gpus + 1 , join = True )
0 commit comments