|
34 | 34 | from parlai.utils.misc import warn_once
|
35 | 35 | from parlai.utils.io import PathManager
|
36 | 36 | import parlai.utils.logging as logging
|
37 |
| -from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric |
| 37 | +from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric |
38 | 38 | from parlai.utils.fp16 import FP16SafeCrossEntropy
|
39 | 39 | import parlai.utils.fsdp as fsdp_utils
|
40 | 40 | from parlai.utils.torch import (
|
@@ -710,28 +710,35 @@ def compute_loss(self, batch, return_output=False):
|
710 | 710 | model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
|
711 | 711 | scores, preds, *_ = model_output
|
712 | 712 | score_view = scores.reshape(-1, scores.size(-1))
|
713 |
| - loss = self.criterion(score_view, batch.label_vec.view(-1)) |
714 |
| - loss = loss.view(scores.shape[:-1]).sum(dim=1) |
715 |
| - # save loss to metrics |
| 713 | + loss_flattened = self.criterion(score_view, batch.label_vec.view(-1)) |
| 714 | + loss_per_token = loss_flattened.view(scores.shape[:-1]) |
716 | 715 | notnull = batch.label_vec.ne(self.NULL_IDX)
|
717 |
| - target_tokens = notnull.long().sum(dim=-1) |
718 |
| - correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) |
719 | 716 |
|
| 717 | + # save loss to metrics |
720 | 718 | # cross entropy loss
|
721 |
| - self.record_local_metric('loss', AverageMetric.many(loss, target_tokens)) |
| 719 | + self.record_local_metric( |
| 720 | + 'loss', Metric.from_mask(loss_per_token, notnull, AverageMetric) |
| 721 | + ) |
722 | 722 | # perplexity
|
723 |
| - self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens)) |
| 723 | + self.record_local_metric( |
| 724 | + 'ppl', Metric.from_mask(loss_per_token, notnull, PPLMetric) |
| 725 | + ) |
724 | 726 | # token-wise accuracy
|
725 | 727 | self.record_local_metric(
|
726 |
| - 'token_acc', AverageMetric.many(correct, target_tokens) |
| 728 | + 'token_acc', |
| 729 | + Metric.from_mask(batch.label_vec == preds, notnull, AverageMetric), |
727 | 730 | )
|
728 | 731 | # utterance-wise exact match
|
| 732 | + num_target_tokens = notnull.long().sum(dim=-1) |
| 733 | + num_tokens_correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) |
729 | 734 | self.record_local_metric(
|
730 |
| - 'token_em', AverageMetric.many(correct == target_tokens) |
| 735 | + 'token_em', AverageMetric.many(num_tokens_correct == num_target_tokens) |
731 | 736 | )
|
| 737 | + |
732 | 738 | # actually do backwards loss
|
| 739 | + loss = loss_per_token.sum(dim=1) |
733 | 740 | loss = loss.sum()
|
734 |
| - loss /= target_tokens.sum() # average loss per token |
| 741 | + loss /= num_target_tokens.sum() # average loss per token |
735 | 742 | if return_output:
|
736 | 743 | return (loss, model_output)
|
737 | 744 | else:
|
@@ -1440,7 +1447,7 @@ def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSTyp
|
1440 | 1447 |
|
1441 | 1448 | def get_output_from_current_step(self):
|
1442 | 1449 | """
|
1443 |
| - Get the outputput at the current step. |
| 1450 | + Get the output at the current step. |
1444 | 1451 | """
|
1445 | 1452 | return self.outputs[-1]
|
1446 | 1453 |
|
|
0 commit comments