@@ -414,5 +414,111 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
414414 )
415415
416416
417+ def test_wer_empty (self ):
418+ """Tests the `empty` method of the `WER` class."""
419+ m = metrax .WER .empty ()
420+ self .assertEqual (m .total_edit_distance , jnp .array (0 , jnp .float32 ))
421+ self .assertEqual (m .total_reference_length , jnp .array (0 , jnp .float32 ))
422+
423+ def test_wer (self ):
424+ """Tests that WER metric computes correct values."""
425+ # Test with string inputs
426+ predictions = [
427+ "the cat sat on the mat" ,
428+ "a quick brown fox jumps over the lazy dog" ,
429+ "hello world"
430+ ]
431+ references = [
432+ "the cat sat on the hat" , # 1 substitution (mat->hat), 6 total words
433+ "the quick brown fox jumps over the lazy dog" , # 1 substitution (a->the), 9 total words
434+ "hello beautiful world" # 1 insertion (beautiful), 3 total words
435+ ]
436+
437+ # Expected individual WERs: 1/6, 1/9, 1/3
438+ # Total edit distance: 1 + 1 + 1 = 3
439+ # Total reference length: 6 + 9 + 3 = 18
440+ # Expected WER: 3/18 = 0.1667
441+
442+ metric = None
443+ for pred , ref in zip (predictions , references ):
444+ update = metrax .WER .from_model_output (
445+ predictions = [pred ],
446+ references = [ref ],
447+ )
448+ metric = update if metric is None else metric .merge (update )
449+
450+ np .testing .assert_allclose (
451+ metric .compute (),
452+ jnp .array (3 / 18 , dtype = jnp .float32 ),
453+ rtol = 1e-05 ,
454+ atol = 1e-05 ,
455+ )
456+
457+ def test_wer_with_tokens (self ):
458+ """Tests that WER metric computes correct values with tokenized inputs."""
459+ # Test with token inputs (lists instead of strings)
460+ tokenized_preds = [
461+ ["the" , "cat" , "sat" , "on" , "the" , "mat" ],
462+ ["a" , "quick" , "brown" , "fox" , "jumps" , "over" , "the" , "lazy" , "dog" ],
463+ ["hello" , "world" ]
464+ ]
465+ tokenized_refs = [
466+ ["the" , "cat" , "sat" , "on" , "the" , "hat" ],
467+ ["the" , "quick" , "brown" , "fox" , "jumps" , "over" , "the" , "lazy" , "dog" ],
468+ ["hello" , "beautiful" , "world" ]
469+ ]
470+
471+ metric = None
472+ for pred , ref in zip (tokenized_preds , tokenized_refs ):
473+ update = metrax .WER .from_model_output (
474+ predictions = [pred ],
475+ references = [ref ],
476+ )
477+ metric = update if metric is None else metric .merge (update )
478+
479+ np .testing .assert_allclose (
480+ metric .compute (),
481+ jnp .array (3 / 18 , dtype = jnp .float32 ),
482+ rtol = 1e-05 ,
483+ atol = 1e-05 ,
484+ )
485+
486+ def test_wer_merge (self ):
487+ """Tests the merge functionality of the WER metric."""
488+ predictions1 = ["the cat sat on the mat" ]
489+ references1 = ["the cat sat on the hat" ] # 1/6 WER
490+
491+ predictions2 = [
492+ "a quick brown fox jumps over the lazy dog" ,
493+ "hello world"
494+ ]
495+ references2 = [
496+ "the quick brown fox jumps over the lazy dog" ,
497+ "hello beautiful world"
498+ ] # (1+1)/(9+3) = 2/12 WER
499+
500+ # Create and compute first metric
501+ metric1 = metrax .WER .from_model_output (
502+ predictions = predictions1 ,
503+ references = references1 ,
504+ )
505+
506+ # Create and compute second metric
507+ metric2 = metrax .WER .from_model_output (
508+ predictions = predictions2 ,
509+ references = references2 ,
510+ )
511+
512+ # Merge and compute
513+ merged_metric = metric1 .merge (metric2 )
514+
515+ np .testing .assert_allclose (
516+ merged_metric .compute (),
517+ jnp .array (3 / 18 , dtype = jnp .float32 ),
518+ rtol = 1e-05 ,
519+ atol = 1e-05 ,
520+ )
521+
522+
417523if __name__ == '__main__' :
418524 absltest .main ()
0 commit comments