@@ -183,24 +183,35 @@ def update_iterations(neox_args, data_loaders):
183
183
to do as many iterations as possible while ensuring that each example is seen *at most* train_epochs
184
184
times.
185
185
"""
186
- if neox_args .train_iters is not None :
186
+ if ( not neox_args .do_train ) or ( neox_args . train_iters is not None ) :
187
187
pass
188
188
elif neox_args .train_iters is None and neox_args .train_epochs is None :
189
189
print_rank_0 (
190
190
"ERROR:Failed to specify either train_epochs or train_iters in config file"
191
191
)
192
192
else :
193
- train_dataloader = data_loaders ["train" ]
194
- train_epochs = neox_args .train_epochs
195
- gradient_accumulation_steps = neox_args .gradient_accumulation_steps
193
+ global_rank = torch .distributed .get_rank ()
196
194
197
- train_iterations = (
198
- len (train_dataloader ) * train_epochs
199
- ) // gradient_accumulation_steps
195
+ if global_rank == 0 :
196
+ train_dataloader = data_loaders ["train" ]
197
+ train_epochs = neox_args .train_epochs
198
+ gradient_accumulation_steps = neox_args .gradient_accumulation_steps
199
+
200
+ train_dataloader_len = len (train_dataloader )
201
+ train_iterations = (
202
+ train_dataloader_len * train_epochs
203
+ ) // gradient_accumulation_steps
204
+
205
+ train_iters_tensor = torch .cuda .LongTensor ([train_iterations ])
206
+ else :
207
+ train_iters_tensor = torch .cuda .LongTensor ([0 ])
208
+
209
+ torch .distributed .broadcast (train_iters_tensor , src = 0 )
210
+
211
+ neox_args .train_iters = train_iters_tensor [0 ].item ()
200
212
201
- neox_args .train_iters = train_iterations
202
213
print_rank_0 (
203
- f"Training for a total of { train_iterations } iterations, corresponding to { train_epochs } epochs."
214
+ f"Training for a total of { neox_args . train_iters } iterations, corresponding to { neox_args . train_epochs } epochs."
204
215
)
205
216
206
217
0 commit comments