1414
1515"""A collection of different metrics for NLP models."""
1616
17- from clu import metrics as clu_metrics
1817import collections
1918import math
19+ from clu import metrics as clu_metrics
2020import flax
2121import jax
2222import jax .numpy as jnp
2323from metrax import base
2424
2525
26+ def _get_single_n_grams (segment : list [str ], order : int ):
27+ """Generates a counter of n-grams from a list of tokens for a specific n.
28+
29+ Args:
30+ segment: list. Text segment from which n-grams will be extracted.
31+ order: The order of n-grams.
32+
33+ Returns:
34+ A collections.Counter mapping n-gram tuples to their counts.
35+ """
36+ return collections .Counter (zip (* [segment [i :] for i in range (order )]))
37+
38+
2639def _get_ngrams (segment : list [str ], max_order : int ):
2740 """Extracts all n-grams up to a given maximum order from an input segment.
2841
2942 Args:
3043 segment: list. Text segment from which n-grams will be extracted.
3144 max_order: int. Maximum length in tokens of the n-grams returned by this
3245 method.
46+
47+ Returns:
48+ A collections.Counter mapping n-gram tuples to their counts for all orders.
3349 """
3450 ngram_counts = collections .Counter ()
3551 for order in range (1 , max_order + 1 ):
36- for i in range (0 , len (segment ) - order + 1 ):
37- ngram = tuple (segment [i : i + order ])
38- ngram_counts [ngram ] += 1
52+ ngram_counts .update (_get_single_n_grams (segment , order ))
3953 return ngram_counts
4054
4155
@@ -285,6 +299,159 @@ def compute(self) -> jax.Array:
285299 )
286300
287301
302+ @flax .struct .dataclass
303+ class RougeN (clu_metrics .Metric ):
304+ r"""Computes macro-averaged ROUGE-N recall, precision, and F1-score.
305+
306+ This metric first calculates ROUGE-N precision, recall, and F1-score for each
307+ individual prediction compared against its single corresponding reference.
308+ These per-instance precision, recall and F1-scores are then averaged across
309+ all instances in the dataset/batch.
310+
311+ Accumulation for Macro-Average:
312+ - total_precision = sum of all precision values.
313+ - total_recall = sum of all instance_recall values.
314+ - total_f1 = sum of all f1 values.
315+ - num_examples = count of prediction-reference pairs.
316+
317+ Final Macro-Averaged Metrics:
318+ .. math::
319+ \text{MacroAvgPrecision} =
320+ \frac{\text{total_precision}}{\text{num_examples}}
321+ .. math::
322+ \text{MacroAvgRecall} = \frac{\text{total_recall}}{\text{num_examples}}
323+ .. math::
324+ \text{MacroAvgF1} = 2 \cdot \frac{\text{MacroAvgPrecision} \cdot
325+ \text{MacroAvgRecall}}{\text{MacroAvgPrecision} + \text{MacroAvgRecall}}
326+
327+ Attributes:
328+ order: The specific 'N' in ROUGE-N (e.g., 1 for ROUGE-1, 2 for ROUGE-2).
329+ total_precision: Accumulated sum of precision scores from each instance.
330+ total_recall: Accumulated sum of recall scores from each instance.
331+ total_f1: Accumulated sum of f1 scores from each instance.
332+ num_examples: The number of instances (prediction-reference pairs)
333+ processed.
334+ """
335+
336+ order : int
337+ total_precision : jax .Array
338+ total_recall : jax .Array
339+ total_f1 : jax .Array
340+ num_examples : jax .Array
341+
342+ @classmethod
343+ def empty (cls , order : int = 2 ) -> 'RougeN' :
344+ """Creates an empty ROUGE-N metric for macro-averaging.
345+
346+ Args:
347+ order: The order 'N' of n-grams (e.g., 2 for ROUGE-2). Must be a positive
348+ integer.
349+
350+ Returns:
351+ An empty RougeN metric.
352+ """
353+ return cls (
354+ order = order ,
355+ total_precision = jnp .array (0 , jnp .float32 ),
356+ total_recall = jnp .array (0 , jnp .float32 ),
357+ total_f1 = jnp .array (0 , jnp .float32 ),
358+ num_examples = jnp .array (0 , jnp .float32 ),
359+ )
360+
361+ @classmethod
362+ def from_model_output (
363+ cls ,
364+ predictions : list [str ],
365+ references : list [str ],
366+ order : int = 2 ,
367+ ) -> 'RougeN' :
368+ """Computes sums of per-instance ROUGE-N scores for a batch.
369+
370+ Args:
371+ predictions: A list of predicted strings. The shape should be (batch_size,
372+ ).
373+ references: A list of reference strings. Each prediction must have one
374+ corresponding reference string. The shape should be (batch_size, ).
375+ order: The order 'N' of n-grams to consider. Must be positive.
376+
377+ Returns:
378+ A RougeN metric instance with accumulated per-instance scores.
379+ """
380+ total_precision = 0.0
381+ total_recall = 0.0
382+ total_f1 = 0.0
383+ num_examples = 0.0
384+
385+ for pred_str , ref_str in zip (predictions , references ):
386+ pred_tokens = pred_str .split ()
387+ ref_tokens = ref_str .split ()
388+
389+ pred_ngrams_counts = _get_single_n_grams (pred_tokens , order )
390+ ref_ngrams_counts = _get_single_n_grams (ref_tokens , order )
391+ overlap_counts = pred_ngrams_counts & ref_ngrams_counts
392+
393+ prediction_ngrams = jnp .array (sum (pred_ngrams_counts .values ()))
394+ reference_ngrams = jnp .array (sum (ref_ngrams_counts .values ()))
395+ overlapping_ngrams = jnp .array (sum (overlap_counts .values ()))
396+
397+ precision = base .divide_no_nan (overlapping_ngrams , prediction_ngrams )
398+ recall = base .divide_no_nan (overlapping_ngrams , reference_ngrams )
399+ f1 = base .divide_no_nan (2 * precision * recall , precision + recall )
400+
401+ total_precision += precision
402+ total_recall += recall
403+ total_f1 += f1
404+ num_examples += 1
405+
406+ return cls (
407+ order = order ,
408+ total_precision = jnp .array (total_precision , dtype = jnp .float32 ),
409+ total_recall = jnp .array (total_recall , dtype = jnp .float32 ),
410+ total_f1 = jnp .array (total_f1 , dtype = jnp .float32 ),
411+ num_examples = jnp .array (num_examples , dtype = jnp .float32 ),
412+ )
413+
414+ def merge (self , other : 'RougeN' ) -> 'RougeN' :
415+ """Merges this RougeN metric with another.
416+
417+ Args:
418+ other: Another RougeN metric instance.
419+
420+ Returns:
421+ A new RougeN metric instance with combined statistics.
422+ """
423+ if self .order != other .order :
424+ raise ValueError (
425+ 'RougeN metrics with different orders cannot be merged. '
426+ f'Got { self .order } and { other .order } .'
427+ )
428+ return RougeN (
429+ order = self .order ,
430+ total_precision = (self .total_precision + other .total_precision ),
431+ total_recall = (self .total_recall + other .total_recall ),
432+ total_f1 = (self .total_f1 + other .total_f1 ),
433+ num_examples = (self .num_examples + other .num_examples ),
434+ )
435+
436+ def compute (self ) -> jax .Array :
437+ """Computes macro-averaged ROUGE-N recall, precision, and F1-score.
438+
439+ Returns:
440+ A JAX array where:
441+ - index 0: macro-averaged precision
442+ - index 1: macro-averaged recall
443+ - index 2: macro-averaged f1score (derived from avg_precision and
444+ avg_recall)
445+ Scores are 0.0 if num_examples is zero.
446+ """
447+ macro_avg_precision = base .divide_no_nan (
448+ self .total_precision , self .num_examples
449+ )
450+ macro_avg_recall = base .divide_no_nan (self .total_recall , self .num_examples )
451+ macro_avg_f1score = base .divide_no_nan (self .total_f1 , self .num_examples )
452+ return jnp .stack ([macro_avg_precision , macro_avg_recall , macro_avg_f1score ])
453+
454+
288455@flax .struct .dataclass
289456class WER (base .Average ):
290457 r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks.
@@ -389,4 +556,4 @@ def _levenshtein_distance(prediction: list, reference: list) -> int:
389556 distance_matrix [i - 1 ][j - 1 ] + cost , # substitution
390557 )
391558
392- return distance_matrix [m ][n ]
559+ return distance_matrix [m ][n ]
0 commit comments