Skip to content

Commit 6db29d0

Browse files
committed
Add WER metric (#28)
1 parent 21e0837 commit 6db29d0

File tree

3 files changed

+235
-2
lines changed

3 files changed

+235
-2
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
RMSE,
2222
RSQUARED,
2323
Recall,
24+
WER,
2425
)
2526

2627
__all__ = [
@@ -32,4 +33,5 @@
3233
"Recall",
3334
"AUCPR",
3435
"AUCROC",
36+
"WER",
3537
]

src/metrax/metrics.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class RSQUARED(clu_metrics.Metric):
165165
count: jax.Array
166166
sum_of_squared_error: jax.Array
167167
sum_of_squared_label: jax.Array
168-
168+
169169

170170
@classmethod
171171
def empty(cls) -> 'RSQUARED':
@@ -436,7 +436,7 @@ class AUCPR(clu_metrics.Metric):
436436
false_positives: jax.Array
437437
false_negatives: jax.Array
438438
num_thresholds: int
439-
439+
440440
@classmethod
441441
def empty(cls) -> 'AUCPR':
442442
return cls(
@@ -795,3 +795,137 @@ def merge(self, other: 'Perplexity') -> 'Perplexity':
795795

796796
def compute(self) -> jax.Array:
797797
return jnp.exp(self.aggregate_crossentropy / self.num_samples)
798+
799+
800+
@flax.struct.dataclass
801+
class WER(clu_metrics.Metric):
802+
r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks.
803+
804+
Word Error Rate measures the edit distance between reference texts and predictions,
805+
normalized by the length of the reference texts. It is calculated as:
806+
807+
.. math::
808+
WER = \frac{S + D + I}{N}
809+
810+
where:
811+
- S is the number of substitutions
812+
- D is the number of deletions
813+
- I is the number of insertions
814+
- N is the number of words in the reference
815+
816+
A lower WER indicates better performance, with 0 being perfect.
817+
818+
Attributes:
819+
total_edit_distance: Sum of edit distances across all samples.
820+
total_reference_length: Sum of reference lengths across all samples.
821+
"""
822+
823+
total_edit_distance: jax.Array
824+
total_reference_length: jax.Array
825+
826+
@classmethod
827+
def empty(cls) -> 'WER':
828+
return cls(
829+
total_edit_distance=jnp.array(0, jnp.float32),
830+
total_reference_length=jnp.array(0, jnp.float32))
831+
832+
@classmethod
833+
def from_model_output(
834+
cls,
835+
predictions: list[str] | list[list],
836+
references: list[str] | list[list],
837+
) -> 'WER':
838+
"""Updates the metric.
839+
840+
Args:
841+
predictions: A list of predicted texts/transcriptions or tokenized sequences.
842+
references: A list of reference texts/transcriptions or tokenized sequences.
843+
844+
Returns:
845+
Updated WER metric.
846+
847+
Raises:
848+
ValueError: If inputs are not properly formatted or are empty.
849+
"""
850+
if not predictions or not references:
851+
raise ValueError('predictions and references must not be empty')
852+
853+
if len(predictions) != len(references):
854+
raise ValueError(
855+
f'Length mismatch: predictions has {len(predictions)} items, '
856+
f'but references has {len(references)} items'
857+
)
858+
859+
# Determine if inputs are strings that need tokenization or already tokenized
860+
total_edit_distance = 0
861+
total_reference_length = 0
862+
863+
for pred, ref in zip(predictions, references):
864+
# Convert to tokens if needed
865+
pred_tokens = pred.split() if isinstance(pred, str) else pred
866+
ref_tokens = ref.split() if isinstance(ref, str) else ref
867+
868+
edit_distance = cls._levenshtein_distance(pred_tokens, ref_tokens)
869+
870+
total_edit_distance += edit_distance
871+
total_reference_length += len(ref_tokens)
872+
873+
return cls(
874+
total_edit_distance=jnp.array(total_edit_distance, dtype=jnp.float32),
875+
total_reference_length=jnp.array(total_reference_length, dtype=jnp.float32),
876+
)
877+
878+
@staticmethod
879+
def _levenshtein_distance(prediction: list, reference: list) -> int:
880+
"""Computes the Levenshtein (edit) distance between two token sequences.
881+
882+
Args:
883+
prediction: List of tokens in the predicted sequence.
884+
reference: List of tokens in the reference sequence.
885+
886+
Returns:
887+
The minimum number of edits needed to transform prediction into reference.
888+
"""
889+
m, n = len(prediction), len(reference)
890+
891+
# Handle edge cases
892+
if m == 0:
893+
return n
894+
if n == 0:
895+
return m
896+
897+
# Create distance matrix
898+
distance_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)]
899+
900+
# Initialize first row and column
901+
for i in range(m+1):
902+
distance_matrix[i][0] = i
903+
for j in range(n+1):
904+
distance_matrix[0][j] = j
905+
906+
# Fill the matrix
907+
for i in range(1, m+1):
908+
for j in range(1, n+1):
909+
if prediction[i-1] == reference[j-1]:
910+
cost = 0
911+
else:
912+
cost = 1
913+
914+
distance_matrix[i][j] = min(
915+
distance_matrix[i-1][j] + 1, # deletion
916+
distance_matrix[i][j-1] + 1, # insertion
917+
distance_matrix[i-1][j-1] + cost # substitution
918+
)
919+
920+
return distance_matrix[m][n]
921+
922+
def merge(self, other: 'WER') -> 'WER':
923+
return type(self)(
924+
total_edit_distance=self.total_edit_distance + other.total_edit_distance,
925+
total_reference_length=self.total_reference_length + other.total_reference_length,
926+
)
927+
928+
def compute(self) -> jax.Array:
929+
return _divide_no_nan(
930+
self.total_edit_distance, self.total_reference_length
931+
)

src/metrax/metrics_test.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,103 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
455455
atol=1e-05,
456456
)
457457

458+
def test_wer_empty(self):
459+
"""Tests the `empty` method of the `WER` class."""
460+
m = metrax.WER.empty()
461+
self.assertEqual(m.total_edit_distance, jnp.array(0, jnp.float32))
462+
self.assertEqual(m.total_reference_length, jnp.array(0, jnp.float32))
463+
464+
def test_wer(self):
465+
"""Tests that WER metric computes correct values."""
466+
# Test with string inputs
467+
predictions = [
468+
"the cat sat on the mat",
469+
"a quick brown fox jumps over the lazy dog",
470+
"hello world"
471+
]
472+
references = [
473+
"the cat sat on the hat",
474+
"the quick brown fox jumps over the lazy dog",
475+
"hello beautiful world"
476+
]
477+
478+
metric = None
479+
for pred, ref in zip(predictions, references):
480+
update = metrax.WER.from_model_output(
481+
predictions=[pred],
482+
references=[ref],
483+
)
484+
metric = update if metric is None else metric.merge(update)
485+
486+
np.testing.assert_allclose(
487+
metric.compute(),
488+
jnp.array(3/18, dtype=jnp.float32),
489+
rtol=1e-05,
490+
atol=1e-05,
491+
)
492+
493+
def test_wer_with_tokens(self):
494+
"""Tests that WER metric computes correct values with tokenized inputs."""
495+
# Test with token inputs (lists of strings)
496+
tokenized_preds = [
497+
["the", "cat", "sat", "on", "the", "mat"],
498+
["a", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"],
499+
["hello", "world"]
500+
]
501+
tokenized_refs = [
502+
["the", "cat", "sat", "on", "the", "hat"],
503+
["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"],
504+
["hello", "beautiful", "world"]
505+
]
506+
507+
metric = None
508+
for pred, ref in zip(tokenized_preds, tokenized_refs):
509+
update = metrax.WER.from_model_output(
510+
predictions=[pred],
511+
references=[ref],
512+
)
513+
metric = update if metric is None else metric.merge(update)
514+
515+
np.testing.assert_allclose(
516+
metric.compute(),
517+
jnp.array(3/18, dtype=jnp.float32),
518+
rtol=1e-05,
519+
atol=1e-05,
520+
)
521+
522+
def test_wer_merge(self):
523+
"""Tests the merge functionality of the WER metric."""
524+
predictions1 = ["the cat sat on the mat"]
525+
references1 = ["the cat sat on the hat"]
526+
527+
predictions2 = [
528+
"a quick brown fox jumps over the lazy dog",
529+
"hello world"
530+
]
531+
references2 = [
532+
"the quick brown fox jumps over the lazy dog",
533+
"hello beautiful world"
534+
]
535+
536+
metric1 = metrax.WER.from_model_output(
537+
predictions=predictions1,
538+
references=references1,
539+
)
540+
541+
metric2 = metrax.WER.from_model_output(
542+
predictions=predictions2,
543+
references=references2,
544+
)
545+
546+
merged_metric = metric1.merge(metric2)
547+
548+
np.testing.assert_allclose(
549+
merged_metric.compute(),
550+
jnp.array(3/18, dtype=jnp.float32),
551+
rtol=1e-05,
552+
atol=1e-05,
553+
)
554+
458555

459556
if __name__ == '__main__':
460557
os.environ['XLA_FLAGS'] = (

0 commit comments

Comments
 (0)