Skip to content

Commit 5cbea54

Browse files
committed
Add WER metric (#28)
1 parent 44644ee commit 5cbea54

File tree

3 files changed

+245
-0
lines changed

3 files changed

+245
-0
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
RMSE,
2222
RSQUARED,
2323
Recall,
24+
WER,
2425
)
2526

2627
__all__ = [
@@ -32,4 +33,5 @@
3233
"Recall",
3334
"AUCPR",
3435
"AUCROC",
36+
"WER",
3537
]

src/metrax/metrics.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,140 @@ def merge(self, other: 'Perplexity') -> 'Perplexity':
795795

796796
def compute(self) -> jax.Array:
797797
return jnp.exp(self.aggregate_crossentropy / self.num_samples)
798+
799+
800+
@flax.struct.dataclass
801+
class WER(clu_metrics.Metric):
802+
r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks.
803+
804+
Word Error Rate measures the edit distance between reference texts and predictions,
805+
normalized by the length of the reference texts. It is calculated as:
806+
807+
.. math::
808+
WER = \frac{S + D + I}{N}
809+
810+
where:
811+
- S is the number of substitutions
812+
- D is the number of deletions
813+
- I is the number of insertions
814+
- N is the number of words in the reference
815+
816+
A lower WER indicates better performance, with 0 being perfect.
817+
818+
Attributes:
819+
total_edit_distance: Sum of edit distances across all samples.
820+
total_reference_length: Sum of reference lengths across all samples.
821+
"""
822+
823+
total_edit_distance: jax.Array
824+
total_reference_length: jax.Array
825+
826+
@classmethod
827+
def empty(cls) -> 'WER':
828+
return cls(
829+
total_edit_distance=jnp.array(0, jnp.float32),
830+
total_reference_length=jnp.array(0, jnp.float32))
831+
832+
@classmethod
833+
def from_model_output(
834+
cls,
835+
predictions: list[str] | list[list],
836+
references: list[str] | list[list],
837+
) -> 'WER':
838+
"""Updates the metric.
839+
840+
Args:
841+
predictions: A list of predicted texts/transcriptions or tokenized sequences.
842+
references: A list of reference texts/transcriptions or tokenized sequences.
843+
844+
Returns:
845+
Updated WER metric.
846+
847+
Raises:
848+
ValueError: If inputs are not properly formatted or are empty.
849+
"""
850+
if not predictions or not references:
851+
raise ValueError('predictions and references must not be empty')
852+
853+
if len(predictions) != len(references):
854+
raise ValueError(
855+
f'Length mismatch: predictions has {len(predictions)} items, '
856+
f'but references has {len(references)} items'
857+
)
858+
859+
# Determine if inputs are strings that need tokenization or already tokenized
860+
total_edit_distance = 0
861+
total_reference_length = 0
862+
863+
for pred, ref in zip(predictions, references):
864+
# Convert to tokens if needed
865+
pred_tokens = pred.split() if isinstance(pred, str) else pred
866+
ref_tokens = ref.split() if isinstance(ref, str) else ref
867+
868+
# Calculate edit distance
869+
edit_distance = cls._levenshtein_distance(pred_tokens, ref_tokens)
870+
871+
# Update totals
872+
total_edit_distance += edit_distance
873+
total_reference_length += len(ref_tokens)
874+
875+
return cls(
876+
total_edit_distance=jnp.array(total_edit_distance, dtype=jnp.float32),
877+
total_reference_length=jnp.array(total_reference_length, dtype=jnp.float32),
878+
)
879+
880+
@staticmethod
881+
def _levenshtein_distance(prediction: list, reference: list) -> int:
882+
"""Computes the Levenshtein (edit) distance between two token sequences.
883+
884+
Args:
885+
prediction: List of tokens in the predicted sequence.
886+
reference: List of tokens in the reference sequence.
887+
888+
Returns:
889+
The minimum number of edits needed to transform prediction into reference.
890+
"""
891+
# Create a matrix to store the edit distances
892+
m, n = len(prediction), len(reference)
893+
894+
# Handle edge cases
895+
if m == 0:
896+
return n
897+
if n == 0:
898+
return m
899+
900+
# Create distance matrix
901+
distance_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)]
902+
903+
# Initialize first row and column
904+
for i in range(m+1):
905+
distance_matrix[i][0] = i
906+
for j in range(n+1):
907+
distance_matrix[0][j] = j
908+
909+
# Fill the matrix
910+
for i in range(1, m+1):
911+
for j in range(1, n+1):
912+
if prediction[i-1] == reference[j-1]:
913+
cost = 0
914+
else:
915+
cost = 1
916+
917+
distance_matrix[i][j] = min(
918+
distance_matrix[i-1][j] + 1, # deletion
919+
distance_matrix[i][j-1] + 1, # insertion
920+
distance_matrix[i-1][j-1] + cost # substitution
921+
)
922+
923+
return distance_matrix[m][n]
924+
925+
def merge(self, other: 'WER') -> 'WER':
926+
return type(self)(
927+
total_edit_distance=self.total_edit_distance + other.total_edit_distance,
928+
total_reference_length=self.total_reference_length + other.total_reference_length,
929+
)
930+
931+
def compute(self) -> jax.Array:
932+
return _divide_no_nan(
933+
self.total_edit_distance, self.total_reference_length
934+
)

src/metrax/metrics_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,5 +414,111 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
414414
)
415415

416416

417+
def test_wer_empty(self):
418+
"""Tests the `empty` method of the `WER` class."""
419+
m = metrax.WER.empty()
420+
self.assertEqual(m.total_edit_distance, jnp.array(0, jnp.float32))
421+
self.assertEqual(m.total_reference_length, jnp.array(0, jnp.float32))
422+
423+
def test_wer(self):
424+
"""Tests that WER metric computes correct values."""
425+
# Test with string inputs
426+
predictions = [
427+
"the cat sat on the mat",
428+
"a quick brown fox jumps over the lazy dog",
429+
"hello world"
430+
]
431+
references = [
432+
"the cat sat on the hat", # 1 substitution (mat->hat), 6 total words
433+
"the quick brown fox jumps over the lazy dog", # 1 substitution (a->the), 9 total words
434+
"hello beautiful world" # 1 insertion (beautiful), 3 total words
435+
]
436+
437+
# Expected individual WERs: 1/6, 1/9, 1/3
438+
# Total edit distance: 1 + 1 + 1 = 3
439+
# Total reference length: 6 + 9 + 3 = 18
440+
# Expected WER: 3/18 = 0.1667
441+
442+
metric = None
443+
for pred, ref in zip(predictions, references):
444+
update = metrax.WER.from_model_output(
445+
predictions=[pred],
446+
references=[ref],
447+
)
448+
metric = update if metric is None else metric.merge(update)
449+
450+
np.testing.assert_allclose(
451+
metric.compute(),
452+
jnp.array(3/18, dtype=jnp.float32),
453+
rtol=1e-05,
454+
atol=1e-05,
455+
)
456+
457+
def test_wer_with_tokens(self):
458+
"""Tests that WER metric computes correct values with tokenized inputs."""
459+
# Test with token inputs (lists instead of strings)
460+
tokenized_preds = [
461+
["the", "cat", "sat", "on", "the", "mat"],
462+
["a", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"],
463+
["hello", "world"]
464+
]
465+
tokenized_refs = [
466+
["the", "cat", "sat", "on", "the", "hat"],
467+
["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"],
468+
["hello", "beautiful", "world"]
469+
]
470+
471+
metric = None
472+
for pred, ref in zip(tokenized_preds, tokenized_refs):
473+
update = metrax.WER.from_model_output(
474+
predictions=[pred],
475+
references=[ref],
476+
)
477+
metric = update if metric is None else metric.merge(update)
478+
479+
np.testing.assert_allclose(
480+
metric.compute(),
481+
jnp.array(3/18, dtype=jnp.float32),
482+
rtol=1e-05,
483+
atol=1e-05,
484+
)
485+
486+
def test_wer_merge(self):
487+
"""Tests the merge functionality of the WER metric."""
488+
predictions1 = ["the cat sat on the mat"]
489+
references1 = ["the cat sat on the hat"] # 1/6 WER
490+
491+
predictions2 = [
492+
"a quick brown fox jumps over the lazy dog",
493+
"hello world"
494+
]
495+
references2 = [
496+
"the quick brown fox jumps over the lazy dog",
497+
"hello beautiful world"
498+
] # (1+1)/(9+3) = 2/12 WER
499+
500+
# Create and compute first metric
501+
metric1 = metrax.WER.from_model_output(
502+
predictions=predictions1,
503+
references=references1,
504+
)
505+
506+
# Create and compute second metric
507+
metric2 = metrax.WER.from_model_output(
508+
predictions=predictions2,
509+
references=references2,
510+
)
511+
512+
# Merge and compute
513+
merged_metric = metric1.merge(metric2)
514+
515+
np.testing.assert_allclose(
516+
merged_metric.compute(),
517+
jnp.array(3/18, dtype=jnp.float32),
518+
rtol=1e-05,
519+
atol=1e-05,
520+
)
521+
522+
417523
if __name__ == '__main__':
418524
absltest.main()

0 commit comments

Comments
 (0)