Skip to content

Commit

Permalink
Add Metric.from_mask helper method (facebookresearch#3411)
Browse files Browse the repository at this point in the history
  • Loading branch information
poojasethi committed Nov 22, 2022
1 parent 94f1b9c commit 59f1d85
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
24 changes: 23 additions & 1 deletion parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Optional,
Set,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -272,7 +273,7 @@ def many(cls, *objs: List[TVector]) -> List[Metric]:
"""
Construct many of a Metric from the base parts.
Useful if you separately compute numerators and denomenators, etc.
Useful if you separately compute numerators and denominators, etc.
"""
lengths = [len(o) for o in objs]
objs = list(objs) # convert from tuple for inplace modification
Expand All @@ -286,6 +287,27 @@ def many(cls, *objs: List[TVector]) -> List[Metric]:
raise IndexError(f'Uneven {cls.__name__} constructions: {lengths}')
return [cls(*items) for items in zip(*objs)]

@classmethod
def from_mask(
cls, metric_per_token: torch.Tensor, mask: torch.Tensor, MyMetric: Type[Metric]
) -> List[Metric]:
"""
From token-level metrics, returns aggregate MyMetric per example in the batch.
:param metric_per_token:
a (batchsize x num_tokens) Tensor
:param mask:
a (batchsize x num_tokens) Tensor to mask out tokens that should *not* be considered in the aggregate metric calculation.
:param MyMetric:
a subclass of Metric
:return:
a (batchsize) Tensor
"""
tokens_per_ex = mask.long().sum(dim=-1)
metric_per_ex = (metric_per_token * mask).sum(dim=-1)
metrics = MyMetric.many(metric_per_ex, tokens_per_ex)
return metrics


class FixedMetric(Metric):
"""
Expand Down
31 changes: 19 additions & 12 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from parlai.utils.misc import warn_once
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric
from parlai.utils.fp16 import FP16SafeCrossEntropy
import parlai.utils.fsdp as fsdp_utils
from parlai.utils.torch import (
Expand Down Expand Up @@ -710,28 +710,35 @@ def compute_loss(self, batch, return_output=False):
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
scores, preds, *_ = model_output
score_view = scores.reshape(-1, scores.size(-1))
loss = self.criterion(score_view, batch.label_vec.view(-1))
loss = loss.view(scores.shape[:-1]).sum(dim=1)
# save loss to metrics
loss_flattened = self.criterion(score_view, batch.label_vec.view(-1))
loss_per_token = loss_flattened.view(scores.shape[:-1])
notnull = batch.label_vec.ne(self.NULL_IDX)
target_tokens = notnull.long().sum(dim=-1)
correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

# save loss to metrics
# cross entropy loss
self.record_local_metric('loss', AverageMetric.many(loss, target_tokens))
self.record_local_metric(
'loss', Metric.from_mask(loss_per_token, notnull, AverageMetric)
)
# perplexity
self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
self.record_local_metric(
'ppl', Metric.from_mask(loss_per_token, notnull, PPLMetric)
)
# token-wise accuracy
self.record_local_metric(
'token_acc', AverageMetric.many(correct, target_tokens)
'token_acc',
Metric.from_mask(batch.label_vec == preds, notnull, AverageMetric),
)
# utterance-wise exact match
num_target_tokens = notnull.long().sum(dim=-1)
num_tokens_correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
self.record_local_metric(
'token_em', AverageMetric.many(correct == target_tokens)
'token_em', AverageMetric.many(num_tokens_correct == num_target_tokens)
)

# actually do backwards loss
loss = loss_per_token.sum(dim=1)
loss = loss.sum()
loss /= target_tokens.sum() # average loss per token
loss /= num_target_tokens.sum() # average loss per token
if return_output:
return (loss, model_output)
else:
Expand Down Expand Up @@ -1440,7 +1447,7 @@ def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSTyp

def get_output_from_current_step(self):
"""
Get the outputput at the current step.
Get the output at the current step.
"""
return self.outputs[-1]

Expand Down

0 comments on commit 59f1d85

Please sign in to comment.