@@ -165,7 +165,7 @@ class RSQUARED(clu_metrics.Metric):
165165 count : jax .Array
166166 sum_of_squared_error : jax .Array
167167 sum_of_squared_label : jax .Array
168-
168+
169169
170170 @classmethod
171171 def empty (cls ) -> 'RSQUARED' :
@@ -436,7 +436,7 @@ class AUCPR(clu_metrics.Metric):
436436 false_positives : jax .Array
437437 false_negatives : jax .Array
438438 num_thresholds : int
439-
439+
440440 @classmethod
441441 def empty (cls ) -> 'AUCPR' :
442442 return cls (
@@ -795,3 +795,137 @@ def merge(self, other: 'Perplexity') -> 'Perplexity':
795795
796796 def compute (self ) -> jax .Array :
797797 return jnp .exp (self .aggregate_crossentropy / self .num_samples )
798+
799+
800+ @flax .struct .dataclass
801+ class WER (clu_metrics .Metric ):
802+ r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks.
803+
804+ Word Error Rate measures the edit distance between reference texts and predictions,
805+ normalized by the length of the reference texts. It is calculated as:
806+
807+ .. math::
808+ WER = \frac{S + D + I}{N}
809+
810+ where:
811+ - S is the number of substitutions
812+ - D is the number of deletions
813+ - I is the number of insertions
814+ - N is the number of words in the reference
815+
816+ A lower WER indicates better performance, with 0 being perfect.
817+
818+ Attributes:
819+ total_edit_distance: Sum of edit distances across all samples.
820+ total_reference_length: Sum of reference lengths across all samples.
821+ """
822+
823+ total_edit_distance : jax .Array
824+ total_reference_length : jax .Array
825+
826+ @classmethod
827+ def empty (cls ) -> 'WER' :
828+ return cls (
829+ total_edit_distance = jnp .array (0 , jnp .float32 ),
830+ total_reference_length = jnp .array (0 , jnp .float32 ))
831+
832+ @classmethod
833+ def from_model_output (
834+ cls ,
835+ predictions : list [str ] | list [list ],
836+ references : list [str ] | list [list ],
837+ ) -> 'WER' :
838+ """Updates the metric.
839+
840+ Args:
841+ predictions: A list of predicted texts/transcriptions or tokenized sequences.
842+ references: A list of reference texts/transcriptions or tokenized sequences.
843+
844+ Returns:
845+ Updated WER metric.
846+
847+ Raises:
848+ ValueError: If inputs are not properly formatted or are empty.
849+ """
850+ if not predictions or not references :
851+ raise ValueError ('predictions and references must not be empty' )
852+
853+ if len (predictions ) != len (references ):
854+ raise ValueError (
855+ f'Length mismatch: predictions has { len (predictions )} items, '
856+ f'but references has { len (references )} items'
857+ )
858+
859+ # Determine if inputs are strings that need tokenization or already tokenized
860+ total_edit_distance = 0
861+ total_reference_length = 0
862+
863+ for pred , ref in zip (predictions , references ):
864+ # Convert to tokens if needed
865+ pred_tokens = pred .split () if isinstance (pred , str ) else pred
866+ ref_tokens = ref .split () if isinstance (ref , str ) else ref
867+
868+ edit_distance = cls ._levenshtein_distance (pred_tokens , ref_tokens )
869+
870+ total_edit_distance += edit_distance
871+ total_reference_length += len (ref_tokens )
872+
873+ return cls (
874+ total_edit_distance = jnp .array (total_edit_distance , dtype = jnp .float32 ),
875+ total_reference_length = jnp .array (total_reference_length , dtype = jnp .float32 ),
876+ )
877+
878+ @staticmethod
879+ def _levenshtein_distance (prediction : list , reference : list ) -> int :
880+ """Computes the Levenshtein (edit) distance between two token sequences.
881+
882+ Args:
883+ prediction: List of tokens in the predicted sequence.
884+ reference: List of tokens in the reference sequence.
885+
886+ Returns:
887+ The minimum number of edits needed to transform prediction into reference.
888+ """
889+ m , n = len (prediction ), len (reference )
890+
891+ # Handle edge cases
892+ if m == 0 :
893+ return n
894+ if n == 0 :
895+ return m
896+
897+ # Create distance matrix
898+ distance_matrix = [[0 for _ in range (n + 1 )] for _ in range (m + 1 )]
899+
900+ # Initialize first row and column
901+ for i in range (m + 1 ):
902+ distance_matrix [i ][0 ] = i
903+ for j in range (n + 1 ):
904+ distance_matrix [0 ][j ] = j
905+
906+ # Fill the matrix
907+ for i in range (1 , m + 1 ):
908+ for j in range (1 , n + 1 ):
909+ if prediction [i - 1 ] == reference [j - 1 ]:
910+ cost = 0
911+ else :
912+ cost = 1
913+
914+ distance_matrix [i ][j ] = min (
915+ distance_matrix [i - 1 ][j ] + 1 , # deletion
916+ distance_matrix [i ][j - 1 ] + 1 , # insertion
917+ distance_matrix [i - 1 ][j - 1 ] + cost # substitution
918+ )
919+
920+ return distance_matrix [m ][n ]
921+
922+ def merge (self , other : 'WER' ) -> 'WER' :
923+ return type (self )(
924+ total_edit_distance = self .total_edit_distance + other .total_edit_distance ,
925+ total_reference_length = self .total_reference_length + other .total_reference_length ,
926+ )
927+
928+ def compute (self ) -> jax .Array :
929+ return _divide_no_nan (
930+ self .total_edit_distance , self .total_reference_length
931+ )
0 commit comments