Skip to content

Commit 3fc5184

Browse files
feat: integrate Neptune long-video benchmark tasks (#1187)
* LMM-271: [P0][Benchmark] Neptune long-video benchmark integration... * fix(neptune): cap chat video frame loading and document missing full videos * style: auto-fix lint (black + isort) --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 0059a43 commit 3fc5184

20 files changed

Lines changed: 727 additions & 295 deletions

File tree

docs/current_tasks.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,10 @@ python -m lmms_eval --tasks list_with_num
345345
- [LongTimescope](https://longtimescope.github.io/) (longtimescope)
346346
- [LongVT](https://longvt-bench.github.io/) (longvt) - Tool-based long video understanding
347347
- [LongVideoBench](https://github.com/longvideobench/LongVideoBench) (longvideobench)
348+
- [NEPTUNE](https://github.com/google-deepmind/neptune) (neptune)
349+
- Video-path subsets: neptune_full_v, neptune_mma_v, neptune_mmh_v
350+
- Frame-sampled subsets: neptune_full_i, neptune_mma_i, neptune_mmh_i
351+
- Example: `python -m lmms_eval --model qwen2_5_vl --tasks neptune_full_v --limit 5 --batch_size 1`
348352
- [MovieChat](https://github.com/rese1f/MovieChat) (moviechat)
349353
- Global Mode for entire video (moviechat_global)
350354
- Breakpoint Mode for specific moments (moviechat_breakpoint)

lmms_eval/models/chat/openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def build_payload_for_index(global_index: int) -> dict:
181181
temperature = request_gen_kwargs.get("temperature", 0)
182182

183183
payload = {
184-
"messages": chat_messages.to_openai_messages(),
184+
"messages": chat_messages.to_openai_messages(video_kwargs={"nframes": self.max_frames_num}),
185185
"model": self.model_version,
186186
"max_tokens": max_new_tokens,
187187
"temperature": temperature,
@@ -212,6 +212,7 @@ def build_payload_for_index(global_index: int) -> dict:
212212
cursor += 1
213213
continue
214214

215+
assert payload is not None
215216
future = executor.submit(process_single_request, request_index, payload)
216217
in_flight[future] = request_index
217218
cursor += 1

lmms_eval/models/simple/audio_flamingo_3.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import numpy as np
66
import soundfile as sf
77
import torch
8+
import transformers
89
from accelerate import Accelerator, DistributedType
910
from loguru import logger as eval_logger
1011
from tqdm import tqdm
11-
import transformers
1212
from transformers import AutoProcessor
1313

1414
try:
@@ -53,11 +53,7 @@ def __init__(
5353
self.device_map = f"cuda:{accelerator.local_process_index}"
5454

5555
if AudioFlamingo3ForConditionalGeneration is None:
56-
raise ImportError(
57-
"AudioFlamingo3ForConditionalGeneration is not available in transformers "
58-
f"{transformers.__version__}. Please upgrade transformers/accelerate in this env, e.g. "
59-
"`pip install -U transformers accelerate`."
60-
)
56+
raise ImportError("AudioFlamingo3ForConditionalGeneration is not available in transformers " f"{transformers.__version__}. Please upgrade transformers/accelerate in this env, e.g. " "`pip install -U transformers accelerate`.")
6157

6258
self._model = AudioFlamingo3ForConditionalGeneration.from_pretrained(
6359
pretrained,

lmms_eval/tasks/ami/utils.py

Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
import re
33
import string
4+
45
import numpy as np
56
from loguru import logger as eval_logger
7+
68
from lmms_eval.llm_judge import ServerConfig, get_server
79

810
API_TYPE = os.getenv("API_TYPE", "openai")
@@ -36,23 +38,23 @@ def remove_punctuation_except_apostrophe(text):
3638

3739
def ami_doc_to_audio(doc):
3840
"""Extract audio from AMI dataset document
39-
41+
4042
AMI dataset uses AudioDecoder type with get_all_samples() method.
4143
Returns audio array and sampling rate (16kHz for AMI).
4244
"""
4345
audio_file = doc.get("audio")
44-
46+
4547
if not audio_file:
4648
eval_logger.warning(f"No audio found in document. Available keys: {list(doc.keys())}")
4749
return []
48-
50+
4951
try:
5052
# AMI uses AudioDecoder type with get_all_samples() method
5153
if hasattr(audio_file, "get_all_samples"):
5254
decoded_audio = audio_file.get_all_samples()
5355
else:
5456
decoded_audio = audio_file
55-
57+
5658
# Extract array - check for data attribute first (AudioSamples object)
5759
if hasattr(decoded_audio, "data"):
5860
# AudioSamples object from torchcodec
@@ -68,13 +70,13 @@ def ami_doc_to_audio(doc):
6870
audio_array = decoded_audio.array
6971
else:
7072
audio_array = decoded_audio
71-
73+
7274
# Convert torch tensor to numpy if needed
7375
if hasattr(audio_array, "cpu") and hasattr(audio_array, "numpy"):
7476
audio_array = audio_array.cpu().numpy()
7577
elif hasattr(audio_array, "numpy"):
7678
audio_array = audio_array.numpy()
77-
79+
7880
# Ensure it's a numpy array and flatten if needed
7981
if not isinstance(audio_array, np.ndarray):
8082
try:
@@ -86,27 +88,28 @@ def ami_doc_to_audio(doc):
8688
audio_array = np.array(audio_array.tolist())
8789
else:
8890
raise
89-
91+
9092
# Ensure it's 1D array (flatten if multi-channel)
9193
if audio_array.ndim > 1:
9294
audio_array = audio_array.flatten()
93-
95+
9496
# Ensure float32 dtype for librosa compatibility
9597
if audio_array.dtype != np.float32:
9698
audio_array = audio_array.astype(np.float32)
97-
99+
98100
# Get sampling rate (AMI is 16kHz)
99101
sampling_rate = getattr(audio_file, "_desired_sample_rate", 16000)
100-
102+
101103
eval_logger.debug(f"Audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}, sampling_rate: {sampling_rate}")
102-
104+
103105
return [{"array": audio_array, "sampling_rate": sampling_rate}]
104-
106+
105107
except Exception as e:
106108
eval_logger.error(f"Error extracting audio: {e}")
107109
eval_logger.error(f"Audio type: {type(audio_file)}, attributes: {dir(audio_file)}")
108110
# Re-raise to help debug
109111
import traceback
112+
110113
eval_logger.error(f"Traceback: {traceback.format_exc()}")
111114
return []
112115

@@ -115,14 +118,14 @@ def ami_doc_to_text(doc, lmms_eval_specific_kwargs):
115118
"""Generate prompt for the audio model"""
116119
pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "")
117120
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "")
118-
121+
119122
# Get meeting context if needed
120123
meeting_id = get_column_value(doc, ["meeting_id"])
121124
speaker_id = get_column_value(doc, ["speaker_id"])
122-
125+
123126
# Default prompt for speech recognition
124127
default_prompt = "Please transcribe the following audio. Only provide the transcription without any additional explanation or formatting."
125-
128+
126129
return f"{pre_prompt}{default_prompt}{post_prompt}"
127130

128131

@@ -132,35 +135,35 @@ def ami_process_results_asr(doc, results):
132135
Calculates Word Error Rate (WER) - case insensitive.
133136
"""
134137
scores = []
135-
138+
136139
# Get ground truth
137140
ground_truth = get_column_value(doc, ["text", "transcript", "transcription"])
138141
if not ground_truth:
139142
eval_logger.warning("No ground truth text found in document")
140143
return {"wer": 1.0}
141-
144+
142145
# Normalize: strip and lowercase for case-insensitive comparison
143146
ground_truth = ground_truth.strip().lower()
144-
147+
145148
# Remove all punctuation except apostrophe
146149
ground_truth = remove_punctuation_except_apostrophe(ground_truth)
147-
150+
148151
for pred in results:
149152
prediction = pred.strip() if isinstance(pred, str) else str(pred)
150-
153+
151154
# Extract transcription from various formats
152155
prediction = extract_transcription(prediction)
153-
156+
154157
# Normalize: strip and lowercase for case-insensitive comparison
155158
prediction = prediction.strip().lower()
156-
159+
157160
# Remove all punctuation except apostrophe
158161
prediction = remove_punctuation_except_apostrophe(prediction)
159-
162+
160163
# Calculate Word Error Rate
161164
wer = calculate_wer(ground_truth, prediction)
162165
scores.append(wer)
163-
166+
164167
avg_wer = sum(scores) / len(scores) if scores else 1.0
165168
return {"wer": avg_wer}
166169

@@ -172,50 +175,47 @@ def extract_transcription(text):
172175
"""
173176
if not isinstance(text, str):
174177
return str(text)
175-
178+
176179
text = text.strip()
177-
180+
178181
# Pattern 1: XML-style tags
179182
for tag in ["<answer>", "<response>", "<result>", "<transcription>", "<text>"]:
180183
closing_tag = tag.replace("<", "</")
181184
pattern = f"{re.escape(tag)}\\s*([\\s\\S]*?)\\s*{re.escape(closing_tag)}"
182185
match = re.search(pattern, text, re.IGNORECASE)
183186
if match:
184187
return match.group(1).strip()
185-
188+
186189
# Pattern 2: "The transcription of the audio is:" followed by text in quotes
187190
patterns = [
188191
r"(?:the\s+)?transcription\s+(?:of\s+)?(?:the\s+)?audio\s+is\s*:\s*['\"](.+?)['\"]\s*\.?\s*$",
189192
r"(?:the\s+)?original\s+content\s+(?:of\s+)?(?:this\s+)?audio\s+is\s*:\s*['\"](.+?)['\"]\s*\.?\s*$",
190193
r"(?:the\s+)?(?:audio|speech)\s+(?:content|transcription|says)\s*:\s*['\"](.+?)['\"]\s*\.?\s*$",
191194
]
192-
195+
193196
for pattern in patterns:
194197
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
195198
if match:
196199
return match.group(1).strip()
197-
200+
198201
# Pattern 3: Text enclosed in quotes (single or double)
199-
quote_patterns = [
200-
r"^['\"](.+?)['\"]\s*\.?\s*$", # Entire text in quotes
201-
r"['\"]([^'\"]{20,})['\"]" # Long text in quotes (at least 20 chars)
202-
]
203-
202+
quote_patterns = [r"^['\"](.+?)['\"]\s*\.?\s*$", r"['\"]([^'\"]{20,})['\"]"] # Entire text in quotes # Long text in quotes (at least 20 chars)
203+
204204
for pattern in quote_patterns:
205205
match = re.search(pattern, text, re.DOTALL)
206206
if match:
207207
return match.group(1).strip()
208-
208+
209209
# Pattern 4: Remove common prefixes
210210
prefixes_to_remove = [
211211
r"^(?:here\s+is\s+)?(?:the\s+)?transcription\s*(?:of\s+(?:the\s+)?audio)?\s*:\s*",
212212
r"^(?:the\s+)?(?:audio|speech)\s+(?:says|contains)\s*:\s*",
213213
r"^(?:answer|response|result)\s*:\s*",
214214
]
215-
215+
216216
for prefix in prefixes_to_remove:
217217
text = re.sub(prefix, "", text, flags=re.IGNORECASE)
218-
218+
219219
return text.strip()
220220

221221

@@ -232,28 +232,28 @@ def calculate_wer(reference, hypothesis):
232232
# Split into words
233233
ref_words = reference.split()
234234
hyp_words = hypothesis.split()
235-
235+
236236
# Build edit distance matrix
237237
n, m = len(ref_words), len(hyp_words)
238238
dp = [[0] * (m + 1) for _ in range(n + 1)]
239-
239+
240240
# Initialize
241241
for i in range(n + 1):
242242
dp[i][0] = i
243243
for j in range(m + 1):
244244
dp[0][j] = j
245-
245+
246246
# Dynamic programming
247247
for i in range(1, n + 1):
248248
for j in range(1, m + 1):
249-
if ref_words[i-1] == hyp_words[j-1]:
250-
dp[i][j] = dp[i-1][j-1]
249+
if ref_words[i - 1] == hyp_words[j - 1]:
250+
dp[i][j] = dp[i - 1][j - 1]
251251
else:
252-
substitution = dp[i-1][j-1] + 1
253-
insertion = dp[i][j-1] + 1
254-
deletion = dp[i-1][j] + 1
252+
substitution = dp[i - 1][j - 1] + 1
253+
insertion = dp[i][j - 1] + 1
254+
deletion = dp[i - 1][j] + 1
255255
dp[i][j] = min(substitution, insertion, deletion)
256-
256+
257257
# Calculate WER
258258
wer = dp[n][m] / max(n, 1) # Avoid division by zero
259259
return wer
@@ -311,13 +311,7 @@ def ami_process_results_llm_judge(doc, results):
311311

312312
custom_config = ServerConfig(model_name=JUDGE_MODEL_VERSION, temperature=0.5, max_tokens=10)
313313

314-
request = Request(
315-
messages=[
316-
{"role": "system", "content": "You are a helpful assistant who evaluates speech recognition quality."},
317-
{"role": "user", "content": formatted_prompt}
318-
],
319-
config=custom_config
320-
)
314+
request = Request(messages=[{"role": "system", "content": "You are a helpful assistant who evaluates speech recognition quality."}, {"role": "user", "content": formatted_prompt}], config=custom_config)
321315

322316
response = server.evaluate(request)
323317

@@ -349,16 +343,16 @@ def ami_aggregate_results(results):
349343
return 0.0
350344

351345
total_count = len(results)
352-
346+
353347
# If results are WER scores, return average
354348
if all(isinstance(r, (int, float)) for r in results):
355349
avg_score = sum(results) / total_count
356350
eval_logger.info(f"AMI evaluation: Average score: {avg_score:.4f}")
357351
return avg_score
358-
352+
359353
# If results are boolean (correct/incorrect), calculate accuracy
360354
correct_count = sum(results)
361355
accuracy = correct_count / total_count if total_count > 0 else 0.0
362356
eval_logger.info(f"AMI evaluation: {correct_count}/{total_count} correct, accuracy: {accuracy:.4f}")
363-
357+
364358
return accuracy

0 commit comments

Comments
 (0)