@@ -670,7 +670,7 @@ def num_examples(self, dataloader: DataLoader) -> int:
670
670
"""
671
671
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
672
672
673
- Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
673
+ Will raise an exception if the underlying dataset does not implement method :obj:`__len__`
674
674
"""
675
675
return len (dataloader .dataset )
676
676
@@ -1783,8 +1783,13 @@ def prediction_loop(
1783
1783
1784
1784
eval_losses_gatherer = DistributedTensorGatherer (world_size , num_examples , make_multiple_of = batch_size )
1785
1785
if not prediction_loss_only :
1786
- preds_gatherer = DistributedTensorGatherer (world_size , num_examples , make_multiple_of = batch_size )
1787
- labels_gatherer = DistributedTensorGatherer (world_size , num_examples , make_multiple_of = batch_size )
1786
+ # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
1787
+ # a batch size to the sampler)
1788
+ make_multiple_of = None
1789
+ if hasattr (dataloader , "sampler" ) and isinstance (dataloader .sampler , SequentialDistributedSampler ):
1790
+ make_multiple_of = dataloader .sampler .batch_size
1791
+ preds_gatherer = DistributedTensorGatherer (world_size , num_examples , make_multiple_of = make_multiple_of )
1792
+ labels_gatherer = DistributedTensorGatherer (world_size , num_examples , make_multiple_of = make_multiple_of )
1788
1793
1789
1794
model .eval ()
1790
1795
0 commit comments