Skip to content

Commit 45dae78

Browse files
committed
Fix distributed evaluation (#10795)
* Fix distributed evaluation * Use logger
1 parent 12b04b5 commit 45dae78

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/transformers/trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def num_examples(self, dataloader: DataLoader) -> int:
670670
"""
671671
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
672672
673-
Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
673+
Will raise an exception if the underlying dataset does not implement method :obj:`__len__`
674674
"""
675675
return len(dataloader.dataset)
676676

@@ -1783,8 +1783,13 @@ def prediction_loop(
17831783

17841784
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
17851785
if not prediction_loss_only:
1786-
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
1787-
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
1786+
# The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
1787+
# a batch size to the sampler)
1788+
make_multiple_of = None
1789+
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
1790+
make_multiple_of = dataloader.sampler.batch_size
1791+
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
1792+
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
17881793

17891794
model.eval()
17901795

tests/test_trainer_distributed.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def test_trainer(self):
9797
def compute_metrics(p: EvalPrediction) -> Dict:
9898
sequential = list(range(len(dataset)))
9999
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
100+
if not success and training_args.local_rank == 0:
101+
logger.warning(
102+
"Predictions and/or labels do not match expected results:\n - predictions: "
103+
f"{p.predictions.tolist()}\n - labels: {p.label_ids.tolist()}\n - expected: {sequential}"
104+
)
100105
return {"success": success}
101106

102107
trainer = Trainer(

0 commit comments

Comments
 (0)