From ea813d30ab17fa08cebcb28b26fcd1cb37a0f813 Mon Sep 17 00:00:00 2001 From: yuxqiu Date: Tue, 14 May 2024 11:02:01 -0700 Subject: [PATCH] fix: correct reference length calculation (#195) Summary: This PR fixes the way brevity penalty (specifically the effective reference corpus length) is calculated in BLEU. Previously, `len_reference` was calculated as `min([len(ref) for ref in references_tokenized])`. However, this is incorrect, because according to the paper, we need to find the "best match length", not the minimum reference length. For more information, see [wikipedia - brevity penalty](https://en.wikipedia.org/wiki/BLEU#Brevity_penalty) and [nltk implementation](https://www.nltk.org/_modules/nltk/translate/bleu_score.html#closest_ref_length). Pull Request resolved: https://github.com/pytorch/torcheval/pull/195 Test Plan: I added another unit test to `test_bleu.py` and compared the results of the calculations to the results of the `nltk.translate.bleu_score.corpus_bleu` function to make sure the implementation is correct. Reviewed By: galrotem Differential Revision: D56846091 Pulled By: JKSenthil fbshipit-source-id: 2bf1cd0ba169535a118222e60f4264259248f1fd --- tests/metrics/text/test_bleu.py | 29 +++++++++++++++++++++++ torcheval/metrics/functional/text/bleu.py | 5 +++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/metrics/text/test_bleu.py b/tests/metrics/text/test_bleu.py index 5451e177..a7dc54f1 100644 --- a/tests/metrics/text/test_bleu.py +++ b/tests/metrics/text/test_bleu.py @@ -107,3 +107,32 @@ def test_bleu_multiple_examples_per_update(self) -> None: num_total_updates=2, num_processes=2, ) + + def test_bleu_brevity(self) -> None: + candidates = [["the squirrel is eating the nut"], ["the cat is on mat"]] + references = [ + [ + [ + "a squirrel is eating a nut", + "the squirrel is eating a tasty nut", + "hi", + ] + ], + [["there is a cat on the mat", "a cat is on the mat"]], + ] + self.run_class_implementation_tests( + metric=BLEUScore(n_gram=4), + state_names={ + "input_len", + "target_len", + "matches_by_order", + "possible_matches_by_order", + }, + update_kwargs={ + "input": candidates, + "target": references, + }, + compute_result=torch.tensor(0.41650065, dtype=torch.float64), + num_total_updates=2, + num_processes=2, + ) diff --git a/torcheval/metrics/functional/text/bleu.py b/torcheval/metrics/functional/text/bleu.py index bba41933..cd4bba80 100644 --- a/torcheval/metrics/functional/text/bleu.py +++ b/torcheval/metrics/functional/text/bleu.py @@ -88,7 +88,10 @@ def _bleu_score_update( references_tokenized = [ref.split() for ref in references] len_candidate = len(candidate_tokenized) - len_reference = min([len(ref) for ref in references_tokenized]) + len_reference = min( + [len(ref) for ref in references_tokenized], + key=lambda ref_len: (abs(ref_len - len_candidate), ref_len), + ) input_len += len_candidate target_len += len_reference