Skip to content

Commit 2f34cc7

Browse files
authored
[megatron] add micro_batch_size check (modelscope#8103)
1 parent 728e167 commit 2f34cc7

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

swift/megatron/arguments/megatron_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,11 @@ def __post_init__(self):
645645
self.data_parallel_size = self.world_size // total_model_size
646646
# Gradient Accumulation
647647
self.num_microbatches = self.global_batch_size // self.data_parallel_size // self.micro_batch_size
648+
if self.num_microbatches == 0:
649+
raise ValueError('global_batch_size must be >= `data_parallel_size * micro_batch_size` '
650+
f'to have at least one micro-batch. global_batch_size: {self.global_batch_size}, '
651+
f'data_parallel_size: {self.data_parallel_size}, '
652+
f'micro_batch_size: {self.micro_batch_size}.')
648653

649654
def _init_teacher_model(self):
650655
if self.teacher_model is None:

0 commit comments

Comments
 (0)