11import os
22import re
33import string
4+
45import numpy as np
56from loguru import logger as eval_logger
7+
68from lmms_eval .llm_judge import ServerConfig , get_server
79
810API_TYPE = os .getenv ("API_TYPE" , "openai" )
@@ -36,23 +38,23 @@ def remove_punctuation_except_apostrophe(text):
3638
3739def ami_doc_to_audio (doc ):
3840 """Extract audio from AMI dataset document
39-
41+
4042 AMI dataset uses AudioDecoder type with get_all_samples() method.
4143 Returns audio array and sampling rate (16kHz for AMI).
4244 """
4345 audio_file = doc .get ("audio" )
44-
46+
4547 if not audio_file :
4648 eval_logger .warning (f"No audio found in document. Available keys: { list (doc .keys ())} " )
4749 return []
48-
50+
4951 try :
5052 # AMI uses AudioDecoder type with get_all_samples() method
5153 if hasattr (audio_file , "get_all_samples" ):
5254 decoded_audio = audio_file .get_all_samples ()
5355 else :
5456 decoded_audio = audio_file
55-
57+
5658 # Extract array - check for data attribute first (AudioSamples object)
5759 if hasattr (decoded_audio , "data" ):
5860 # AudioSamples object from torchcodec
@@ -68,13 +70,13 @@ def ami_doc_to_audio(doc):
6870 audio_array = decoded_audio .array
6971 else :
7072 audio_array = decoded_audio
71-
73+
7274 # Convert torch tensor to numpy if needed
7375 if hasattr (audio_array , "cpu" ) and hasattr (audio_array , "numpy" ):
7476 audio_array = audio_array .cpu ().numpy ()
7577 elif hasattr (audio_array , "numpy" ):
7678 audio_array = audio_array .numpy ()
77-
79+
7880 # Ensure it's a numpy array and flatten if needed
7981 if not isinstance (audio_array , np .ndarray ):
8082 try :
@@ -86,27 +88,28 @@ def ami_doc_to_audio(doc):
8688 audio_array = np .array (audio_array .tolist ())
8789 else :
8890 raise
89-
91+
9092 # Ensure it's 1D array (flatten if multi-channel)
9193 if audio_array .ndim > 1 :
9294 audio_array = audio_array .flatten ()
93-
95+
9496 # Ensure float32 dtype for librosa compatibility
9597 if audio_array .dtype != np .float32 :
9698 audio_array = audio_array .astype (np .float32 )
97-
99+
98100 # Get sampling rate (AMI is 16kHz)
99101 sampling_rate = getattr (audio_file , "_desired_sample_rate" , 16000 )
100-
102+
101103 eval_logger .debug (f"Audio array shape: { audio_array .shape } , dtype: { audio_array .dtype } , sampling_rate: { sampling_rate } " )
102-
104+
103105 return [{"array" : audio_array , "sampling_rate" : sampling_rate }]
104-
106+
105107 except Exception as e :
106108 eval_logger .error (f"Error extracting audio: { e } " )
107109 eval_logger .error (f"Audio type: { type (audio_file )} , attributes: { dir (audio_file )} " )
108110 # Re-raise to help debug
109111 import traceback
112+
110113 eval_logger .error (f"Traceback: { traceback .format_exc ()} " )
111114 return []
112115
@@ -115,14 +118,14 @@ def ami_doc_to_text(doc, lmms_eval_specific_kwargs):
115118 """Generate prompt for the audio model"""
116119 pre_prompt = lmms_eval_specific_kwargs .get ("pre_prompt" , "" )
117120 post_prompt = lmms_eval_specific_kwargs .get ("post_prompt" , "" )
118-
121+
119122 # Get meeting context if needed
120123 meeting_id = get_column_value (doc , ["meeting_id" ])
121124 speaker_id = get_column_value (doc , ["speaker_id" ])
122-
125+
123126 # Default prompt for speech recognition
124127 default_prompt = "Please transcribe the following audio. Only provide the transcription without any additional explanation or formatting."
125-
128+
126129 return f"{ pre_prompt } { default_prompt } { post_prompt } "
127130
128131
@@ -132,35 +135,35 @@ def ami_process_results_asr(doc, results):
132135 Calculates Word Error Rate (WER) - case insensitive.
133136 """
134137 scores = []
135-
138+
136139 # Get ground truth
137140 ground_truth = get_column_value (doc , ["text" , "transcript" , "transcription" ])
138141 if not ground_truth :
139142 eval_logger .warning ("No ground truth text found in document" )
140143 return {"wer" : 1.0 }
141-
144+
142145 # Normalize: strip and lowercase for case-insensitive comparison
143146 ground_truth = ground_truth .strip ().lower ()
144-
147+
145148 # Remove all punctuation except apostrophe
146149 ground_truth = remove_punctuation_except_apostrophe (ground_truth )
147-
150+
148151 for pred in results :
149152 prediction = pred .strip () if isinstance (pred , str ) else str (pred )
150-
153+
151154 # Extract transcription from various formats
152155 prediction = extract_transcription (prediction )
153-
156+
154157 # Normalize: strip and lowercase for case-insensitive comparison
155158 prediction = prediction .strip ().lower ()
156-
159+
157160 # Remove all punctuation except apostrophe
158161 prediction = remove_punctuation_except_apostrophe (prediction )
159-
162+
160163 # Calculate Word Error Rate
161164 wer = calculate_wer (ground_truth , prediction )
162165 scores .append (wer )
163-
166+
164167 avg_wer = sum (scores ) / len (scores ) if scores else 1.0
165168 return {"wer" : avg_wer }
166169
@@ -172,50 +175,47 @@ def extract_transcription(text):
172175 """
173176 if not isinstance (text , str ):
174177 return str (text )
175-
178+
176179 text = text .strip ()
177-
180+
178181 # Pattern 1: XML-style tags
179182 for tag in ["<answer>" , "<response>" , "<result>" , "<transcription>" , "<text>" ]:
180183 closing_tag = tag .replace ("<" , "</" )
181184 pattern = f"{ re .escape (tag )} \\ s*([\\ s\\ S]*?)\\ s*{ re .escape (closing_tag )} "
182185 match = re .search (pattern , text , re .IGNORECASE )
183186 if match :
184187 return match .group (1 ).strip ()
185-
188+
186189 # Pattern 2: "The transcription of the audio is:" followed by text in quotes
187190 patterns = [
188191 r"(?:the\s+)?transcription\s+(?:of\s+)?(?:the\s+)?audio\s+is\s*:\s*['\"](.+?)['\"]\s*\.?\s*$" ,
189192 r"(?:the\s+)?original\s+content\s+(?:of\s+)?(?:this\s+)?audio\s+is\s*:\s*['\"](.+?)['\"]\s*\.?\s*$" ,
190193 r"(?:the\s+)?(?:audio|speech)\s+(?:content|transcription|says)\s*:\s*['\"](.+?)['\"]\s*\.?\s*$" ,
191194 ]
192-
195+
193196 for pattern in patterns :
194197 match = re .search (pattern , text , re .IGNORECASE | re .DOTALL )
195198 if match :
196199 return match .group (1 ).strip ()
197-
200+
198201 # Pattern 3: Text enclosed in quotes (single or double)
199- quote_patterns = [
200- r"^['\"](.+?)['\"]\s*\.?\s*$" , # Entire text in quotes
201- r"['\"]([^'\"]{20,})['\"]" # Long text in quotes (at least 20 chars)
202- ]
203-
202+ quote_patterns = [r"^['\"](.+?)['\"]\s*\.?\s*$" , r"['\"]([^'\"]{20,})['\"]" ] # Entire text in quotes # Long text in quotes (at least 20 chars)
203+
204204 for pattern in quote_patterns :
205205 match = re .search (pattern , text , re .DOTALL )
206206 if match :
207207 return match .group (1 ).strip ()
208-
208+
209209 # Pattern 4: Remove common prefixes
210210 prefixes_to_remove = [
211211 r"^(?:here\s+is\s+)?(?:the\s+)?transcription\s*(?:of\s+(?:the\s+)?audio)?\s*:\s*" ,
212212 r"^(?:the\s+)?(?:audio|speech)\s+(?:says|contains)\s*:\s*" ,
213213 r"^(?:answer|response|result)\s*:\s*" ,
214214 ]
215-
215+
216216 for prefix in prefixes_to_remove :
217217 text = re .sub (prefix , "" , text , flags = re .IGNORECASE )
218-
218+
219219 return text .strip ()
220220
221221
@@ -232,28 +232,28 @@ def calculate_wer(reference, hypothesis):
232232 # Split into words
233233 ref_words = reference .split ()
234234 hyp_words = hypothesis .split ()
235-
235+
236236 # Build edit distance matrix
237237 n , m = len (ref_words ), len (hyp_words )
238238 dp = [[0 ] * (m + 1 ) for _ in range (n + 1 )]
239-
239+
240240 # Initialize
241241 for i in range (n + 1 ):
242242 dp [i ][0 ] = i
243243 for j in range (m + 1 ):
244244 dp [0 ][j ] = j
245-
245+
246246 # Dynamic programming
247247 for i in range (1 , n + 1 ):
248248 for j in range (1 , m + 1 ):
249- if ref_words [i - 1 ] == hyp_words [j - 1 ]:
250- dp [i ][j ] = dp [i - 1 ][j - 1 ]
249+ if ref_words [i - 1 ] == hyp_words [j - 1 ]:
250+ dp [i ][j ] = dp [i - 1 ][j - 1 ]
251251 else :
252- substitution = dp [i - 1 ][j - 1 ] + 1
253- insertion = dp [i ][j - 1 ] + 1
254- deletion = dp [i - 1 ][j ] + 1
252+ substitution = dp [i - 1 ][j - 1 ] + 1
253+ insertion = dp [i ][j - 1 ] + 1
254+ deletion = dp [i - 1 ][j ] + 1
255255 dp [i ][j ] = min (substitution , insertion , deletion )
256-
256+
257257 # Calculate WER
258258 wer = dp [n ][m ] / max (n , 1 ) # Avoid division by zero
259259 return wer
@@ -311,13 +311,7 @@ def ami_process_results_llm_judge(doc, results):
311311
312312 custom_config = ServerConfig (model_name = JUDGE_MODEL_VERSION , temperature = 0.5 , max_tokens = 10 )
313313
314- request = Request (
315- messages = [
316- {"role" : "system" , "content" : "You are a helpful assistant who evaluates speech recognition quality." },
317- {"role" : "user" , "content" : formatted_prompt }
318- ],
319- config = custom_config
320- )
314+ request = Request (messages = [{"role" : "system" , "content" : "You are a helpful assistant who evaluates speech recognition quality." }, {"role" : "user" , "content" : formatted_prompt }], config = custom_config )
321315
322316 response = server .evaluate (request )
323317
@@ -349,16 +343,16 @@ def ami_aggregate_results(results):
349343 return 0.0
350344
351345 total_count = len (results )
352-
346+
353347 # If results are WER scores, return average
354348 if all (isinstance (r , (int , float )) for r in results ):
355349 avg_score = sum (results ) / total_count
356350 eval_logger .info (f"AMI evaluation: Average score: { avg_score :.4f} " )
357351 return avg_score
358-
352+
359353 # If results are boolean (correct/incorrect), calculate accuracy
360354 correct_count = sum (results )
361355 accuracy = correct_count / total_count if total_count > 0 else 0.0
362356 eval_logger .info (f"AMI evaluation: { correct_count } /{ total_count } correct, accuracy: { accuracy :.4f} " )
363-
357+
364358 return accuracy
0 commit comments