Skip to content

Commit 1811369

Browse files
committed
add macro average rouge-n to metrax
1 parent 5723b01 commit 1811369

File tree

7 files changed

+261
-8
lines changed

7 files changed

+261
-8
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ keras-hub
55
keras-nlp
66
pytest
77
scikit-learn
8+
rouge-score

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
RMSE = regression_metrics.RMSE
3030
RSQUARED = regression_metrics.RSQUARED
3131
Recall = classification_metrics.Recall
32+
RougeN = nlp_metrics.RougeN
3233
WER = nlp_metrics.WER
3334

3435

@@ -44,5 +45,6 @@
4445
"RMSE",
4546
"RSQUARED",
4647
"Recall",
48+
"RougeN",
4749
"WER",
4850
]

src/metrax/metrax_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_metrics_jittable(self, metric, kwargs):
105105
computed_metric = metric.from_model_output(**kwargs)
106106
jitted_metric = jax.jit(metric.from_model_output)(**kwargs)
107107
np.testing.assert_allclose(
108-
computed_metric.compute(), jitted_metric.compute()
108+
computed_metric.compute(), jitted_metric.compute(), rtol=1e-2, atol=1e-2
109109
)
110110

111111
@parameterized.named_parameters(
@@ -119,6 +119,11 @@ def test_metrics_jittable(self, metric, kwargs):
119119
metrax.BLEU,
120120
{'predictions': STRING_PREDS, 'references': STRING_REFS},
121121
),
122+
(
123+
'rougeN',
124+
metrax.RougeN,
125+
{'predictions': STRING_PREDS, 'references': STRING_REFS},
126+
),
122127
)
123128
def test_metrics_not_jittable(self, metric, kwargs):
124129
"""Tests that attempting to jit and call a known non-jittable metric raises an error."""

src/metrax/nlp_metrics.py

Lines changed: 172 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,42 @@
1414

1515
"""A collection of different metrics for NLP models."""
1616

17-
from clu import metrics as clu_metrics
1817
import collections
1918
import math
19+
from clu import metrics as clu_metrics
2020
import flax
2121
import jax
2222
import jax.numpy as jnp
2323
from metrax import base
2424

2525

26+
def _get_single_n_grams(segment: list[str], order: int):
27+
"""Generates a counter of n-grams from a list of tokens for a specific n.
28+
29+
Args:
30+
segment: list. Text segment from which n-grams will be extracted.
31+
order: The order of n-grams.
32+
33+
Returns:
34+
A collections.Counter mapping n-gram tuples to their counts.
35+
"""
36+
return collections.Counter(zip(*[segment[i:] for i in range(order)]))
37+
38+
2639
def _get_ngrams(segment: list[str], max_order: int):
2740
"""Extracts all n-grams up to a given maximum order from an input segment.
2841
2942
Args:
3043
segment: list. Text segment from which n-grams will be extracted.
3144
max_order: int. Maximum length in tokens of the n-grams returned by this
3245
method.
46+
47+
Returns:
48+
A collections.Counter mapping n-gram tuples to their counts for all orders.
3349
"""
3450
ngram_counts = collections.Counter()
3551
for order in range(1, max_order + 1):
36-
for i in range(0, len(segment) - order + 1):
37-
ngram = tuple(segment[i : i + order])
38-
ngram_counts[ngram] += 1
52+
ngram_counts.update(_get_single_n_grams(segment, order))
3953
return ngram_counts
4054

4155

@@ -285,6 +299,159 @@ def compute(self) -> jax.Array:
285299
)
286300

287301

302+
@flax.struct.dataclass
303+
class RougeN(clu_metrics.Metric):
304+
r"""Computes macro-averaged ROUGE-N recall, precision, and F1-score.
305+
306+
This metric first calculates ROUGE-N precision, recall, and F1-score for each
307+
individual prediction compared against its single corresponding reference.
308+
These per-instance precision, recall and F1-scores are then averaged across
309+
all instances in the dataset/batch.
310+
311+
Accumulation for Macro-Average:
312+
- total_precision = sum of all precision values.
313+
- total_recall = sum of all instance_recall values.
314+
- total_f1 = sum of all f1 values.
315+
- num_examples = count of prediction-reference pairs.
316+
317+
Final Macro-Averaged Metrics:
318+
.. math::
319+
\text{MacroAvgPrecision} =
320+
\frac{\text{total_precision}}{\text{num_examples}}
321+
.. math::
322+
\text{MacroAvgRecall} = \frac{\text{total_recall}}{\text{num_examples}}
323+
.. math::
324+
\text{MacroAvgF1} = 2 \cdot \frac{\text{MacroAvgPrecision} \cdot
325+
\text{MacroAvgRecall}}{\text{MacroAvgPrecision} + \text{MacroAvgRecall}}
326+
327+
Attributes:
328+
order: The specific 'N' in ROUGE-N (e.g., 1 for ROUGE-1, 2 for ROUGE-2).
329+
total_precision: Accumulated sum of precision scores from each instance.
330+
total_recall: Accumulated sum of recall scores from each instance.
331+
total_f1: Accumulated sum of f1 scores from each instance.
332+
num_examples: The number of instances (prediction-reference pairs)
333+
processed.
334+
"""
335+
336+
order: int
337+
total_precision: jax.Array
338+
total_recall: jax.Array
339+
total_f1: jax.Array
340+
num_examples: jax.Array
341+
342+
@classmethod
343+
def empty(cls, order: int = 2) -> 'RougeN':
344+
"""Creates an empty ROUGE-N metric for macro-averaging.
345+
346+
Args:
347+
order: The order 'N' of n-grams (e.g., 2 for ROUGE-2). Must be a positive
348+
integer.
349+
350+
Returns:
351+
An empty RougeN metric.
352+
"""
353+
return cls(
354+
order=order,
355+
total_precision=jnp.array(0, jnp.float32),
356+
total_recall=jnp.array(0, jnp.float32),
357+
total_f1=jnp.array(0, jnp.float32),
358+
num_examples=jnp.array(0, jnp.float32),
359+
)
360+
361+
@classmethod
362+
def from_model_output(
363+
cls,
364+
predictions: list[str],
365+
references: list[str],
366+
order: int = 2,
367+
) -> 'RougeN':
368+
"""Computes sums of per-instance ROUGE-N scores for a batch.
369+
370+
Args:
371+
predictions: A list of predicted strings. The shape should be (batch_size,
372+
).
373+
references: A list of reference strings. Each prediction must have one
374+
corresponding reference string. The shape should be (batch_size, ).
375+
order: The order 'N' of n-grams to consider. Must be positive.
376+
377+
Returns:
378+
A RougeN metric instance with accumulated per-instance scores.
379+
"""
380+
total_precision = 0.0
381+
total_recall = 0.0
382+
total_f1 = 0.0
383+
num_examples = 0.0
384+
385+
for pred_str, ref_str in zip(predictions, references):
386+
pred_tokens = pred_str.split()
387+
ref_tokens = ref_str.split()
388+
389+
pred_ngrams_counts = _get_single_n_grams(pred_tokens, order)
390+
ref_ngrams_counts = _get_single_n_grams(ref_tokens, order)
391+
overlap_counts = pred_ngrams_counts & ref_ngrams_counts
392+
393+
prediction_ngrams = jnp.array(sum(pred_ngrams_counts.values()))
394+
reference_ngrams = jnp.array(sum(ref_ngrams_counts.values()))
395+
overlapping_ngrams = jnp.array(sum(overlap_counts.values()))
396+
397+
precision = base.divide_no_nan(overlapping_ngrams, prediction_ngrams)
398+
recall = base.divide_no_nan(overlapping_ngrams, reference_ngrams)
399+
f1 = base.divide_no_nan(2 * precision * recall, precision + recall)
400+
401+
total_precision += precision
402+
total_recall += recall
403+
total_f1 += f1
404+
num_examples += 1
405+
406+
return cls(
407+
order=order,
408+
total_precision=jnp.array(total_precision, dtype=jnp.float32),
409+
total_recall=jnp.array(total_recall, dtype=jnp.float32),
410+
total_f1=jnp.array(total_f1, dtype=jnp.float32),
411+
num_examples=jnp.array(num_examples, dtype=jnp.float32),
412+
)
413+
414+
def merge(self, other: 'RougeN') -> 'RougeN':
415+
"""Merges this RougeN metric with another.
416+
417+
Args:
418+
other: Another RougeN metric instance.
419+
420+
Returns:
421+
A new RougeN metric instance with combined statistics.
422+
"""
423+
if self.order != other.order:
424+
raise ValueError(
425+
'RougeN metrics with different orders cannot be merged. '
426+
f'Got {self.order} and {other.order}.'
427+
)
428+
return RougeN(
429+
order=self.order,
430+
total_precision=(self.total_precision + other.total_precision),
431+
total_recall=(self.total_recall + other.total_recall),
432+
total_f1=(self.total_f1 + other.total_f1),
433+
num_examples=(self.num_examples + other.num_examples),
434+
)
435+
436+
def compute(self) -> jax.Array:
437+
"""Computes macro-averaged ROUGE-N recall, precision, and F1-score.
438+
439+
Returns:
440+
A JAX array where:
441+
- index 0: macro-averaged precision
442+
- index 1: macro-averaged recall
443+
- index 2: macro-averaged f1score (derived from avg_precision and
444+
avg_recall)
445+
Scores are 0.0 if num_examples is zero.
446+
"""
447+
macro_avg_precision = base.divide_no_nan(
448+
self.total_precision, self.num_examples
449+
)
450+
macro_avg_recall = base.divide_no_nan(self.total_recall, self.num_examples)
451+
macro_avg_f1score = base.divide_no_nan(self.total_f1, self.num_examples)
452+
return jnp.stack([macro_avg_precision, macro_avg_recall, macro_avg_f1score])
453+
454+
288455
@flax.struct.dataclass
289456
class WER(base.Average):
290457
r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks.
@@ -389,4 +556,4 @@ def _levenshtein_distance(prediction: list, reference: list) -> int:
389556
distance_matrix[i - 1][j - 1] + cost, # substitution
390557
)
391558

392-
return distance_matrix[m][n]
559+
return distance_matrix[m][n]

src/metrax/nlp_metrics_test.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_bleu(self):
6262
]
6363
predictions = [
6464
"He He He eats sweet apple which is a fruit",
65-
"I love Silicon Valley it's one of my favourite shows",
65+
"I love Silicon Valley it is one of my favourite shows",
6666
]
6767
keras_metric = keras_nlp.metrics.Bleu()
6868
keras_metric.update_state(references, predictions)
@@ -86,7 +86,7 @@ def test_bleu_merge(self):
8686
]
8787
predictions = [
8888
"He He He eats sweet apple which is a fruit",
89-
"I love Silicon Valley it's one of my favourite shows",
89+
"I love Silicon Valley it is one of my favourite shows",
9090
]
9191
keras_metric = keras_nlp.metrics.Bleu()
9292
keras_metric.update_state(references, predictions)
@@ -123,6 +123,75 @@ def test_bleu_merge_fails_on_different_max_order(self):
123123
ValueError, lambda: order_3_metric.merge(order_4_metric)
124124
)
125125

126+
def test_rougen(self):
127+
"""Tests that ROUGE-N metric computes correct values."""
128+
references = [
129+
"He eats a sweet apple",
130+
"Silicon Valley is one of my favourite shows",
131+
]
132+
predictions = [
133+
"He He He eats sweet apple which is a fruit",
134+
"I love Silicon Valley it is one of my favourite shows",
135+
]
136+
keras_metric = keras_nlp.metrics.RougeN()
137+
keras_metric.update_state(references, predictions)
138+
keras_metric_array = jnp.stack(list(keras_metric.result().values()))
139+
metrax_metric = metrax.RougeN.from_model_output(predictions, references)
140+
141+
np.testing.assert_allclose(
142+
metrax_metric.compute(),
143+
keras_metric_array,
144+
rtol=1e-05,
145+
atol=1e-05,
146+
)
147+
148+
def test_rougen_merge(self):
149+
"""Tests that ROUGE-N metric computes correct values using merge."""
150+
references = [
151+
"He eats a sweet apple",
152+
"Silicon Valley is one of my favourite shows",
153+
]
154+
predictions = [
155+
"He He He eats sweet apple which is a fruit",
156+
"I love Silicon Valley it is one of my favourite shows",
157+
]
158+
keras_metric = keras_nlp.metrics.RougeN()
159+
keras_metric.update_state(references, predictions)
160+
keras_metric_array = jnp.stack(list(keras_metric.result().values()))
161+
162+
metrax_metric = None
163+
for ref, pred in zip(references, predictions):
164+
update = metrax.RougeN.from_model_output([pred], [ref])
165+
metrax_metric = (
166+
update if metrax_metric is None else metrax_metric.merge(update)
167+
)
168+
169+
np.testing.assert_allclose(
170+
metrax_metric.compute(),
171+
keras_metric_array,
172+
rtol=1e-05,
173+
atol=1e-05,
174+
)
175+
176+
def test_rougen_merge_fails_on_different_max_order(self):
177+
"""Tests that error is raised when ROUGE-N metrics with different max_order are merged."""
178+
references = [
179+
"He eats a sweet apple",
180+
]
181+
predictions = [
182+
"He He He eats sweet apple which is a fruit",
183+
]
184+
order_3_metric = metrax.RougeN.from_model_output(
185+
predictions, references, order=3
186+
)
187+
order_4_metric = metrax.RougeN.from_model_output(
188+
predictions, references, order=4
189+
)
190+
191+
np.testing.assert_raises(
192+
ValueError, lambda: order_3_metric.merge(order_4_metric)
193+
)
194+
126195
@parameterized.named_parameters(
127196
(
128197
'basic',

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
RMSE = nnx_metrics.RMSE
2626
RSQUARED = nnx_metrics.RSQUARED
2727
Recall = nnx_metrics.Recall
28+
RougeN = nnx_metrics.RougeN
2829
WER = nnx_metrics.WER
2930

3031

@@ -40,5 +41,6 @@
4041
"RMSE",
4142
"RSQUARED",
4243
"Recall",
44+
"RougeN",
4345
"WER",
4446
]

src/metrax/nnx/nnx_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def __init__(self):
8888
super().__init__(metrax.RMSE)
8989

9090

91+
class RougeN(NnxWrapper):
92+
"""An NNX class for the Metrax metric RougeN."""
93+
94+
def __init__(self):
95+
super().__init__(metrax.RougeN)
96+
97+
9198
class RSQUARED(NnxWrapper):
9299
"""An NNX class for the Metrax metric RSQUARED."""
93100

0 commit comments

Comments
 (0)