15
15
from torch .utils .data import DataLoader
16
16
17
17
from data import CC100
18
- from dist_model import DistCrossLingualMLMTask
19
18
from model import CrossLingualMLMTask
19
+ from pipeline import SingleProcessPipeline , RPCPipeline , RemoteBaseCPURPC , RemoteBaseCUDARPC
20
+ from shard_model import XLMRModelShards , MLMShards
20
21
from torchtext .experimental .transforms import sentencepiece_tokenizer
21
22
from transforms import PretrainedSPVocab
23
+ from torchtext .experimental .models .utils import count_model_param
22
24
23
25
24
26
def collate_batch (batch_data , args , mask_id , pad_id , text_transform ):
@@ -43,27 +45,28 @@ def collate_batch(batch_data, args, mask_id, pad_id, text_transform):
43
45
return batch_data , targets
44
46
45
47
46
- def evaluate (data_source , model , mask_id , pad_id , ntokens , criterion , args , device , text_transform ):
48
+ def evaluate (data_source , model , mask_id , pad_id , ntokens , criterion , args , devices , text_transform ):
47
49
total_loss = 0.
48
50
dataloader = DataLoader (data_source , batch_size = 1 , # Set batch # to 1 for inference
49
51
shuffle = False , collate_fn = lambda b : collate_batch (b , args , mask_id , pad_id , text_transform ))
50
52
with torch .no_grad ():
51
53
for batch , (data , targets ) in enumerate (dataloader ):
52
- data = data .to (device )
53
- targets = targets .to (device )
54
+ data = data .to (devices [ 0 ] )
55
+ targets = targets .to (devices [ - 1 ] )
54
56
output = model (data )
55
57
total_loss += criterion (output .view (- 1 , ntokens ), targets .view (- 1 )).item ()
56
58
return total_loss / (len (data_source ) - 1 ) # Set batch # to 1 for inference
57
59
58
60
59
- def step (model , data , targets , criterion , optimizer , ntokens ):
61
+ def local_step (model , data , targets , criterion , optimizer , ntokens ):
60
62
optimizer .zero_grad ()
61
63
output = model (data )
62
64
loss = criterion (output .view (- 1 , ntokens ), targets .view (- 1 ))
63
65
loss .backward ()
66
+ res = loss .item ()
64
67
torch .nn .utils .clip_grad_norm_ (model .parameters (), args .clip )
65
68
optimizer .step ()
66
- return loss
69
+ return res
67
70
68
71
69
72
def dist_step (model , data , targets , criterion , optimizer , ntokens ):
@@ -73,11 +76,11 @@ def dist_step(model, data, targets, criterion, optimizer, ntokens):
73
76
dist_autograd .backward (context_id , [loss ])
74
77
# torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
75
78
optimizer .step (context_id )
76
- return loss
79
+ return loss . item ()
77
80
78
81
79
82
def train (model , mask_id , pad_id , train_loss_log , train_data , text_transform ,
80
- optimizer , criterion , ntokens , epoch , last_lr , args , device , step_impl ):
83
+ optimizer , criterion , ntokens , epoch , last_lr , args , devices , step_impl ):
81
84
model .train ()
82
85
total_loss = 0.
83
86
start_time = time .time ()
@@ -86,9 +89,11 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
86
89
shuffle = False , collate_fn = lambda b : collate_batch (b , args , mask_id , pad_id , text_transform ))
87
90
88
91
for batch , (data , targets ) in enumerate (dataloader ):
89
- loss = step_impl (model , data .to (device ), targets .to (device ), criterion , optimizer , ntokens )
92
+ data = data .to (devices [0 ])
93
+ targets = targets .to (devices [- 1 ])
94
+ loss = step_impl (model , data , targets , criterion , optimizer , ntokens )
90
95
91
- total_loss += loss . item ()
96
+ total_loss += loss
92
97
if batch % args .log_interval == 0 and batch > 0 :
93
98
cur_loss = total_loss / args .log_interval
94
99
elapsed = time .time () - start_time
@@ -116,47 +121,101 @@ def text_transform(x: str) -> List:
116
121
pad_id = vocab (['pad' ])[0 ]
117
122
ntokens = len (vocab )
118
123
119
- if not args .dist :
120
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
121
- model = CrossLingualMLMTask (ntokens , args .emsize , args .nhead , args .nhid , args .nlayers , args .dropout )
122
- model = model .to (device )
124
+ xlmr = XLMRModelShards (ntokens , args .emsize , args .nhead , args .nhid , args .dropout )
125
+ mlm = MLMShards (ntokens , args .emsize )
126
+ devices = [f"cuda:{ i } " for i in range (args .gpus )] if torch .cuda .is_available () else ["cpu" ]
127
+
128
+ if len (devices ) == 1 :
129
+ # In case of one device combine all layers into a single nn.Sequential
130
+ shards = [nn .Sequential (
131
+ xlmr .xlmr_embed (),
132
+ xlmr .encoder_layers (args .nlayers ),
133
+ mlm .mlm ()
134
+ )]
135
+ elif len (devices ) == 2 :
136
+ # In case of two devices split the model right in the middle and
137
+ # put the embeddings and half of encoders to the first shard and
138
+ # another half of encoders and mlm head to the second.
139
+ assert args .nlayers % 2 == 0
140
+ shards = [
141
+ nn .Sequential (
142
+ xlmr .xlmr_embed (),
143
+ xlmr .encoder_layers (args .nlayers // 2 )
144
+ ),
145
+ nn .Sequential (
146
+ xlmr .encoder_layers (args .nlayers // 2 ),
147
+ mlm .mlm ()
148
+ )
149
+ ]
150
+ else :
151
+ # In case of more that 2 devices put the embeddings and mlm head
152
+ # to the first and the last shard and split the encoders to equal
153
+ # parts among the rest of the shards
154
+ encoder_gpus = (args .gpus - 2 )
155
+ assert args .nlayers % encoder_gpus == 0
156
+ encoders_per_gpu = args .nlayers // encoder_gpus
157
+ shards = [
158
+ xlmr .xlmr_embed (),
159
+ * [xlmr .encoder_layers (encoders_per_gpu ) for _ in range (encoder_gpus )],
160
+ mlm .mlm ()
161
+ ]
162
+
163
+ print ('Shards parameters:' )
164
+ total = 0
165
+ for i , shard in enumerate (shards ):
166
+ params = count_model_param (shard )
167
+ total += params
168
+ print (f'shard{ i } = { int (params )} M' )
169
+ print (f'total = { int (total )} M' )
170
+
171
+ print ("Allocating memory" )
172
+ if args .pipeline_mode == 'sp' :
173
+ model = SingleProcessPipeline (shards , devices )
174
+
123
175
optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
124
176
scheduler = torch .optim .lr_scheduler .StepLR (optimizer , 1.0 , gamma = 0.75 )
125
177
else :
126
- device = "cpu"
127
- model = DistCrossLingualMLMTask ( args . split_size , [ "worker1" , "worker2" ], ntokens , args .emsize , args . nhead , args . nhid , args .nlayers , args . dropout )
178
+ workers = [ f"worker { i + 1 } " for i in range ( len ( devices ))]
179
+ model = RPCPipeline ( shards , devices , workers , split_size = args .split_size , remote_base_class = ( RemoteBaseCUDARPC if args .pipeline_mode == 'cuda' else RemoteBaseCPURPC ) )
128
180
optimizer = DistributedOptimizer (
129
181
optim .Adam ,
130
182
model .parameter_rrefs (),
131
183
lr = args .lr ,
132
184
)
133
185
scheduler = None
134
186
187
+ print ("Memory allocated" )
188
+ # input("Memory allocated, check nvidia-smi for memory consumption")
189
+
135
190
criterion = nn .CrossEntropyLoss (ignore_index = pad_id )
136
191
best_val_loss = None
137
192
train_loss_log , val_loss_log = [], []
138
193
139
194
for epoch in range (1 , args .epochs + 1 ):
140
- train_data = CC100 ('/datasets01/cc100/031720/' , {'*.txt' }, start_line = args .start_line , num_lines = args .num_lines )
195
+ train_data = CC100 (args . cc100_path , {'*.txt' }, start_line = args .start_line , num_lines = args .num_lines )
141
196
from torchtext .datasets import WikiText2
142
- val_data , = WikiText2 (data_select = 'valid' )
197
+ val_data = WikiText2 (split = 'valid' )
143
198
val_data = [(17 , item ) for item in val_data if item != ' \n ' ] # english language type is 17 in CC100 dataset
144
199
145
200
epoch_start_time = time .time ()
146
201
last_lr = scheduler .get_last_lr ()[0 ] if scheduler is not None else args .lr
147
202
train (model , mask_id , pad_id , train_loss_log , train_data , text_transform ,
148
- optimizer , criterion , ntokens , epoch , last_lr , args , device , step if not args .dist else dist_step )
203
+ optimizer , criterion , ntokens , epoch , last_lr , args ,
204
+ devices if args .pipeline_mode == 'sp' or args .pipeline_mode == 'cuda' else ["cpu" ],
205
+ local_step if args .pipeline_mode == 'sp' else dist_step )
149
206
150
207
# Turn on evaluation mode which disables dropout.
151
208
model .eval ()
152
- val_loss = evaluate (val_data , model , mask_id , pad_id , ntokens , criterion , args , device , text_transform )
209
+ val_loss = evaluate (val_data , model , mask_id , pad_id , ntokens , criterion , args ,
210
+ devices if args .pipeline_mode == 'sp' or args .pipeline_mode == 'cuda' else ["cpu" ],
211
+ text_transform )
153
212
val_loss_log .append (val_loss )
154
213
print ('-' * 89 )
155
214
print ('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
156
215
'valid ppl {:8.2f}' .format (epoch , (time .time () - epoch_start_time ),
157
216
val_loss , math .exp (val_loss )))
158
217
print ('-' * 89 )
159
- if not args .dist and not best_val_loss or val_loss < best_val_loss :
218
+ if args .pipeline_mode == 'sp' and not best_val_loss or val_loss < best_val_loss :
160
219
with open (args .save , 'wb' ) as f :
161
220
torch .save (model , f )
162
221
best_val_loss = val_loss
@@ -173,7 +232,7 @@ def text_transform(x: str) -> List:
173
232
def text_transform (x : str ) -> List :
174
233
return ref_model .encode (x ).tolist ()
175
234
model = ref_model .model .encoder
176
- model = model .to (device )
235
+ model = model .to (devices [ 0 ] )
177
236
# Turn on evaluation mode which disables dropout.
178
237
model .eval ()
179
238
# from fairseq XLM-R model
@@ -187,7 +246,7 @@ def _forward(x):
187
246
return nn_model (x .transpose (0 , 1 ))[0 ].transpose (0 , 1 )
188
247
return _forward
189
248
val_loss = evaluate (val_data , model_forward (model ), mask_id , pad_id , ref_ntokens ,
190
- criterion , args , device , text_transform )
249
+ criterion , args , devices [ 0 ] , text_transform )
191
250
print ('-' * 89 )
192
251
print ('| reference model | valid loss {:5.2f} | '
193
252
'valid ppl {:8.2f}' .format (val_loss , math .exp (val_loss )))
@@ -200,18 +259,26 @@ def run_worker(rank, args):
200
259
options = rpc .TensorPipeRpcBackendOptions (num_worker_threads = 256 )
201
260
202
261
if rank == 0 :
262
+ if args .pipeline_mode == 'cuda' :
263
+ for i in range (args .gpus ):
264
+ options .set_device_map ("worker" + str (i + 1 ), {i :i })
203
265
rpc .init_rpc (
204
266
"master" ,
205
267
rank = rank ,
206
- world_size = args .world_size ,
268
+ world_size = args .gpus + 1 ,
207
269
rpc_backend_options = options
208
270
)
209
271
run_main (args )
210
272
else :
273
+ if args .pipeline_mode == 'cuda' :
274
+ if rank == 1 :
275
+ options .set_device_map ("master" , {0 :0 })
276
+ else :
277
+ options .set_device_map ("worker" + str (rank - 1 ), {(rank - 1 ):(rank - 2 )})
211
278
rpc .init_rpc (
212
279
f"worker{ rank } " ,
213
280
rank = rank ,
214
- world_size = args .world_size ,
281
+ world_size = args .gpus + 1 ,
215
282
rpc_backend_options = options
216
283
)
217
284
pass
@@ -258,15 +325,17 @@ def run_worker(rank, args):
258
325
help = 'path to load the reference model for evaluation' )
259
326
parser .add_argument ('--mask_frac' , type = float , default = 0.15 ,
260
327
help = 'the fraction of masked tokens' )
261
- parser .add_argument ('--dist' , action = 'store_true' ,
262
- help = 'run distributed version' )
263
- parser .add_argument ('--world_size' , type = int , default = 3 ,
264
- help = 'world_size' )
328
+ parser .add_argument ('--cc100_path' , type = str , default = '/datasets01/cc100/031720/' ,
329
+ help = 'path to cc100' )
330
+ parser .add_argument ('--gpus' , type = int , default = 1 ,
331
+ help = 'number of GPUs to use' )
332
+ parser .add_argument ('--pipeline_mode' , type = str , default = 'sp' ,
333
+ help = 'pipeline mode, `cpu` for CPU RPC, `cuda` for CUDA RPC, `sp` for single process pipeline' )
265
334
parser .add_argument ('--split_size' , type = int , default = 8 ,
266
335
help = 'split the input batch into micro-batches' )
267
336
args = parser .parse_args ()
268
337
269
- if args .dist :
270
- mp .spawn (run_worker , args = (args ,), nprocs = args .world_size , join = True )
271
- else :
338
+ if args .pipeline_mode == 'sp' :
272
339
run_main (args )
340
+ else :
341
+ mp .spawn (run_worker , args = (args ,), nprocs = args .gpus + 1 , join = True )
0 commit comments