1
1
import argparse
2
- import time
3
2
import math
3
+ import os
4
+ import time
5
+ from typing import List
6
+
4
7
import torch
8
+ import torch .distributed .autograd as dist_autograd
9
+ import torch .distributed .rpc as rpc
10
+ import torch .multiprocessing as mp
5
11
import torch .nn as nn
12
+ import torch .optim as optim
13
+ from torch .distributed .optim import DistributedOptimizer
14
+ from torch .nn .utils .rnn import pad_sequence
15
+ from torch .utils .data import DataLoader
16
+
6
17
from data import CC100
18
+ from dist_model import DistCrossLingualMLMTask
7
19
from model import CrossLingualMLMTask
8
- from torch .utils .data import DataLoader
9
20
from torchtext .experimental .transforms import sentencepiece_tokenizer
10
21
from transforms import PretrainedSPVocab
11
- from torch .nn .utils .rnn import pad_sequence
12
- from typing import List
13
22
14
23
15
24
def collate_batch (batch_data , args , mask_id , pad_id , text_transform ):
@@ -47,8 +56,28 @@ def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devi
47
56
return total_loss / (len (data_source ) - 1 ) # Set batch # to 1 for inference
48
57
49
58
59
+ def step (model , data , targets , criterion , optimizer , ntokens ):
60
+ optimizer .zero_grad ()
61
+ output = model (data )
62
+ loss = criterion (output .view (- 1 , ntokens ), targets .view (- 1 ))
63
+ loss .backward ()
64
+ torch .nn .utils .clip_grad_norm_ (model .parameters (), args .clip )
65
+ optimizer .step ()
66
+ return loss
67
+
68
+
69
+ def dist_step (model , data , targets , criterion , optimizer , ntokens ):
70
+ with dist_autograd .context () as context_id :
71
+ output = model (data )
72
+ loss = criterion (output .view (- 1 , ntokens ), targets .view (- 1 ))
73
+ dist_autograd .backward (context_id , [loss ])
74
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
75
+ optimizer .step (context_id )
76
+ return loss
77
+
78
+
50
79
def train (model , mask_id , pad_id , train_loss_log , train_data , text_transform ,
51
- optimizer , criterion , ntokens , epoch , scheduler , args , device , rank = None ):
80
+ optimizer , criterion , ntokens , epoch , last_lr , args , device , step_impl ):
52
81
model .train ()
53
82
total_loss = 0.
54
83
start_time = time .time ()
@@ -57,33 +86,25 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
57
86
shuffle = False , collate_fn = lambda b : collate_batch (b , args , mask_id , pad_id , text_transform ))
58
87
59
88
for batch , (data , targets ) in enumerate (dataloader ):
60
- optimizer .zero_grad ()
61
- data = data .to (device )
62
- targets = targets .to (device )
63
- output = model (data )
64
- loss = criterion (output .view (- 1 , ntokens ), targets .view (- 1 ))
65
- loss .backward ()
66
- torch .nn .utils .clip_grad_norm_ (model .parameters (), args .clip )
67
- optimizer .step ()
89
+ loss = step_impl (model , data .to (device ), targets .to (device ), criterion , optimizer , ntokens )
90
+
68
91
total_loss += loss .item ()
69
92
if batch % args .log_interval == 0 and batch > 0 :
70
93
cur_loss = total_loss / args .log_interval
71
94
elapsed = time .time () - start_time
72
- if (rank is None ) or rank == 0 :
73
- train_loss_log [- 1 ] = cur_loss
74
- print ('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
75
- 'loss {:5.2f} | ppl {:8.2f}' .format (epoch , batch ,
76
- len (train_data ) // args .batch_size ,
77
- scheduler .get_last_lr ()[0 ],
78
- elapsed * 1000 / args .log_interval ,
79
- cur_loss , math .exp (cur_loss )))
95
+ train_loss_log [- 1 ] = cur_loss
96
+ print ('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
97
+ 'loss {:5.2f} | ppl {:8.2f}' .format (epoch , batch ,
98
+ len (train_data ) // args .batch_size ,
99
+ last_lr ,
100
+ elapsed * 1000 / args .log_interval ,
101
+ cur_loss , math .exp (cur_loss )))
80
102
total_loss = 0
81
103
start_time = time .time ()
82
104
83
105
84
- def run_main (args , rank = None ):
106
+ def run_main (args ):
85
107
torch .manual_seed (args .seed )
86
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
87
108
88
109
# Set up tokenizer and vocab
89
110
tokenizer = sentencepiece_tokenizer (args .spm_path )
@@ -95,11 +116,23 @@ def text_transform(x: str) -> List:
95
116
pad_id = vocab (['pad' ])[0 ]
96
117
ntokens = len (vocab )
97
118
98
- model = CrossLingualMLMTask (ntokens , args .emsize , args .nhead , args .nhid , args .nlayers , args .dropout )
99
- model = model .to (device )
119
+ if not args .dist :
120
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
121
+ model = CrossLingualMLMTask (ntokens , args .emsize , args .nhead , args .nhid , args .nlayers , args .dropout )
122
+ model = model .to (device )
123
+ optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
124
+ scheduler = torch .optim .lr_scheduler .StepLR (optimizer , 1.0 , gamma = 0.75 )
125
+ else :
126
+ device = "cpu"
127
+ model = DistCrossLingualMLMTask (args .split_size , ["worker1" , "worker2" ], ntokens , args .emsize , args .nhead , args .nhid , args .nlayers , args .dropout )
128
+ optimizer = DistributedOptimizer (
129
+ optim .Adam ,
130
+ model .parameter_rrefs (),
131
+ lr = args .lr ,
132
+ )
133
+ scheduler = None
134
+
100
135
criterion = nn .CrossEntropyLoss (ignore_index = pad_id )
101
- optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
102
- scheduler = torch .optim .lr_scheduler .StepLR (optimizer , 1.0 , gamma = 0.75 )
103
136
best_val_loss = None
104
137
train_loss_log , val_loss_log = [], []
105
138
@@ -110,8 +143,9 @@ def text_transform(x: str) -> List:
110
143
val_data = [(17 , item ) for item in val_data if item != ' \n ' ] # english language type is 17 in CC100 dataset
111
144
112
145
epoch_start_time = time .time ()
146
+ last_lr = scheduler .get_last_lr ()[0 ] if scheduler is not None else args .lr
113
147
train (model , mask_id , pad_id , train_loss_log , train_data , text_transform ,
114
- optimizer , criterion , ntokens , epoch , scheduler , args , device , rank )
148
+ optimizer , criterion , ntokens , epoch , last_lr , args , device , step if not args . dist else dist_step )
115
149
116
150
# Turn on evaluation mode which disables dropout.
117
151
model .eval ()
@@ -122,12 +156,13 @@ def text_transform(x: str) -> List:
122
156
'valid ppl {:8.2f}' .format (epoch , (time .time () - epoch_start_time ),
123
157
val_loss , math .exp (val_loss )))
124
158
print ('-' * 89 )
125
- if not best_val_loss or val_loss < best_val_loss :
159
+ if not args . dist and not best_val_loss or val_loss < best_val_loss :
126
160
with open (args .save , 'wb' ) as f :
127
161
torch .save (model , f )
128
162
best_val_loss = val_loss
129
163
else :
130
- scheduler .step ()
164
+ if scheduler is not None :
165
+ scheduler .step ()
131
166
132
167
# Run reference XLM-R model from fairseq
133
168
if args .eval_ref != 'None' :
@@ -159,6 +194,32 @@ def _forward(x):
159
194
print ('-' * 89 )
160
195
161
196
197
+ def run_worker (rank , args ):
198
+ os .environ ['MASTER_ADDR' ] = 'localhost'
199
+ os .environ ['MASTER_PORT' ] = '29500'
200
+ options = rpc .TensorPipeRpcBackendOptions (num_worker_threads = 256 )
201
+
202
+ if rank == 0 :
203
+ rpc .init_rpc (
204
+ "master" ,
205
+ rank = rank ,
206
+ world_size = args .world_size ,
207
+ rpc_backend_options = options
208
+ )
209
+ run_main (args )
210
+ else :
211
+ rpc .init_rpc (
212
+ f"worker{ rank } " ,
213
+ rank = rank ,
214
+ world_size = args .world_size ,
215
+ rpc_backend_options = options
216
+ )
217
+ pass
218
+
219
+ # block until all rpcs finish
220
+ rpc .shutdown ()
221
+
222
+
162
223
if __name__ == "__main__" :
163
224
parser = argparse .ArgumentParser (description = 'PyTorch Cross-lingual XLM MLM' )
164
225
parser .add_argument ('--emsize' , type = int , default = 768 ,
@@ -197,6 +258,15 @@ def _forward(x):
197
258
help = 'path to load the reference model for evaluation' )
198
259
parser .add_argument ('--mask_frac' , type = float , default = 0.15 ,
199
260
help = 'the fraction of masked tokens' )
261
+ parser .add_argument ('--dist' , action = 'store_true' ,
262
+ help = 'run distributed version' )
263
+ parser .add_argument ('--world_size' , type = int , default = 3 ,
264
+ help = 'world_size' )
265
+ parser .add_argument ('--split_size' , type = int , default = 8 ,
266
+ help = 'split the input batch into micro-batches' )
200
267
args = parser .parse_args ()
201
268
202
- run_main (args )
269
+ if args .dist :
270
+ mp .spawn (run_worker , args = (args ,), nprocs = args .world_size , join = True )
271
+ else :
272
+ run_main (args )
0 commit comments