Skip to content

Commit 8876108

Browse files
committed
Add WER metric
1 parent 61e5dc2 commit 8876108

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from metrax.nlp_metrics import (
2222
Perplexity,
23+
WER
2324
)
2425
from metrax.ranking_metrics import (
2526
AveragePrecisionAtK,
@@ -40,4 +41,5 @@
4041
"Recall",
4142
"RMSE",
4243
"RSQUARED",
44+
"WER",
4345
]

src/metrax/nlp_metrics.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,104 @@ def merge(self, other: 'Perplexity') -> 'Perplexity':
114114

115115
def compute(self) -> jax.Array:
116116
return jnp.exp(self.aggregate_crossentropy / self.num_samples)
117+
118+
119+
@flax.struct.dataclass
120+
class WER(clu_metrics.Average):
121+
r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks.
122+
123+
Word Error Rate measures the edit distance between reference texts and predictions,
124+
normalized by the length of the reference texts. It is calculated as:
125+
126+
.. math::
127+
WER = \frac{S + D + I}{N}
128+
129+
where:
130+
- S is the number of substitutions
131+
- D is the number of deletions
132+
- I is the number of insertions
133+
- N is the number of words in the reference
134+
135+
A lower WER indicates better performance, with 0 being perfect.
136+
137+
This implementation accepts both pre-tokenized inputs (lists of tokens) and untokenized
138+
strings. When strings are provided, they are tokenized by splitting on whitespace.
139+
"""
140+
141+
@classmethod
142+
def from_model_output(
143+
cls,
144+
predictions: list[str],
145+
references: list[str],
146+
) -> "WER":
147+
"""Updates the metric.
148+
149+
Args:
150+
prediction: Either a string or a list of tokens in the predicted sequence.
151+
reference: Either a string or a list of tokens in the reference sequence.
152+
153+
Returns:
154+
New WER metric instance.
155+
156+
Raises:
157+
ValueError: If inputs are not properly formatted or are empty.
158+
"""
159+
if not predictions or not references:
160+
raise ValueError("predictions and references must not be empty")
161+
162+
if isinstance(predictions, str):
163+
predictions = predictions.split()
164+
if isinstance(references, str):
165+
references = references.split()
166+
167+
edit_distance = cls._levenshtein_distance(predictions, references)
168+
reference_length = len(references)
169+
170+
return cls(
171+
total=jnp.array(edit_distance, dtype=jnp.float32),
172+
count=jnp.array(reference_length, dtype=jnp.float32),
173+
)
174+
175+
@staticmethod
176+
def _levenshtein_distance(prediction: list, reference: list) -> int:
177+
"""Computes the Levenshtein (edit) distance between two token sequences.
178+
179+
Args:
180+
prediction: List of tokens in the predicted sequence.
181+
reference: List of tokens in the reference sequence.
182+
183+
Returns:
184+
The minimum number of edits needed to transform prediction into reference.
185+
"""
186+
m, n = len(prediction), len(reference)
187+
188+
# Handle edge cases
189+
if m == 0:
190+
return n
191+
if n == 0:
192+
return m
193+
194+
# Create distance matrix
195+
distance_matrix = [[0 for _ in range(n + 1)] for _ in range(m + 1)]
196+
197+
# Initialize first row and column
198+
for i in range(m + 1):
199+
distance_matrix[i][0] = i
200+
for j in range(n + 1):
201+
distance_matrix[0][j] = j
202+
203+
# Fill the matrix
204+
for i in range(1, m + 1):
205+
for j in range(1, n + 1):
206+
if prediction[i - 1] == reference[j - 1]:
207+
cost = 0
208+
else:
209+
cost = 1
210+
211+
distance_matrix[i][j] = min(
212+
distance_matrix[i - 1][j] + 1, # deletion
213+
distance_matrix[i][j - 1] + 1, # insertion
214+
distance_matrix[i - 1][j - 1] + cost, # substitution
215+
)
216+
217+
return distance_matrix[m][n]

src/metrax/nlp_metrics_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def test_perplexity_empty(self):
3030
self.assertEqual(m.aggregate_crossentropy, jnp.array(0, jnp.float32))
3131
self.assertEqual(m.num_samples, jnp.array(0, jnp.float32))
3232

33+
def test_wer_empty(self):
34+
"""Tests the `empty` method of the `WER` class."""
35+
m = metrax.WER.empty()
36+
self.assertEqual(m.total, jnp.array(0, jnp.float32))
37+
self.assertEqual(m.count, jnp.array(0, jnp.float32))
38+
3339
@parameterized.named_parameters(
3440
(
3541
'basic',
@@ -68,6 +74,49 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
6874
atol=1e-05,
6975
)
7076

77+
def test_wer(self):
78+
"""Tests that WER metric computes correct values with tokenized and untokenized inputs."""
79+
string_preds = [
80+
"the cat sat on the mat",
81+
"a quick brown fox jumps over the lazy dog",
82+
"hello world"
83+
]
84+
string_refs = [
85+
"the cat sat on the hat",
86+
"the quick brown fox jumps over the lazy dog",
87+
"hello beautiful world"
88+
]
89+
tokenized_preds = [sentence.split() for sentence in string_preds]
90+
tokenized_refs = [sentence.split() for sentence in string_refs]
91+
92+
metrax_token_metric = None
93+
keras_metric = keras_hub.metrics.EditDistance(normalize=True)
94+
for pred, ref in zip(tokenized_preds, tokenized_refs):
95+
metrax_update = metrax.WER.from_model_output(pred,ref)
96+
keras_metric.update_state(ref, pred)
97+
metrax_token_metric = metrax_update if metrax_token_metric is None else metrax_token_metric.merge(metrax_update)
98+
99+
np.testing.assert_allclose(
100+
metrax_token_metric.compute(),
101+
keras_metric.result(),
102+
rtol=1e-05,
103+
atol=1e-05,
104+
err_msg="String-based WER should match keras_hub EditDistance"
105+
)
106+
107+
metrax_string_metric = None
108+
for pred, ref in zip(string_preds, string_refs):
109+
update = metrax.WER.from_model_output(predictions=pred, references=ref)
110+
metrax_string_metric = update if metrax_string_metric is None else metrax_string_metric.merge(update)
111+
112+
np.testing.assert_allclose(
113+
metrax_string_metric.compute(),
114+
metrax_token_metric.compute(),
115+
rtol=1e-05,
116+
atol=1e-05,
117+
err_msg="String input and tokenized input should produce the same WER"
118+
)
119+
71120

72121
if __name__ == '__main__':
73122
absltest.main()

0 commit comments

Comments
 (0)