|
15 | 15 | """A collection of different metrics for NLP models.""" |
16 | 16 |
|
17 | 17 | from clu import metrics as clu_metrics |
| 18 | +import collections |
| 19 | +import math |
18 | 20 | import flax |
19 | 21 | import jax |
20 | 22 | import jax.numpy as jnp |
21 | 23 | from metrax import base |
22 | 24 |
|
23 | 25 |
|
| 26 | +def get_ngrams(segment: list[str], max_order: int): |
| 27 | + """Extracts all n-grams up to a given maximum order from an input segment. |
| 28 | +
|
| 29 | + Args: |
| 30 | + segment: list. Text segment from which n-grams will be extracted. |
| 31 | + max_order: int. Maximum length in tokens of the n-grams returned by this |
| 32 | + method. |
| 33 | + """ |
| 34 | + ngram_counts = collections.Counter() |
| 35 | + for order in range(1, max_order + 1): |
| 36 | + for i in range(0, len(segment) - order + 1): |
| 37 | + ngram = tuple(segment[i : i + order]) |
| 38 | + ngram_counts[ngram] += 1 |
| 39 | + return ngram_counts |
| 40 | + |
| 41 | + |
| 42 | +@flax.struct.dataclass |
| 43 | +class BLEU(clu_metrics.Metric): |
| 44 | + r"""Computes the BLEU score for sequence generation. |
| 45 | +
|
| 46 | + BLEU measures the similarity between a machine-generated candidate translation |
| 47 | + and one or more human reference translations, focusing on matching n-grams. |
| 48 | +
|
| 49 | + It's calculated as: |
| 50 | + .. math:: |
| 51 | + \text{BLEU} = \text{BP} \times \exp\left( \sum_{n=1}^{N} w_n \log p_n |
| 52 | + \right) |
| 53 | +
|
| 54 | + Where: |
| 55 | + - :math:`p_n` is the modified n-gram precision for n-grams of order n. |
| 56 | + - :math:`N` is the maximum n-gram order considered (typically 4). |
| 57 | + - :math:`w_n` are weights for each order (typically uniform, 1/N). |
| 58 | + - :math:`\text{BP}` is the Brevity Penalty. |
| 59 | +
|
| 60 | + This implementation uses uniform weights and calculates statistics |
| 61 | + incrementally. |
| 62 | +
|
| 63 | + Attributes: |
| 64 | + max_order: Maximum n-gram order to consider. |
| 65 | + matches_by_order: Accumulated sum of clipped n-gram matches for each order. |
| 66 | + possible_matches_by_order: Accumulated sum of total n-grams in predictions |
| 67 | + for each order. |
| 68 | + translation_length: Accumulated total length of predictions. |
| 69 | + reference_length: Accumulated total 'effective' reference length (closest |
| 70 | + length match for each prediction). |
| 71 | + """ |
| 72 | + |
| 73 | + max_order: int |
| 74 | + matches_by_order: jax.Array |
| 75 | + possible_matches_by_order: jax.Array |
| 76 | + translation_length: jax.Array |
| 77 | + reference_length: jax.Array |
| 78 | + |
| 79 | + @classmethod |
| 80 | + def empty(cls) -> 'BLEU': |
| 81 | + return cls( |
| 82 | + max_order=4, |
| 83 | + matches_by_order=jnp.array(0, jnp.float32), |
| 84 | + possible_matches_by_order=jnp.array(0, jnp.float32), |
| 85 | + translation_length=jnp.array(0, jnp.float32), |
| 86 | + reference_length=jnp.array(0, jnp.float32), |
| 87 | + ) |
| 88 | + |
| 89 | + @classmethod |
| 90 | + def from_model_output( |
| 91 | + cls, |
| 92 | + predictions: list[str], |
| 93 | + references: list[list[str]], |
| 94 | + max_order: int = 4, |
| 95 | + ) -> 'BLEU': |
| 96 | + """Computes BLEU statistics for a batch of predictions and references. |
| 97 | +
|
| 98 | + Args: |
| 99 | + predictions: A list of predicted strings. The shape should be (batch_size, |
| 100 | + ). |
| 101 | + references: A list of lists of reference strings. The shape should be |
| 102 | + (batch_size, num_references). |
| 103 | + max_order: The maximum order of n-grams to consider. |
| 104 | +
|
| 105 | + Returns: |
| 106 | + A BLEU metric instance containing the statistics for this batch. |
| 107 | +
|
| 108 | + Raises: |
| 109 | + ValueError: If the shapes of `predictions` and `references` are |
| 110 | + incompatible. |
| 111 | + """ |
| 112 | + matches_by_order = [0] * max_order |
| 113 | + possible_matches_by_order = [0] * max_order |
| 114 | + pred_length = 0 |
| 115 | + ref_length = 0 |
| 116 | + |
| 117 | + for pred, ref_list in zip(predictions, references): |
| 118 | + pred = pred.split() |
| 119 | + ref_list = [r.split() for r in ref_list] |
| 120 | + pred_length += len(pred) |
| 121 | + ref_length += min(len(r) for r in ref_list) |
| 122 | + prediction_ngram_counts = get_ngrams(pred, max_order) |
| 123 | + reference_ngram_counts = collections.Counter() |
| 124 | + for ref in ref_list: |
| 125 | + reference_ngram_counts |= get_ngrams(ref, max_order) |
| 126 | + overlap = prediction_ngram_counts & reference_ngram_counts |
| 127 | + for ngram in overlap: |
| 128 | + matches_by_order[len(ngram) - 1] += overlap[ngram] |
| 129 | + for order in range(1, max_order + 1): |
| 130 | + possible_matches = len(pred) - order + 1 |
| 131 | + if possible_matches > 0: |
| 132 | + possible_matches_by_order[order - 1] += possible_matches |
| 133 | + |
| 134 | + return cls( |
| 135 | + max_order=max_order, |
| 136 | + matches_by_order=jnp.array(matches_by_order, dtype=jnp.float32), |
| 137 | + possible_matches_by_order=jnp.array( |
| 138 | + possible_matches_by_order, dtype=jnp.float32 |
| 139 | + ), |
| 140 | + translation_length=jnp.array(pred_length, dtype=jnp.float32), |
| 141 | + reference_length=jnp.array(ref_length, dtype=jnp.float32), |
| 142 | + ) |
| 143 | + |
| 144 | + def merge(self, other: 'BLEU') -> 'BLEU': |
| 145 | + if self.max_order != other.max_order: |
| 146 | + raise ValueError( |
| 147 | + 'BLEU metrics with different max_order cannot be merged.' |
| 148 | + ) |
| 149 | + return type(self)( |
| 150 | + max_order=self.max_order, |
| 151 | + matches_by_order=(self.matches_by_order + other.matches_by_order), |
| 152 | + possible_matches_by_order=( |
| 153 | + self.possible_matches_by_order + other.possible_matches_by_order |
| 154 | + ), |
| 155 | + translation_length=(self.translation_length + other.translation_length), |
| 156 | + reference_length=(self.reference_length + other.reference_length), |
| 157 | + ) |
| 158 | + |
| 159 | + def compute(self) -> jax.Array: |
| 160 | + precisions = [0] * self.max_order |
| 161 | + for i in range(0, self.max_order): |
| 162 | + precisions[i] = base.divide_no_nan( |
| 163 | + self.matches_by_order[i], self.possible_matches_by_order[i] |
| 164 | + ) |
| 165 | + geo_mean = ( |
| 166 | + math.exp(sum((1.0 / self.max_order) * math.log(p) for p in precisions)) |
| 167 | + if precisions and min(precisions) > 0 |
| 168 | + else 0 |
| 169 | + ) |
| 170 | + ratio = base.divide_no_nan(self.translation_length, self.reference_length) |
| 171 | + bp = 1.0 if ratio > 1.0 else math.exp(1 - 1.0 / ratio) |
| 172 | + bleu = geo_mean * bp |
| 173 | + return jnp.array(bleu) |
| 174 | + |
| 175 | + |
24 | 176 | @flax.struct.dataclass |
25 | 177 | class Perplexity(clu_metrics.Metric): |
26 | 178 | r"""Computes perplexity for sequence generation. |
|
0 commit comments