Skip to content

Commit 59f1d85

Browse files
committed
Add Metric.from_mask helper method (facebookresearch#3411)
1 parent 94f1b9c commit 59f1d85

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

parlai/core/metrics.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Optional,
2525
Set,
2626
Tuple,
27+
Type,
2728
Union,
2829
)
2930

@@ -272,7 +273,7 @@ def many(cls, *objs: List[TVector]) -> List[Metric]:
272273
"""
273274
Construct many of a Metric from the base parts.
274275
275-
Useful if you separately compute numerators and denomenators, etc.
276+
Useful if you separately compute numerators and denominators, etc.
276277
"""
277278
lengths = [len(o) for o in objs]
278279
objs = list(objs) # convert from tuple for inplace modification
@@ -286,6 +287,27 @@ def many(cls, *objs: List[TVector]) -> List[Metric]:
286287
raise IndexError(f'Uneven {cls.__name__} constructions: {lengths}')
287288
return [cls(*items) for items in zip(*objs)]
288289

290+
@classmethod
291+
def from_mask(
292+
cls, metric_per_token: torch.Tensor, mask: torch.Tensor, MyMetric: Type[Metric]
293+
) -> List[Metric]:
294+
"""
295+
From token-level metrics, returns aggregate MyMetric per example in the batch.
296+
297+
:param metric_per_token:
298+
a (batchsize x num_tokens) Tensor
299+
:param mask:
300+
a (batchsize x num_tokens) Tensor to mask out tokens that should *not* be considered in the aggregate metric calculation.
301+
:param MyMetric:
302+
a subclass of Metric
303+
:return:
304+
a (batchsize) Tensor
305+
"""
306+
tokens_per_ex = mask.long().sum(dim=-1)
307+
metric_per_ex = (metric_per_token * mask).sum(dim=-1)
308+
metrics = MyMetric.many(metric_per_ex, tokens_per_ex)
309+
return metrics
310+
289311

290312
class FixedMetric(Metric):
291313
"""

parlai/core/torch_generator_agent.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from parlai.utils.misc import warn_once
3535
from parlai.utils.io import PathManager
3636
import parlai.utils.logging as logging
37-
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
37+
from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric
3838
from parlai.utils.fp16 import FP16SafeCrossEntropy
3939
import parlai.utils.fsdp as fsdp_utils
4040
from parlai.utils.torch import (
@@ -710,28 +710,35 @@ def compute_loss(self, batch, return_output=False):
710710
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
711711
scores, preds, *_ = model_output
712712
score_view = scores.reshape(-1, scores.size(-1))
713-
loss = self.criterion(score_view, batch.label_vec.view(-1))
714-
loss = loss.view(scores.shape[:-1]).sum(dim=1)
715-
# save loss to metrics
713+
loss_flattened = self.criterion(score_view, batch.label_vec.view(-1))
714+
loss_per_token = loss_flattened.view(scores.shape[:-1])
716715
notnull = batch.label_vec.ne(self.NULL_IDX)
717-
target_tokens = notnull.long().sum(dim=-1)
718-
correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
719716

717+
# save loss to metrics
720718
# cross entropy loss
721-
self.record_local_metric('loss', AverageMetric.many(loss, target_tokens))
719+
self.record_local_metric(
720+
'loss', Metric.from_mask(loss_per_token, notnull, AverageMetric)
721+
)
722722
# perplexity
723-
self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
723+
self.record_local_metric(
724+
'ppl', Metric.from_mask(loss_per_token, notnull, PPLMetric)
725+
)
724726
# token-wise accuracy
725727
self.record_local_metric(
726-
'token_acc', AverageMetric.many(correct, target_tokens)
728+
'token_acc',
729+
Metric.from_mask(batch.label_vec == preds, notnull, AverageMetric),
727730
)
728731
# utterance-wise exact match
732+
num_target_tokens = notnull.long().sum(dim=-1)
733+
num_tokens_correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
729734
self.record_local_metric(
730-
'token_em', AverageMetric.many(correct == target_tokens)
735+
'token_em', AverageMetric.many(num_tokens_correct == num_target_tokens)
731736
)
737+
732738
# actually do backwards loss
739+
loss = loss_per_token.sum(dim=1)
733740
loss = loss.sum()
734-
loss /= target_tokens.sum() # average loss per token
741+
loss /= num_target_tokens.sum() # average loss per token
735742
if return_output:
736743
return (loss, model_output)
737744
else:
@@ -1440,7 +1447,7 @@ def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSTyp
14401447

14411448
def get_output_from_current_step(self):
14421449
"""
1443-
Get the outputput at the current step.
1450+
Get the output at the current step.
14441451
"""
14451452
return self.outputs[-1]
14461453

0 commit comments

Comments
 (0)