Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions lmms_eval/models/simple/audio_flamingo_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import numpy as np
import soundfile as sf
import torch
import transformers
from accelerate import Accelerator, DistributedType
from loguru import logger as eval_logger
from tqdm import tqdm
import transformers
from transformers import AutoProcessor

try:
Expand Down Expand Up @@ -53,11 +53,7 @@ def __init__(
self.device_map = f"cuda:{accelerator.local_process_index}"

if AudioFlamingo3ForConditionalGeneration is None:
raise ImportError(
"AudioFlamingo3ForConditionalGeneration is not available in transformers "
f"{transformers.__version__}. Please upgrade transformers/accelerate in this env, e.g. "
"`pip install -U transformers accelerate`."
)
raise ImportError("AudioFlamingo3ForConditionalGeneration is not available in transformers " f"{transformers.__version__}. Please upgrade transformers/accelerate in this env, e.g. " "`pip install -U transformers accelerate`.")

self._model = AudioFlamingo3ForConditionalGeneration.from_pretrained(
pretrained,
Expand Down
106 changes: 50 additions & 56 deletions lmms_eval/tasks/ami/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import re
import string

import numpy as np
from loguru import logger as eval_logger

from lmms_eval.llm_judge import ServerConfig, get_server

API_TYPE = os.getenv("API_TYPE", "openai")
Expand Down Expand Up @@ -36,23 +38,23 @@ def remove_punctuation_except_apostrophe(text):

def ami_doc_to_audio(doc):
"""Extract audio from AMI dataset document

AMI dataset uses AudioDecoder type with get_all_samples() method.
Returns audio array and sampling rate (16kHz for AMI).
"""
audio_file = doc.get("audio")

if not audio_file:
eval_logger.warning(f"No audio found in document. Available keys: {list(doc.keys())}")
return []

try:
# AMI uses AudioDecoder type with get_all_samples() method
if hasattr(audio_file, "get_all_samples"):
decoded_audio = audio_file.get_all_samples()
else:
decoded_audio = audio_file

# Extract array - check for data attribute first (AudioSamples object)
if hasattr(decoded_audio, "data"):
# AudioSamples object from torchcodec
Expand All @@ -68,13 +70,13 @@ def ami_doc_to_audio(doc):
audio_array = decoded_audio.array
else:
audio_array = decoded_audio

# Convert torch tensor to numpy if needed
if hasattr(audio_array, "cpu") and hasattr(audio_array, "numpy"):
audio_array = audio_array.cpu().numpy()
elif hasattr(audio_array, "numpy"):
audio_array = audio_array.numpy()

# Ensure it's a numpy array and flatten if needed
if not isinstance(audio_array, np.ndarray):
try:
Expand All @@ -86,27 +88,28 @@ def ami_doc_to_audio(doc):
audio_array = np.array(audio_array.tolist())
else:
raise

# Ensure it's 1D array (flatten if multi-channel)
if audio_array.ndim > 1:
audio_array = audio_array.flatten()

# Ensure float32 dtype for librosa compatibility
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32)

# Get sampling rate (AMI is 16kHz)
sampling_rate = getattr(audio_file, "_desired_sample_rate", 16000)

eval_logger.debug(f"Audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}, sampling_rate: {sampling_rate}")

return [{"array": audio_array, "sampling_rate": sampling_rate}]

except Exception as e:
eval_logger.error(f"Error extracting audio: {e}")
eval_logger.error(f"Audio type: {type(audio_file)}, attributes: {dir(audio_file)}")
# Re-raise to help debug
import traceback

eval_logger.error(f"Traceback: {traceback.format_exc()}")
return []

Expand All @@ -115,14 +118,14 @@ def ami_doc_to_text(doc, lmms_eval_specific_kwargs):
"""Generate prompt for the audio model"""
pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "")
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "")

# Get meeting context if needed
meeting_id = get_column_value(doc, ["meeting_id"])
speaker_id = get_column_value(doc, ["speaker_id"])

# Default prompt for speech recognition
default_prompt = "Please transcribe the following audio. Only provide the transcription without any additional explanation or formatting."

return f"{pre_prompt}{default_prompt}{post_prompt}"


Expand All @@ -132,35 +135,35 @@ def ami_process_results_asr(doc, results):
Calculates Word Error Rate (WER) - case insensitive.
"""
scores = []

# Get ground truth
ground_truth = get_column_value(doc, ["text", "transcript", "transcription"])
if not ground_truth:
eval_logger.warning("No ground truth text found in document")
return {"wer": 1.0}

# Normalize: strip and lowercase for case-insensitive comparison
ground_truth = ground_truth.strip().lower()

# Remove all punctuation except apostrophe
ground_truth = remove_punctuation_except_apostrophe(ground_truth)

for pred in results:
prediction = pred.strip() if isinstance(pred, str) else str(pred)

# Extract transcription from various formats
prediction = extract_transcription(prediction)

# Normalize: strip and lowercase for case-insensitive comparison
prediction = prediction.strip().lower()

# Remove all punctuation except apostrophe
prediction = remove_punctuation_except_apostrophe(prediction)

# Calculate Word Error Rate
wer = calculate_wer(ground_truth, prediction)
scores.append(wer)

avg_wer = sum(scores) / len(scores) if scores else 1.0
return {"wer": avg_wer}

Expand All @@ -172,50 +175,47 @@ def extract_transcription(text):
"""
if not isinstance(text, str):
return str(text)

text = text.strip()

# Pattern 1: XML-style tags
for tag in ["<answer>", "<response>", "<result>", "<transcription>", "<text>"]:
closing_tag = tag.replace("<", "</")
pattern = f"{re.escape(tag)}\\s*([\\s\\S]*?)\\s*{re.escape(closing_tag)}"
match = re.search(pattern, text, re.IGNORECASE)
if match:
return match.group(1).strip()

# Pattern 2: "The transcription of the audio is:" followed by text in quotes
patterns = [
r"(?:the\s+)?transcription\s+(?:of\s+)?(?:the\s+)?audio\s+is\s*:\s*['\"](.+?)['\"]\s*\.?\s*$",
r"(?:the\s+)?original\s+content\s+(?:of\s+)?(?:this\s+)?audio\s+is\s*:\s*['\"](.+?)['\"]\s*\.?\s*$",
r"(?:the\s+)?(?:audio|speech)\s+(?:content|transcription|says)\s*:\s*['\"](.+?)['\"]\s*\.?\s*$",
]

for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip()

# Pattern 3: Text enclosed in quotes (single or double)
quote_patterns = [
r"^['\"](.+?)['\"]\s*\.?\s*$", # Entire text in quotes
r"['\"]([^'\"]{20,})['\"]" # Long text in quotes (at least 20 chars)
]

quote_patterns = [r"^['\"](.+?)['\"]\s*\.?\s*$", r"['\"]([^'\"]{20,})['\"]"] # Entire text in quotes # Long text in quotes (at least 20 chars)

for pattern in quote_patterns:
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()

# Pattern 4: Remove common prefixes
prefixes_to_remove = [
r"^(?:here\s+is\s+)?(?:the\s+)?transcription\s*(?:of\s+(?:the\s+)?audio)?\s*:\s*",
r"^(?:the\s+)?(?:audio|speech)\s+(?:says|contains)\s*:\s*",
r"^(?:answer|response|result)\s*:\s*",
]

for prefix in prefixes_to_remove:
text = re.sub(prefix, "", text, flags=re.IGNORECASE)

return text.strip()


Expand All @@ -232,28 +232,28 @@ def calculate_wer(reference, hypothesis):
# Split into words
ref_words = reference.split()
hyp_words = hypothesis.split()

# Build edit distance matrix
n, m = len(ref_words), len(hyp_words)
dp = [[0] * (m + 1) for _ in range(n + 1)]

# Initialize
for i in range(n + 1):
dp[i][0] = i
for j in range(m + 1):
dp[0][j] = j

# Dynamic programming
for i in range(1, n + 1):
for j in range(1, m + 1):
if ref_words[i-1] == hyp_words[j-1]:
dp[i][j] = dp[i-1][j-1]
if ref_words[i - 1] == hyp_words[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
substitution = dp[i-1][j-1] + 1
insertion = dp[i][j-1] + 1
deletion = dp[i-1][j] + 1
substitution = dp[i - 1][j - 1] + 1
insertion = dp[i][j - 1] + 1
deletion = dp[i - 1][j] + 1
dp[i][j] = min(substitution, insertion, deletion)

# Calculate WER
wer = dp[n][m] / max(n, 1) # Avoid division by zero
return wer
Expand Down Expand Up @@ -311,13 +311,7 @@ def ami_process_results_llm_judge(doc, results):

custom_config = ServerConfig(model_name=JUDGE_MODEL_VERSION, temperature=0.5, max_tokens=10)

request = Request(
messages=[
{"role": "system", "content": "You are a helpful assistant who evaluates speech recognition quality."},
{"role": "user", "content": formatted_prompt}
],
config=custom_config
)
request = Request(messages=[{"role": "system", "content": "You are a helpful assistant who evaluates speech recognition quality."}, {"role": "user", "content": formatted_prompt}], config=custom_config)

response = server.evaluate(request)

Expand Down Expand Up @@ -349,16 +343,16 @@ def ami_aggregate_results(results):
return 0.0

total_count = len(results)

# If results are WER scores, return average
if all(isinstance(r, (int, float)) for r in results):
avg_score = sum(results) / total_count
eval_logger.info(f"AMI evaluation: Average score: {avg_score:.4f}")
return avg_score

# If results are boolean (correct/incorrect), calculate accuracy
correct_count = sum(results)
accuracy = correct_count / total_count if total_count > 0 else 0.0
eval_logger.info(f"AMI evaluation: {correct_count}/{total_count} correct, accuracy: {accuracy:.4f}")

return accuracy
Loading