Skip to content

Commit 12b04b5

Browse files
mansimanesgugger
andcommitted
Smmp batch not divisible by microbatches fix (#10778)
* Added debug prints * Added config * Added prints * Added prints * Added extra samples to SequentialDistributedSampler * Added extra samples to SequentialDistributedSampler Updated SequentialDistributedSampler call * Added deubg prints * Removed extra prints * Making predicitons and labels multiple of batchsize * updated number of microbatches * Removed extra prints * Made start_remainder similar to DistributedSamplerWithLoop * Minor spacing update * Added debug prints Added config Added prints Added prints * Added extra samples to SequentialDistributedSampler Updated SequentialDistributedSampler call Added extra samples to SequentialDistributedSampler Added deubg prints Removed extra prints Making predicitons and labels multiple of batchsize updated number of microbatches Removed extra prints Squashing redundant commits * Made start_remainder similar to DistributedSamplerWithLoop Minor spacing update Made start_remainder similar to DistributedSamplerWithLoop * Test and styling * Rename test Co-authored-by: Sylvain Gugger <[email protected]>
1 parent 6460e9a commit 12b04b5

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

src/transformers/sagemaker/trainer_sm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,12 @@ def _get_train_sampler(self):
112112

113113
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
114114
if self.is_model_parallel_enabled:
115-
return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
115+
return SequentialDistributedSampler(
116+
eval_dataset,
117+
num_replicas=smp.dp_size(),
118+
rank=smp.dp_rank(),
119+
batch_size=self.args.per_device_eval_batch_size,
120+
)
116121
else:
117122
return super()._get_eval_sampler(eval_dataset)
118123

src/transformers/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,8 +1783,8 @@ def prediction_loop(
17831783

17841784
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
17851785
if not prediction_loss_only:
1786-
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
1787-
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
1786+
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
1787+
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
17881788

17891789
model.eval()
17901790

src/transformers/trainer_pt_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ class SequentialDistributedSampler(Sampler):
220220
or `reduce` resulting tensors at the end of the loop.
221221
"""
222222

223-
def __init__(self, dataset, num_replicas=None, rank=None):
223+
def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
224224
if num_replicas is None:
225225
if not dist.is_available():
226226
raise RuntimeError("Requires distributed package to be available")
@@ -232,8 +232,14 @@ def __init__(self, dataset, num_replicas=None, rank=None):
232232
self.dataset = dataset
233233
self.num_replicas = num_replicas
234234
self.rank = rank
235-
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
235+
num_samples = len(self.dataset)
236+
# Add extra samples to make num_samples a multiple of batch_size if passed
237+
if batch_size is not None:
238+
self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size
239+
else:
240+
self.num_samples = int(math.ceil(num_samples / num_replicas))
236241
self.total_size = self.num_samples * self.num_replicas
242+
self.batch_size = batch_size
237243

238244
def __iter__(self):
239245
indices = list(range(len(self.dataset)))

tests/test_trainer_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
DistributedTensorGatherer,
3232
LabelSmoother,
3333
LengthGroupedSampler,
34+
SequentialDistributedSampler,
3435
get_parameter_names,
3536
)
3637

@@ -167,3 +168,35 @@ def test_distributed_sampler_with_loop(self):
167168

168169
self.assertEqual(set(total[:length]), set(dataset))
169170
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))
171+
172+
def test_sequential_distributed_sampler(self):
173+
batch_size = 16
174+
for length in [23, 64, 123]:
175+
dataset = list(range(length))
176+
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0)
177+
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1)
178+
179+
# Sample
180+
samples1 = list(shard1)
181+
samples2 = list(shard2)
182+
183+
total = samples1 + samples2
184+
185+
self.assertListEqual(total[:length], dataset)
186+
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
187+
188+
# With a batch_size passed
189+
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0, batch_size=batch_size)
190+
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1, batch_size=batch_size)
191+
192+
# Sample
193+
samples1 = list(shard1)
194+
samples2 = list(shard2)
195+
196+
self.assertTrue(len(samples1) % batch_size == 0)
197+
self.assertTrue(len(samples2) % batch_size == 0)
198+
199+
total = samples1 + samples2
200+
201+
self.assertListEqual(total[:length], dataset)
202+
self.assertListEqual(total[length:], dataset[: (len(total) - length)])

0 commit comments

Comments
 (0)