@@ -114,3 +114,104 @@ def merge(self, other: 'Perplexity') -> 'Perplexity':
114114
115115 def compute (self ) -> jax .Array :
116116 return jnp .exp (self .aggregate_crossentropy / self .num_samples )
117+
118+
119+ @flax .struct .dataclass
120+ class WER (clu_metrics .Average ):
121+ r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks.
122+
123+ Word Error Rate measures the edit distance between reference texts and predictions,
124+ normalized by the length of the reference texts. It is calculated as:
125+
126+ .. math::
127+ WER = \frac{S + D + I}{N}
128+
129+ where:
130+ - S is the number of substitutions
131+ - D is the number of deletions
132+ - I is the number of insertions
133+ - N is the number of words in the reference
134+
135+ A lower WER indicates better performance, with 0 being perfect.
136+
137+ This implementation accepts both pre-tokenized inputs (lists of tokens) and untokenized
138+ strings. When strings are provided, they are tokenized by splitting on whitespace.
139+ """
140+
141+ @classmethod
142+ def from_model_output (
143+ cls ,
144+ predictions : list [str ],
145+ references : list [str ],
146+ ) -> "WER" :
147+ """Updates the metric.
148+
149+ Args:
150+ prediction: Either a string or a list of tokens in the predicted sequence.
151+ reference: Either a string or a list of tokens in the reference sequence.
152+
153+ Returns:
154+ New WER metric instance.
155+
156+ Raises:
157+ ValueError: If inputs are not properly formatted or are empty.
158+ """
159+ if not predictions or not references :
160+ raise ValueError ("predictions and references must not be empty" )
161+
162+ if isinstance (predictions , str ):
163+ predictions = predictions .split ()
164+ if isinstance (references , str ):
165+ references = references .split ()
166+
167+ edit_distance = cls ._levenshtein_distance (predictions , references )
168+ reference_length = len (references )
169+
170+ return cls (
171+ total = jnp .array (edit_distance , dtype = jnp .float32 ),
172+ count = jnp .array (reference_length , dtype = jnp .float32 ),
173+ )
174+
175+ @staticmethod
176+ def _levenshtein_distance (prediction : list , reference : list ) -> int :
177+ """Computes the Levenshtein (edit) distance between two token sequences.
178+
179+ Args:
180+ prediction: List of tokens in the predicted sequence.
181+ reference: List of tokens in the reference sequence.
182+
183+ Returns:
184+ The minimum number of edits needed to transform prediction into reference.
185+ """
186+ m , n = len (prediction ), len (reference )
187+
188+ # Handle edge cases
189+ if m == 0 :
190+ return n
191+ if n == 0 :
192+ return m
193+
194+ # Create distance matrix
195+ distance_matrix = [[0 for _ in range (n + 1 )] for _ in range (m + 1 )]
196+
197+ # Initialize first row and column
198+ for i in range (m + 1 ):
199+ distance_matrix [i ][0 ] = i
200+ for j in range (n + 1 ):
201+ distance_matrix [0 ][j ] = j
202+
203+ # Fill the matrix
204+ for i in range (1 , m + 1 ):
205+ for j in range (1 , n + 1 ):
206+ if prediction [i - 1 ] == reference [j - 1 ]:
207+ cost = 0
208+ else :
209+ cost = 1
210+
211+ distance_matrix [i ][j ] = min (
212+ distance_matrix [i - 1 ][j ] + 1 , # deletion
213+ distance_matrix [i ][j - 1 ] + 1 , # insertion
214+ distance_matrix [i - 1 ][j - 1 ] + cost , # substitution
215+ )
216+
217+ return distance_matrix [m ][n ]
0 commit comments