Skip to content

Commit 774eb58

Browse files
hotfix for tp >= 2 and pp > 2 in autoitercount (#1296)
* hotfix * precommit --------- Co-authored-by: Quentin Anthony <[email protected]>
1 parent c1105de commit 774eb58

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

megatron/training.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -183,24 +183,35 @@ def update_iterations(neox_args, data_loaders):
183183
to do as many iterations as possible while ensuring that each example is seen *at most* train_epochs
184184
times.
185185
"""
186-
if neox_args.train_iters is not None:
186+
if (not neox_args.do_train) or (neox_args.train_iters is not None):
187187
pass
188188
elif neox_args.train_iters is None and neox_args.train_epochs is None:
189189
print_rank_0(
190190
"ERROR:Failed to specify either train_epochs or train_iters in config file"
191191
)
192192
else:
193-
train_dataloader = data_loaders["train"]
194-
train_epochs = neox_args.train_epochs
195-
gradient_accumulation_steps = neox_args.gradient_accumulation_steps
193+
global_rank = torch.distributed.get_rank()
196194

197-
train_iterations = (
198-
len(train_dataloader) * train_epochs
199-
) // gradient_accumulation_steps
195+
if global_rank == 0:
196+
train_dataloader = data_loaders["train"]
197+
train_epochs = neox_args.train_epochs
198+
gradient_accumulation_steps = neox_args.gradient_accumulation_steps
199+
200+
train_dataloader_len = len(train_dataloader)
201+
train_iterations = (
202+
train_dataloader_len * train_epochs
203+
) // gradient_accumulation_steps
204+
205+
train_iters_tensor = torch.cuda.LongTensor([train_iterations])
206+
else:
207+
train_iters_tensor = torch.cuda.LongTensor([0])
208+
209+
torch.distributed.broadcast(train_iters_tensor, src=0)
210+
211+
neox_args.train_iters = train_iters_tensor[0].item()
200212

201-
neox_args.train_iters = train_iterations
202213
print_rank_0(
203-
f"Training for a total of {train_iterations} iterations, corresponding to {train_epochs} epochs."
214+
f"Training for a total of {neox_args.train_iters} iterations, corresponding to {neox_args.train_epochs} epochs."
204215
)
205216

206217

0 commit comments

Comments
 (0)