Skip to content

Commit a767587

Browse files
committed
Add WER metric
1 parent 61e5dc2 commit a767587

File tree

3 files changed

+244
-91
lines changed

3 files changed

+244
-91
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from metrax.nlp_metrics import (
2222
Perplexity,
23+
WER
2324
)
2425
from metrax.ranking_metrics import (
2526
AveragePrecisionAtK,
@@ -40,4 +41,5 @@
4041
"Recall",
4142
"RMSE",
4243
"RSQUARED",
44+
"WER",
4345
]

src/metrax/nlp_metrics.py

Lines changed: 193 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -22,95 +22,197 @@
2222

2323
@flax.struct.dataclass
2424
class Perplexity(clu_metrics.Metric):
25-
r"""Computes perplexity for sequence generation.
26-
27-
Perplexity is a measurement of how well a probability distribution predicts a
28-
sample. It is defined as the exponentiation of the cross-entropy. A low
29-
perplexity indicates the probability distribution is good at predicting the
30-
sample.
31-
32-
For language models, it can be interpreted as the weighted average branching
33-
factor of the model - how many equally likely words can be selected at each
34-
step.
35-
36-
Given a sequence of :math:`N` tokens, perplexity is calculated as:
37-
38-
.. math::
39-
Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right)
40-
41-
When sample weights :math:`w_i` are provided:
42-
43-
.. math::
44-
Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right)
45-
46-
where:
47-
- :math:`P(x_i|x_{<i})` is the predicted probability of token :math:`x_i`
48-
given previous tokens
49-
- :math:`w_i` are sample weights
50-
- :math:`N` is the sequence length
51-
52-
Lower perplexity indicates better prediction - the model is less "perplexed" by the data.
53-
"""
54-
55-
aggregate_crossentropy: jax.Array
56-
num_samples: jax.Array
57-
58-
@classmethod
59-
def empty(cls) -> 'Perplexity':
60-
return cls(
61-
aggregate_crossentropy=jnp.array(0, jnp.float32),
62-
num_samples=jnp.array(0, jnp.float32))
63-
64-
@classmethod
65-
def from_model_output(
66-
cls,
67-
predictions: jax.Array,
68-
labels: jax.Array,
69-
sample_weights: jax.Array | None = None,
70-
) -> 'Perplexity':
71-
"""Updates the metric.
72-
73-
Args:
74-
predictions: A floating point tensor representing the prediction
75-
generated from the model. The shape should be (batch_size, seq_len,
76-
vocab_size).
77-
labels: True value. The shape should be (batch_size, seq_len).
78-
sample_weights: An optional tensor representing the
79-
weight of each token. The shape should be (batch_size, seq_len).
80-
81-
Returns:
82-
Updated Perplexity metric.
83-
84-
Raises:
85-
ValueError: If type of `labels` is wrong or the shapes of `predictions`
86-
and `labels` are incompatible.
25+
r"""Computes perplexity for sequence generation.
26+
27+
Perplexity is a measurement of how well a probability distribution predicts a
28+
sample. It is defined as the exponentiation of the cross-entropy. A low
29+
perplexity indicates the probability distribution is good at predicting the
30+
sample.
31+
32+
For language models, it can be interpreted as the weighted average branching
33+
factor of the model - how many equally likely words can be selected at each
34+
step.
35+
36+
Given a sequence of :math:`N` tokens, perplexity is calculated as:
37+
38+
.. math::
39+
Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right)
40+
41+
When sample weights :math:`w_i` are provided:
42+
43+
.. math::
44+
Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right)
45+
46+
where:
47+
- :math:`P(x_i|x_{<i})` is the predicted probability of token :math:`x_i`
48+
given previous tokens
49+
- :math:`w_i` are sample weights
50+
- :math:`N` is the sequence length
51+
52+
Lower perplexity indicates better prediction - the model is less "perplexed" by the data.
8753
"""
88-
predictions = predictions / jnp.sum(predictions, axis=-1, keepdims=True)
89-
labels_one_hot = jax.nn.one_hot(labels, predictions.shape[-1], axis=-1)
90-
log_prob = jnp.log(predictions)
91-
crossentropy = -jnp.sum(labels_one_hot * log_prob, axis=-1)
92-
93-
# Sum across sequence length dimension first.
94-
if sample_weights is not None:
95-
crossentropy = crossentropy * sample_weights
96-
# Normalize by the sum of weights for each sequence.
97-
crossentropy = jnp.sum(crossentropy) / jnp.sum(sample_weights)
98-
else:
99-
crossentropy = jnp.mean(crossentropy)
100-
101-
batch_size = jnp.array(labels.shape[0])
102-
return cls(
103-
aggregate_crossentropy=(batch_size * crossentropy),
104-
num_samples=batch_size,
105-
)
106-
107-
def merge(self, other: 'Perplexity') -> 'Perplexity':
108-
return type(self)(
109-
aggregate_crossentropy=(
110-
self.aggregate_crossentropy + other.aggregate_crossentropy
111-
),
112-
num_samples=self.num_samples + other.num_samples,
113-
)
114-
115-
def compute(self) -> jax.Array:
116-
return jnp.exp(self.aggregate_crossentropy / self.num_samples)
54+
55+
aggregate_crossentropy: jax.Array
56+
num_samples: jax.Array
57+
58+
@classmethod
59+
def empty(cls) -> "Perplexity":
60+
return cls(
61+
aggregate_crossentropy=jnp.array(0, jnp.float32),
62+
num_samples=jnp.array(0, jnp.float32),
63+
)
64+
65+
@classmethod
66+
def from_model_output(
67+
cls,
68+
predictions: jax.Array,
69+
labels: jax.Array,
70+
sample_weights: jax.Array | None = None,
71+
) -> "Perplexity":
72+
"""Updates the metric.
73+
74+
Args:
75+
predictions: A floating point tensor representing the prediction
76+
generated from the model. The shape should be (batch_size, seq_len,
77+
vocab_size).
78+
labels: True value. The shape should be (batch_size, seq_len).
79+
sample_weights: An optional tensor representing the
80+
weight of each token. The shape should be (batch_size, seq_len).
81+
82+
Returns:
83+
Updated Perplexity metric.
84+
85+
Raises:
86+
ValueError: If type of `labels` is wrong or the shapes of `predictions`
87+
and `labels` are incompatible.
88+
"""
89+
predictions = predictions / jnp.sum(predictions, axis=-1, keepdims=True)
90+
labels_one_hot = jax.nn.one_hot(labels, predictions.shape[-1], axis=-1)
91+
log_prob = jnp.log(predictions)
92+
crossentropy = -jnp.sum(labels_one_hot * log_prob, axis=-1)
93+
94+
# Sum across sequence length dimension first.
95+
if sample_weights is not None:
96+
crossentropy = crossentropy * sample_weights
97+
# Normalize by the sum of weights for each sequence.
98+
crossentropy = jnp.sum(crossentropy) / jnp.sum(sample_weights)
99+
else:
100+
crossentropy = jnp.mean(crossentropy)
101+
102+
batch_size = jnp.array(labels.shape[0])
103+
return cls(
104+
aggregate_crossentropy=(batch_size * crossentropy),
105+
num_samples=batch_size,
106+
)
107+
108+
def merge(self, other: "Perplexity") -> "Perplexity":
109+
return type(self)(
110+
aggregate_crossentropy=(
111+
self.aggregate_crossentropy + other.aggregate_crossentropy
112+
),
113+
num_samples=self.num_samples + other.num_samples,
114+
)
115+
116+
def compute(self) -> jax.Array:
117+
return jnp.exp(self.aggregate_crossentropy / self.num_samples)
118+
119+
120+
@flax.struct.dataclass
121+
class WER(clu_metrics.Average):
122+
r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks.
123+
124+
Word Error Rate measures the edit distance between reference texts and predictions,
125+
normalized by the length of the reference texts. It is calculated as:
126+
127+
.. math::
128+
WER = \frac{S + D + I}{N}
129+
130+
where:
131+
- S is the number of substitutions
132+
- D is the number of deletions
133+
- I is the number of insertions
134+
- N is the number of words in the reference
135+
136+
A lower WER indicates better performance, with 0 being perfect.
137+
138+
This implementation accepts both pre-tokenized inputs (lists of tokens) and untokenized
139+
strings. When strings are provided, they are tokenized by splitting on whitespace.
140+
"""
141+
142+
@classmethod
143+
def from_model_output(
144+
cls,
145+
predictions: list[str],
146+
references: list[str],
147+
) -> "WER":
148+
"""Updates the metric.
149+
150+
Args:
151+
prediction: Either a string or a list of tokens in the predicted sequence.
152+
reference: Either a string or a list of tokens in the reference sequence.
153+
154+
Returns:
155+
New WER metric instance.
156+
157+
Raises:
158+
ValueError: If inputs are not properly formatted or are empty.
159+
"""
160+
if not predictions or not references:
161+
raise ValueError("predictions and references must not be empty")
162+
163+
if isinstance(predictions, str):
164+
predictions = predictions.split()
165+
if isinstance(references, str):
166+
references = references.split()
167+
168+
edit_distance = cls._levenshtein_distance(predictions, references)
169+
reference_length = len(references)
170+
171+
return cls(
172+
total=jnp.array(edit_distance, dtype=jnp.float32),
173+
count=jnp.array(reference_length, dtype=jnp.float32),
174+
)
175+
176+
@staticmethod
177+
def _levenshtein_distance(prediction: list, reference: list) -> int:
178+
"""Computes the Levenshtein (edit) distance between two token sequences.
179+
180+
Args:
181+
prediction: List of tokens in the predicted sequence.
182+
reference: List of tokens in the reference sequence.
183+
184+
Returns:
185+
The minimum number of edits needed to transform prediction into reference.
186+
"""
187+
m, n = len(prediction), len(reference)
188+
189+
# Handle edge cases
190+
if m == 0:
191+
return n
192+
if n == 0:
193+
return m
194+
195+
# Create distance matrix
196+
distance_matrix = [[0 for _ in range(n + 1)] for _ in range(m + 1)]
197+
198+
# Initialize first row and column
199+
for i in range(m + 1):
200+
distance_matrix[i][0] = i
201+
for j in range(n + 1):
202+
distance_matrix[0][j] = j
203+
204+
# Fill the matrix
205+
for i in range(1, m + 1):
206+
for j in range(1, n + 1):
207+
if prediction[i - 1] == reference[j - 1]:
208+
cost = 0
209+
else:
210+
cost = 1
211+
212+
distance_matrix[i][j] = min(
213+
distance_matrix[i - 1][j] + 1, # deletion
214+
distance_matrix[i][j - 1] + 1, # insertion
215+
distance_matrix[i - 1][j - 1] + cost, # substitution
216+
)
217+
218+
return distance_matrix[m][n]

src/metrax/nlp_metrics_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def test_perplexity_empty(self):
3030
self.assertEqual(m.aggregate_crossentropy, jnp.array(0, jnp.float32))
3131
self.assertEqual(m.num_samples, jnp.array(0, jnp.float32))
3232

33+
def test_wer_empty(self):
34+
"""Tests the `empty` method of the `WER` class."""
35+
m = metrax.WER.empty()
36+
self.assertEqual(m.total, jnp.array(0, jnp.float32))
37+
self.assertEqual(m.count, jnp.array(0, jnp.float32))
38+
3339
@parameterized.named_parameters(
3440
(
3541
'basic',
@@ -68,6 +74,49 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
6874
atol=1e-05,
6975
)
7076

77+
def test_wer(self):
78+
"""Tests that WER metric computes correct values with tokenized inputs."""
79+
string_preds = [
80+
"the cat sat on the mat",
81+
"a quick brown fox jumps over the lazy dog",
82+
"hello world"
83+
]
84+
string_refs = [
85+
"the cat sat on the hat",
86+
"the quick brown fox jumps over the lazy dog",
87+
"hello beautiful world"
88+
]
89+
tokenized_preds = [sentence.split() for sentence in string_preds]
90+
tokenized_refs = [sentence.split() for sentence in string_refs]
91+
92+
metrax_token_metric = None
93+
keras_metric = keras_hub.metrics.EditDistance(normalize=True)
94+
for pred, ref in zip(tokenized_preds, tokenized_refs):
95+
metrax_update = metrax.WER.from_model_output(pred,ref)
96+
keras_metric.update_state(ref, pred)
97+
metrax_token_metric = metrax_update if metrax_token_metric is None else metrax_token_metric.merge(metrax_update)
98+
99+
np.testing.assert_allclose(
100+
metrax_token_metric.compute(),
101+
keras_metric.result(),
102+
rtol=1e-05,
103+
atol=1e-05,
104+
err_msg="String-based WER should match keras_hub EditDistance"
105+
)
106+
107+
metrax_string_metric = None
108+
for pred, ref in zip(string_preds, string_refs):
109+
update = metrax.WER.from_model_output(predictions=pred, references=ref)
110+
metrax_string_metric = update if metrax_string_metric is None else metrax_string_metric.merge(update)
111+
112+
np.testing.assert_allclose(
113+
metrax_string_metric.compute(),
114+
metrax_token_metric.compute(),
115+
rtol=1e-05,
116+
atol=1e-05,
117+
err_msg="String input and tokenized input should produce the same WER"
118+
)
119+
71120

72121
if __name__ == '__main__':
73122
absltest.main()

0 commit comments

Comments
 (0)