|
22 | 22 |
|
23 | 23 | @flax.struct.dataclass |
24 | 24 | class Perplexity(clu_metrics.Metric): |
25 | | - r"""Computes perplexity for sequence generation. |
26 | | -
|
27 | | - Perplexity is a measurement of how well a probability distribution predicts a |
28 | | - sample. It is defined as the exponentiation of the cross-entropy. A low |
29 | | - perplexity indicates the probability distribution is good at predicting the |
30 | | - sample. |
31 | | -
|
32 | | - For language models, it can be interpreted as the weighted average branching |
33 | | - factor of the model - how many equally likely words can be selected at each |
34 | | - step. |
35 | | -
|
36 | | - Given a sequence of :math:`N` tokens, perplexity is calculated as: |
37 | | -
|
38 | | - .. math:: |
39 | | - Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right) |
40 | | -
|
41 | | - When sample weights :math:`w_i` are provided: |
42 | | -
|
43 | | - .. math:: |
44 | | - Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right) |
45 | | -
|
46 | | - where: |
47 | | - - :math:`P(x_i|x_{<i})` is the predicted probability of token :math:`x_i` |
48 | | - given previous tokens |
49 | | - - :math:`w_i` are sample weights |
50 | | - - :math:`N` is the sequence length |
51 | | -
|
52 | | - Lower perplexity indicates better prediction - the model is less "perplexed" by the data. |
53 | | - """ |
54 | | - |
55 | | - aggregate_crossentropy: jax.Array |
56 | | - num_samples: jax.Array |
57 | | - |
58 | | - @classmethod |
59 | | - def empty(cls) -> 'Perplexity': |
60 | | - return cls( |
61 | | - aggregate_crossentropy=jnp.array(0, jnp.float32), |
62 | | - num_samples=jnp.array(0, jnp.float32)) |
63 | | - |
64 | | - @classmethod |
65 | | - def from_model_output( |
66 | | - cls, |
67 | | - predictions: jax.Array, |
68 | | - labels: jax.Array, |
69 | | - sample_weights: jax.Array | None = None, |
70 | | - ) -> 'Perplexity': |
71 | | - """Updates the metric. |
72 | | -
|
73 | | - Args: |
74 | | - predictions: A floating point tensor representing the prediction |
75 | | - generated from the model. The shape should be (batch_size, seq_len, |
76 | | - vocab_size). |
77 | | - labels: True value. The shape should be (batch_size, seq_len). |
78 | | - sample_weights: An optional tensor representing the |
79 | | - weight of each token. The shape should be (batch_size, seq_len). |
80 | | -
|
81 | | - Returns: |
82 | | - Updated Perplexity metric. |
83 | | -
|
84 | | - Raises: |
85 | | - ValueError: If type of `labels` is wrong or the shapes of `predictions` |
86 | | - and `labels` are incompatible. |
| 25 | + r"""Computes perplexity for sequence generation. |
| 26 | +
|
| 27 | + Perplexity is a measurement of how well a probability distribution predicts a |
| 28 | + sample. It is defined as the exponentiation of the cross-entropy. A low |
| 29 | + perplexity indicates the probability distribution is good at predicting the |
| 30 | + sample. |
| 31 | +
|
| 32 | + For language models, it can be interpreted as the weighted average branching |
| 33 | + factor of the model - how many equally likely words can be selected at each |
| 34 | + step. |
| 35 | +
|
| 36 | + Given a sequence of :math:`N` tokens, perplexity is calculated as: |
| 37 | +
|
| 38 | + .. math:: |
| 39 | + Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right) |
| 40 | +
|
| 41 | + When sample weights :math:`w_i` are provided: |
| 42 | +
|
| 43 | + .. math:: |
| 44 | + Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right) |
| 45 | +
|
| 46 | + where: |
| 47 | + - :math:`P(x_i|x_{<i})` is the predicted probability of token :math:`x_i` |
| 48 | + given previous tokens |
| 49 | + - :math:`w_i` are sample weights |
| 50 | + - :math:`N` is the sequence length |
| 51 | +
|
| 52 | + Lower perplexity indicates better prediction - the model is less "perplexed" by the data. |
87 | 53 | """ |
88 | | - predictions = predictions / jnp.sum(predictions, axis=-1, keepdims=True) |
89 | | - labels_one_hot = jax.nn.one_hot(labels, predictions.shape[-1], axis=-1) |
90 | | - log_prob = jnp.log(predictions) |
91 | | - crossentropy = -jnp.sum(labels_one_hot * log_prob, axis=-1) |
92 | | - |
93 | | - # Sum across sequence length dimension first. |
94 | | - if sample_weights is not None: |
95 | | - crossentropy = crossentropy * sample_weights |
96 | | - # Normalize by the sum of weights for each sequence. |
97 | | - crossentropy = jnp.sum(crossentropy) / jnp.sum(sample_weights) |
98 | | - else: |
99 | | - crossentropy = jnp.mean(crossentropy) |
100 | | - |
101 | | - batch_size = jnp.array(labels.shape[0]) |
102 | | - return cls( |
103 | | - aggregate_crossentropy=(batch_size * crossentropy), |
104 | | - num_samples=batch_size, |
105 | | - ) |
106 | | - |
107 | | - def merge(self, other: 'Perplexity') -> 'Perplexity': |
108 | | - return type(self)( |
109 | | - aggregate_crossentropy=( |
110 | | - self.aggregate_crossentropy + other.aggregate_crossentropy |
111 | | - ), |
112 | | - num_samples=self.num_samples + other.num_samples, |
113 | | - ) |
114 | | - |
115 | | - def compute(self) -> jax.Array: |
116 | | - return jnp.exp(self.aggregate_crossentropy / self.num_samples) |
| 54 | + |
| 55 | + aggregate_crossentropy: jax.Array |
| 56 | + num_samples: jax.Array |
| 57 | + |
| 58 | + @classmethod |
| 59 | + def empty(cls) -> "Perplexity": |
| 60 | + return cls( |
| 61 | + aggregate_crossentropy=jnp.array(0, jnp.float32), |
| 62 | + num_samples=jnp.array(0, jnp.float32), |
| 63 | + ) |
| 64 | + |
| 65 | + @classmethod |
| 66 | + def from_model_output( |
| 67 | + cls, |
| 68 | + predictions: jax.Array, |
| 69 | + labels: jax.Array, |
| 70 | + sample_weights: jax.Array | None = None, |
| 71 | + ) -> "Perplexity": |
| 72 | + """Updates the metric. |
| 73 | +
|
| 74 | + Args: |
| 75 | + predictions: A floating point tensor representing the prediction |
| 76 | + generated from the model. The shape should be (batch_size, seq_len, |
| 77 | + vocab_size). |
| 78 | + labels: True value. The shape should be (batch_size, seq_len). |
| 79 | + sample_weights: An optional tensor representing the |
| 80 | + weight of each token. The shape should be (batch_size, seq_len). |
| 81 | +
|
| 82 | + Returns: |
| 83 | + Updated Perplexity metric. |
| 84 | +
|
| 85 | + Raises: |
| 86 | + ValueError: If type of `labels` is wrong or the shapes of `predictions` |
| 87 | + and `labels` are incompatible. |
| 88 | + """ |
| 89 | + predictions = predictions / jnp.sum(predictions, axis=-1, keepdims=True) |
| 90 | + labels_one_hot = jax.nn.one_hot(labels, predictions.shape[-1], axis=-1) |
| 91 | + log_prob = jnp.log(predictions) |
| 92 | + crossentropy = -jnp.sum(labels_one_hot * log_prob, axis=-1) |
| 93 | + |
| 94 | + # Sum across sequence length dimension first. |
| 95 | + if sample_weights is not None: |
| 96 | + crossentropy = crossentropy * sample_weights |
| 97 | + # Normalize by the sum of weights for each sequence. |
| 98 | + crossentropy = jnp.sum(crossentropy) / jnp.sum(sample_weights) |
| 99 | + else: |
| 100 | + crossentropy = jnp.mean(crossentropy) |
| 101 | + |
| 102 | + batch_size = jnp.array(labels.shape[0]) |
| 103 | + return cls( |
| 104 | + aggregate_crossentropy=(batch_size * crossentropy), |
| 105 | + num_samples=batch_size, |
| 106 | + ) |
| 107 | + |
| 108 | + def merge(self, other: "Perplexity") -> "Perplexity": |
| 109 | + return type(self)( |
| 110 | + aggregate_crossentropy=( |
| 111 | + self.aggregate_crossentropy + other.aggregate_crossentropy |
| 112 | + ), |
| 113 | + num_samples=self.num_samples + other.num_samples, |
| 114 | + ) |
| 115 | + |
| 116 | + def compute(self) -> jax.Array: |
| 117 | + return jnp.exp(self.aggregate_crossentropy / self.num_samples) |
| 118 | + |
| 119 | + |
| 120 | +@flax.struct.dataclass |
| 121 | +class WER(clu_metrics.Average): |
| 122 | + r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks. |
| 123 | +
|
| 124 | + Word Error Rate measures the edit distance between reference texts and predictions, |
| 125 | + normalized by the length of the reference texts. It is calculated as: |
| 126 | +
|
| 127 | + .. math:: |
| 128 | + WER = \frac{S + D + I}{N} |
| 129 | +
|
| 130 | + where: |
| 131 | + - S is the number of substitutions |
| 132 | + - D is the number of deletions |
| 133 | + - I is the number of insertions |
| 134 | + - N is the number of words in the reference |
| 135 | +
|
| 136 | + A lower WER indicates better performance, with 0 being perfect. |
| 137 | +
|
| 138 | + This implementation accepts both pre-tokenized inputs (lists of tokens) and untokenized |
| 139 | + strings. When strings are provided, they are tokenized by splitting on whitespace. |
| 140 | + """ |
| 141 | + |
| 142 | + @classmethod |
| 143 | + def from_model_output( |
| 144 | + cls, |
| 145 | + predictions: list[str], |
| 146 | + references: list[str], |
| 147 | + ) -> "WER": |
| 148 | + """Updates the metric. |
| 149 | +
|
| 150 | + Args: |
| 151 | + prediction: Either a string or a list of tokens in the predicted sequence. |
| 152 | + reference: Either a string or a list of tokens in the reference sequence. |
| 153 | +
|
| 154 | + Returns: |
| 155 | + New WER metric instance. |
| 156 | +
|
| 157 | + Raises: |
| 158 | + ValueError: If inputs are not properly formatted or are empty. |
| 159 | + """ |
| 160 | + if not predictions or not references: |
| 161 | + raise ValueError("predictions and references must not be empty") |
| 162 | + |
| 163 | + if isinstance(predictions, str): |
| 164 | + predictions = predictions.split() |
| 165 | + if isinstance(references, str): |
| 166 | + references = references.split() |
| 167 | + |
| 168 | + edit_distance = cls._levenshtein_distance(predictions, references) |
| 169 | + reference_length = len(references) |
| 170 | + |
| 171 | + return cls( |
| 172 | + total=jnp.array(edit_distance, dtype=jnp.float32), |
| 173 | + count=jnp.array(reference_length, dtype=jnp.float32), |
| 174 | + ) |
| 175 | + |
| 176 | + @staticmethod |
| 177 | + def _levenshtein_distance(prediction: list, reference: list) -> int: |
| 178 | + """Computes the Levenshtein (edit) distance between two token sequences. |
| 179 | +
|
| 180 | + Args: |
| 181 | + prediction: List of tokens in the predicted sequence. |
| 182 | + reference: List of tokens in the reference sequence. |
| 183 | +
|
| 184 | + Returns: |
| 185 | + The minimum number of edits needed to transform prediction into reference. |
| 186 | + """ |
| 187 | + m, n = len(prediction), len(reference) |
| 188 | + |
| 189 | + # Handle edge cases |
| 190 | + if m == 0: |
| 191 | + return n |
| 192 | + if n == 0: |
| 193 | + return m |
| 194 | + |
| 195 | + # Create distance matrix |
| 196 | + distance_matrix = [[0 for _ in range(n + 1)] for _ in range(m + 1)] |
| 197 | + |
| 198 | + # Initialize first row and column |
| 199 | + for i in range(m + 1): |
| 200 | + distance_matrix[i][0] = i |
| 201 | + for j in range(n + 1): |
| 202 | + distance_matrix[0][j] = j |
| 203 | + |
| 204 | + # Fill the matrix |
| 205 | + for i in range(1, m + 1): |
| 206 | + for j in range(1, n + 1): |
| 207 | + if prediction[i - 1] == reference[j - 1]: |
| 208 | + cost = 0 |
| 209 | + else: |
| 210 | + cost = 1 |
| 211 | + |
| 212 | + distance_matrix[i][j] = min( |
| 213 | + distance_matrix[i - 1][j] + 1, # deletion |
| 214 | + distance_matrix[i][j - 1] + 1, # insertion |
| 215 | + distance_matrix[i - 1][j - 1] + cost, # substitution |
| 216 | + ) |
| 217 | + |
| 218 | + return distance_matrix[m][n] |
0 commit comments