Skip to content

Commit d51b93b

Browse files
committed
add BLEU
1 parent e57f189 commit d51b93b

File tree

7 files changed

+253
-5
lines changed

7 files changed

+253
-5
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ absl-py
22
clu
33
jax[cpu]
44
keras-hub
5+
keras-nlp
56
pytest
67
scikit-learn

src/metrax/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AUCROC = classification_metrics.AUCROC
2323
Average = base.Average
2424
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
25+
BLEU = nlp_metrics.BLEU
2526
MSE = regression_metrics.MSE
2627
Perplexity = nlp_metrics.Perplexity
2728
Precision = classification_metrics.Precision
@@ -36,7 +37,7 @@
3637
"AUCROC",
3738
"Average",
3839
"AveragePrecisionAtK",
39-
"MSE",
40+
"BLEUMSE",
4041
"Perplexity",
4142
"Precision",
4243
"RMSE",

src/metrax/metrax_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@
4040
'the quick brown fox jumps over the lazy dog',
4141
'hello beautiful world',
4242
]
43-
TOKENIZED_PREDS = [sentence.split() for sentence in STRING_PREDS]
44-
TOKENIZED_REFS = [sentence.split() for sentence in STRING_REFS]
4543

4644

4745
class MetraxTest(parameterized.TestCase):
@@ -114,7 +112,12 @@ def test_metrics_jittable(self, metric, kwargs):
114112
(
115113
'wer',
116114
metrax.WER,
117-
{'predictions': TOKENIZED_PREDS, 'references': TOKENIZED_REFS},
115+
{'predictions': STRING_PREDS, 'references': STRING_REFS},
116+
),
117+
(
118+
'bleu',
119+
metrax.BLEU,
120+
{'predictions': STRING_PREDS, 'references': STRING_REFS},
118121
),
119122
)
120123
def test_metrics_not_jittable(self, metric, kwargs):

src/metrax/nlp_metrics.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,164 @@
1515
"""A collection of different metrics for NLP models."""
1616

1717
from clu import metrics as clu_metrics
18+
import collections
19+
import math
1820
import flax
1921
import jax
2022
import jax.numpy as jnp
2123
from metrax import base
2224

2325

26+
def get_ngrams(segment: list[str], max_order: int):
27+
"""Extracts all n-grams up to a given maximum order from an input segment.
28+
29+
Args:
30+
segment: list. Text segment from which n-grams will be extracted.
31+
max_order: int. Maximum length in tokens of the n-grams returned by this
32+
method.
33+
"""
34+
ngram_counts = collections.Counter()
35+
for order in range(1, max_order + 1):
36+
for i in range(0, len(segment) - order + 1):
37+
ngram = tuple(segment[i : i + order])
38+
ngram_counts[ngram] += 1
39+
return ngram_counts
40+
41+
42+
@flax.struct.dataclass
43+
class BLEU(clu_metrics.Metric):
44+
r"""Computes the BLEU score for sequence generation.
45+
46+
BLEU measures the similarity between a machine-generated candidate translation
47+
and one or more human reference translations, focusing on matching n-grams.
48+
49+
It's calculated as:
50+
.. math::
51+
\text{BLEU} = \text{BP} \times \exp\left( \sum_{n=1}^{N} w_n \log p_n
52+
\right)
53+
54+
Where:
55+
- :math:`p_n` is the modified n-gram precision for n-grams of order n.
56+
- :math:`N` is the maximum n-gram order considered (typically 4).
57+
- :math:`w_n` are weights for each order (typically uniform, 1/N).
58+
- :math:`\text{BP}` is the Brevity Penalty.
59+
60+
This implementation uses uniform weights and calculates statistics
61+
incrementally.
62+
63+
Attributes:
64+
max_order: Maximum n-gram order to consider.
65+
matches_by_order: Accumulated sum of clipped n-gram matches for each order.
66+
possible_matches_by_order: Accumulated sum of total n-grams in predictions
67+
for each order.
68+
translation_length: Accumulated total length of predictions.
69+
reference_length: Accumulated total 'effective' reference length (closest
70+
length match for each prediction).
71+
"""
72+
73+
max_order: int
74+
matches_by_order: jax.Array
75+
possible_matches_by_order: jax.Array
76+
translation_length: jax.Array
77+
reference_length: jax.Array
78+
79+
@classmethod
80+
def empty(cls) -> 'BLEU':
81+
return cls(
82+
max_order=4,
83+
matches_by_order=jnp.array(0, jnp.float32),
84+
possible_matches_by_order=jnp.array(0, jnp.float32),
85+
translation_length=jnp.array(0, jnp.float32),
86+
reference_length=jnp.array(0, jnp.float32),
87+
)
88+
89+
@classmethod
90+
def from_model_output(
91+
cls,
92+
predictions: list[str],
93+
references: list[list[str]],
94+
max_order: int = 4,
95+
) -> 'BLEU':
96+
"""Computes BLEU statistics for a batch of predictions and references.
97+
98+
Args:
99+
predictions: A list of predicted strings. The shape should be (batch_size,
100+
).
101+
references: A list of lists of reference strings. The shape should be
102+
(batch_size, num_references).
103+
max_order: The maximum order of n-grams to consider.
104+
105+
Returns:
106+
A BLEU metric instance containing the statistics for this batch.
107+
108+
Raises:
109+
ValueError: If the shapes of `predictions` and `references` are
110+
incompatible.
111+
"""
112+
matches_by_order = [0] * max_order
113+
possible_matches_by_order = [0] * max_order
114+
pred_length = 0
115+
ref_length = 0
116+
117+
for pred, ref_list in zip(predictions, references):
118+
pred = pred.split()
119+
ref_list = [r.split() for r in ref_list]
120+
pred_length += len(pred)
121+
ref_length += min(len(r) for r in ref_list)
122+
prediction_ngram_counts = get_ngrams(pred, max_order)
123+
reference_ngram_counts = collections.Counter()
124+
for ref in ref_list:
125+
reference_ngram_counts |= get_ngrams(ref, max_order)
126+
overlap = prediction_ngram_counts & reference_ngram_counts
127+
for ngram in overlap:
128+
matches_by_order[len(ngram) - 1] += overlap[ngram]
129+
for order in range(1, max_order + 1):
130+
possible_matches = len(pred) - order + 1
131+
if possible_matches > 0:
132+
possible_matches_by_order[order - 1] += possible_matches
133+
134+
return cls(
135+
max_order=max_order,
136+
matches_by_order=jnp.array(matches_by_order, dtype=jnp.float32),
137+
possible_matches_by_order=jnp.array(
138+
possible_matches_by_order, dtype=jnp.float32
139+
),
140+
translation_length=jnp.array(pred_length, dtype=jnp.float32),
141+
reference_length=jnp.array(ref_length, dtype=jnp.float32),
142+
)
143+
144+
def merge(self, other: 'BLEU') -> 'BLEU':
145+
if self.max_order != other.max_order:
146+
raise ValueError(
147+
'BLEU metrics with different max_order cannot be merged.'
148+
)
149+
return type(self)(
150+
max_order=self.max_order,
151+
matches_by_order=(self.matches_by_order + other.matches_by_order),
152+
possible_matches_by_order=(
153+
self.possible_matches_by_order + other.possible_matches_by_order
154+
),
155+
translation_length=(self.translation_length + other.translation_length),
156+
reference_length=(self.reference_length + other.reference_length),
157+
)
158+
159+
def compute(self) -> jax.Array:
160+
precisions = [0] * self.max_order
161+
for i in range(0, self.max_order):
162+
precisions[i] = base.divide_no_nan(
163+
self.matches_by_order[i], self.possible_matches_by_order[i]
164+
)
165+
geo_mean = (
166+
math.exp(sum((1.0 / self.max_order) * math.log(p) for p in precisions))
167+
if precisions and min(precisions) > 0
168+
else 0
169+
)
170+
ratio = base.divide_no_nan(self.translation_length, self.reference_length)
171+
bp = 1.0 if ratio > 1.0 else math.exp(1 - 1.0 / ratio)
172+
bleu = geo_mean * bp
173+
return jnp.array(bleu)
174+
175+
24176
@flax.struct.dataclass
25177
class Perplexity(clu_metrics.Metric):
26178
r"""Computes perplexity for sequence generation.

src/metrax/nlp_metrics_test.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from absl.testing import parameterized
2222
import jax.numpy as jnp
2323
import keras_hub
24+
import keras_nlp
2425
import metrax
2526
import numpy as np
2627

@@ -29,6 +30,15 @@
2930

3031
class NlpMetricsTest(parameterized.TestCase):
3132

33+
def test_bleu_empty(self):
34+
"""Tests the `empty` method of the `BLEU` class."""
35+
m = metrax.BLEU.empty()
36+
self.assertEqual(m.max_order, 4)
37+
self.assertEqual(m.matches_by_order, jnp.array(0, jnp.float32))
38+
self.assertEqual(m.possible_matches_by_order, jnp.array(0, jnp.float32))
39+
self.assertEqual(m.translation_length, jnp.array(0, jnp.float32))
40+
self.assertEqual(m.reference_length, jnp.array(0, jnp.float32))
41+
3242
def test_perplexity_empty(self):
3343
"""Tests the `empty` method of the `Perplexity` class."""
3444
m = metrax.Perplexity.empty()
@@ -41,6 +51,78 @@ def test_wer_empty(self):
4151
self.assertEqual(m.total, jnp.array(0, jnp.float32))
4252
self.assertEqual(m.count, jnp.array(0, jnp.float32))
4353

54+
def test_bleu(self):
55+
"""Tests that BLEU metric computes correct values."""
56+
references = [
57+
["He eats a sweet apple", "He is eating a tasty apple, isn't he"],
58+
[
59+
"Silicon Valley is one of my favourite shows",
60+
"Silicon Valley is the best show ever",
61+
],
62+
]
63+
predictions = [
64+
"He He He eats sweet apple which is a fruit",
65+
"I love Silicon Valley it's one of my favourite shows",
66+
]
67+
keras_metric = keras_nlp.metrics.Bleu()
68+
keras_metric.update_state(references, predictions)
69+
metrax_metric = metrax.BLEU.from_model_output(predictions, references)
70+
71+
np.testing.assert_allclose(
72+
metrax_metric.compute(),
73+
keras_metric.result(),
74+
rtol=1e-05,
75+
atol=1e-05,
76+
)
77+
78+
def test_bleu_merge(self):
79+
"""Tests that BLEU metric computes correct values using merge."""
80+
references = [
81+
["He eats a sweet apple", "He is eating a tasty apple, isn't he"],
82+
[
83+
"Silicon Valley is one of my favourite shows",
84+
"Silicon Valley is the best show ever",
85+
],
86+
]
87+
predictions = [
88+
"He He He eats sweet apple which is a fruit",
89+
"I love Silicon Valley it's one of my favourite shows",
90+
]
91+
keras_metric = keras_nlp.metrics.Bleu()
92+
keras_metric.update_state(references, predictions)
93+
metrax_metric = None
94+
for ref_list, pred in zip(references, predictions):
95+
update = metrax.BLEU.from_model_output([pred], [ref_list])
96+
metrax_metric = (
97+
update if metrax_metric is None else metrax_metric.merge(update)
98+
)
99+
100+
np.testing.assert_allclose(
101+
metrax_metric.compute(),
102+
keras_metric.result(),
103+
rtol=1e-05,
104+
atol=1e-05,
105+
)
106+
107+
def test_bleu_merge_fails_on_different_max_order(self):
108+
"""Tests that error is raised when BLEU metrics with different max_order are merged."""
109+
references = [
110+
["He eats a sweet apple", "He is eating a tasty apple, isn't he"],
111+
]
112+
predictions = [
113+
"He He He eats sweet apple which is a fruit",
114+
]
115+
order_3_metric = metrax.BLEU.from_model_output(
116+
predictions, references, max_order=3
117+
)
118+
order_4_metric = metrax.BLEU.from_model_output(
119+
predictions, references, max_order=4
120+
)
121+
122+
np.testing.assert_raises(
123+
ValueError, lambda: order_3_metric.merge(order_4_metric)
124+
)
125+
44126
@parameterized.named_parameters(
45127
(
46128
'basic',
@@ -141,4 +223,4 @@ def test_wer(self):
141223

142224

143225
if __name__ == '__main__':
144-
absltest.main()
226+
absltest.main()

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AUCROC = nnx_metrics.AUCROC
1919
Average = nnx_metrics.Average
2020
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
21+
BLEU = nnx_metrics.BLEU
2122
MSE = nnx_metrics.MSE
2223
Perplexity = nnx_metrics.Perplexity
2324
Precision = nnx_metrics.Precision
@@ -32,6 +33,7 @@
3233
"AUCROC",
3334
"Average",
3435
"AveragePrecisionAtK",
36+
"BLEU",
3537
"MSE",
3638
"Perplexity",
3739
"Precision",

src/metrax/nnx/nnx_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def __init__(self):
4646
super().__init__(metrax.AveragePrecisionAtK)
4747

4848

49+
class BLEU(NnxWrapper):
50+
"""An NNX class for the Metrax metric BLEU."""
51+
52+
def __init__(self):
53+
super().__init__(metrax.BLEU)
54+
55+
4956
class MSE(NnxWrapper):
5057
"""An NNX class for the Metrax metric MSE."""
5158

0 commit comments

Comments
 (0)