-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathtranscriber.py
More file actions
545 lines (440 loc) · 18.1 KB
/
transcriber.py
File metadata and controls
545 lines (440 loc) · 18.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
"""Transcription pipeline wrapper."""
import ipaddress
import os
import shutil
import socket
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import httpx
import pandas as pd
# SSRF Protection: Block internal/private IP ranges and metadata endpoints
BLOCKED_HOSTS = {
"localhost",
"127.0.0.1",
"0.0.0.0",
"169.254.169.254", # AWS/GCP metadata
"metadata.google.internal", # GCP metadata
"metadata.azure.internal", # Azure metadata
}
ALLOWED_SCHEMES = {"http", "https"}
def validate_audio_url(url: str) -> None:
"""Validate audio URL for SSRF protection.
Args:
url: URL to validate.
Raises:
ValueError: If URL is invalid or points to a blocked host.
"""
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Must be http or https.")
# Check hostname exists
if not parsed.hostname:
raise ValueError("Invalid URL: no hostname")
hostname = parsed.hostname.lower()
# Check against blocklist
if hostname in BLOCKED_HOSTS:
raise ValueError(f"Blocked host: {hostname}")
# Resolve hostname and check if it's a private IP
try:
resolved_ips = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in resolved_ips:
ip_str = sockaddr[0]
ip = ipaddress.ip_address(ip_str)
# Block private, loopback, link-local, and reserved IPs
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
raise ValueError(f"Blocked IP address: {ip_str} (resolved from {hostname})")
except socket.gaierror:
# If DNS resolution fails, let it proceed (httpx will fail with better error)
pass
def _ensure_ffmpeg() -> None:
"""Ensure ffmpeg is available, using bundled version if needed."""
if shutil.which("ffmpeg"):
return # System ffmpeg available
try:
from imageio_ffmpeg import get_ffmpeg_exe
ffmpeg_path = get_ffmpeg_exe()
# Create symlink directory with proper 'ffmpeg' name
# (imageio_ffmpeg binary has versioned name like ffmpeg-linux-x86_64-v7.0.2)
symlink_dir = Path(tempfile.gettempdir()) / "murmurai-ffmpeg"
symlink_dir.mkdir(exist_ok=True)
symlink_path = symlink_dir / "ffmpeg"
# Create/update symlink
if symlink_path.exists() or symlink_path.is_symlink():
symlink_path.unlink()
symlink_path.symlink_to(ffmpeg_path)
# Add to PATH
os.environ["PATH"] = str(symlink_dir) + os.pathsep + os.environ.get("PATH", "")
# Note: Can't use logger here - runs before logging is setup
except Exception as e:
# Note: Can't use logger here - runs before logging is setup
import sys
print(f"[murmurai] WARNING: Could not setup bundled ffmpeg: {e}", file=sys.stderr)
# Ensure ffmpeg is available BEFORE importing murmurai-core
_ensure_ffmpeg()
import murmurai as murmurai_core # type: ignore[import-untyped] # noqa: E402
from murmurai_server.config import get_settings # noqa: E402
from murmurai_server.logging import get_logger # noqa: E402
from murmurai_server.model_manager import ModelManager # noqa: E402
@dataclass
class TranscribeOptions:
"""Options for transcription pipeline."""
# Model selection (None = use server default from settings)
model: str | None = None
# Language
language: str | None = None
# Task
task: str = "transcribe"
# Speaker diarization
speaker_labels: bool = False
speakers_expected: int | None = None
min_speakers: int | None = None
max_speakers: int | None = None
diarize_model: str = "pyannote/speaker-diarization-community-1"
return_speaker_embeddings: bool = False
# Decoding parameters (🔴 MODEL PARAMS - trigger model reload if non-default)
temperature: float = 0.0
temperature_increment_on_fallback: float | None = 0.2
beam_size: int = 5
best_of: int = 5
patience: float = 1.0
length_penalty: float = 1.0
suppress_tokens: str | None = None
logprob_threshold: float | None = -1.0
# Prompt engineering
initial_prompt: str | None = None
hotwords: str | None = None
# Output control
word_timestamps: bool = False
return_char_alignments: bool = False
suppress_numerals: bool = False
interpolate_method: str = "nearest"
# Hallucination filtering
compression_ratio_threshold: float = 2.4
no_speech_threshold: float = 0.6
condition_on_previous_text: bool = False
# VAD parameters (🔴 MODEL PARAMS - trigger model reload if non-default)
vad_method: str = "pyannote"
vad_onset: float = 0.5
vad_offset: float = 0.363
chunk_size: int = 30
# Subtitle/segment formatting
segment_resolution: str = "sentence"
max_line_width: int | None = None
max_line_count: int | None = None
highlight_words: bool = False
def has_custom_asr_options(self) -> bool:
"""Check if any ASR options differ from defaults."""
return (
self.beam_size != 5
or self.best_of != 5
or self.patience != 1.0
or self.length_penalty != 1.0
or self.temperature != 0.0
or self.temperature_increment_on_fallback != 0.2
or self.compression_ratio_threshold != 2.4
or (self.logprob_threshold is not None and self.logprob_threshold != -1.0)
or self.no_speech_threshold != 0.6
or self.condition_on_previous_text
or self.suppress_numerals
or self.initial_prompt is not None
or self.hotwords is not None
or self.suppress_tokens is not None
)
def has_custom_vad_options(self) -> bool:
"""Check if any VAD options differ from defaults."""
return self.vad_onset != 0.5 or self.vad_offset != 0.363 or self.chunk_size != 30
def build_asr_options(self) -> dict[str, Any] | None:
"""Build ASR options dict for model loading, or None if defaults."""
if not self.has_custom_asr_options():
return None
# Build temperatures list
temps = [self.temperature]
if self.temperature_increment_on_fallback is not None:
t = self.temperature + self.temperature_increment_on_fallback
while t <= 1.0:
temps.append(t)
t += self.temperature_increment_on_fallback
# Parse suppress_tokens
suppress_tokens = [-1] # Default
if self.suppress_tokens:
suppress_tokens = [int(t.strip()) for t in self.suppress_tokens.split(",") if t.strip()]
return {
"beam_size": self.beam_size,
"best_of": self.best_of,
"patience": self.patience,
"length_penalty": self.length_penalty,
"temperatures": temps,
"compression_ratio_threshold": self.compression_ratio_threshold,
"log_prob_threshold": self.logprob_threshold or -1.0,
"no_speech_threshold": self.no_speech_threshold,
"condition_on_previous_text": self.condition_on_previous_text,
"suppress_numerals": self.suppress_numerals,
"initial_prompt": self.initial_prompt,
"hotwords": self.hotwords,
"suppress_tokens": suppress_tokens,
}
def build_vad_options(self) -> dict[str, Any] | None:
"""Build VAD options dict for model loading, or None if defaults."""
if not self.has_custom_vad_options():
return None
return {
"vad_onset": self.vad_onset,
"vad_offset": self.vad_offset,
"chunk_size": self.chunk_size,
}
def convert_pyannote_to_murmurai(diarization: Any) -> pd.DataFrame:
"""Convert pyannote Annotation to diarize_segments format.
Args:
diarization: pyannote.core.Annotation or pyannote 4.x DiarizeOutput.
Returns:
pandas DataFrame with columns: start, end, speaker
"""
# pyannote 4.x returns DiarizeOutput, extract the annotation
annotation = getattr(diarization, "speaker_diarization", diarization)
segments = []
for turn, _, speaker in annotation.itertracks(yield_label=True):
segments.append(
{
"start": turn.start,
"end": turn.end,
"speaker": speaker,
}
)
return pd.DataFrame(segments)
async def download_audio(url: str) -> Path:
"""Download audio from URL to temporary file with streaming.
Args:
url: URL to download audio from.
Returns:
Path to the downloaded temporary file.
Raises:
ValueError: If URL fails SSRF validation.
httpx.HTTPError: If download fails.
"""
# SSRF protection: validate URL before downloading
validate_audio_url(url)
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream("GET", url, follow_redirects=True) as response:
response.raise_for_status()
# Determine file extension from URL or default to .mp3
url_path = Path(url.split("?")[0]) # Remove query params
suffix = url_path.suffix if url_path.suffix else ".mp3"
# Create temp file and stream content
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
async for chunk in response.aiter_bytes():
temp_file.write(chunk)
temp_file.close()
return Path(temp_file.name)
def transcribe(
audio_path: Path,
options: TranscribeOptions,
progress_callback: Any = None,
) -> dict[str, Any]:
"""Run transcription pipeline.
Args:
audio_path: Path to audio file.
options: Transcription options.
progress_callback: Optional callback(progress: float) for progress updates.
Returns:
Formatted transcript result with words and utterances.
"""
settings = get_settings()
logger = get_logger()
# Log job start
logger.info(
f"Job started: {audio_path.name}",
extra={
"language": options.language or "auto-detect",
"speaker_labels": options.speaker_labels,
"word_timestamps": options.word_timestamps,
},
)
logger.debug(f" Language: {options.language or 'auto-detect'}")
logger.debug(f" Speaker labels: {options.speaker_labels}")
logger.debug(f" Word timestamps: {options.word_timestamps}")
# Build ASR/VAD options from request (None if defaults = fast path)
asr_options = options.build_asr_options()
vad_options = options.build_vad_options()
# Log if using custom options (will trigger model reload)
if asr_options or vad_options or options.vad_method != "pyannote":
logger.info("Using custom ASR/VAD options (may trigger model reload)...")
if asr_options:
logger.debug(
f" ASR: beam_size={asr_options.get('beam_size')}, temps={asr_options.get('temperatures')}"
)
if vad_options:
logger.debug(
f" VAD: onset={vad_options.get('vad_onset')}, offset={vad_options.get('vad_offset')}"
)
# Get model (fast path if defaults, slow path if custom options or model)
model = ModelManager.get_model(
model_name=options.model,
asr_options=asr_options,
vad_options=vad_options,
vad_method=options.vad_method,
)
# Use request language, fall back to config default, then auto-detect
effective_language = options.language or settings.language
# Load audio
audio = murmurai_core.load_audio(str(audio_path))
if progress_callback:
progress_callback(0.1) # Audio loaded
# Build transcription kwargs (only runtime params supported by transcribe())
transcribe_kwargs: dict[str, Any] = {
"batch_size": settings.batch_size,
"language": effective_language,
}
# Add optional runtime parameters
if options.task != "transcribe":
transcribe_kwargs["task"] = options.task
if options.chunk_size != 30:
transcribe_kwargs["chunk_size"] = options.chunk_size
# Transcribe (ASR/VAD options are baked into the model)
result = model.transcribe(audio, **transcribe_kwargs)
if progress_callback:
progress_callback(0.5) # Transcription done
# Get detected or specified language
detected_language = result["language"]
# Align for word-level timestamps (if enabled)
if options.word_timestamps:
align_model, metadata = ModelManager.get_align_model(detected_language)
result = murmurai_core.align(
result["segments"],
align_model,
metadata,
audio,
device="cuda",
return_char_alignments=options.return_char_alignments,
interpolate_method=options.interpolate_method,
)
if progress_callback:
progress_callback(0.8) # Alignment done
# Speaker diarization (if requested)
speaker_embeddings = None
if options.speaker_labels:
diarize_pipeline = ModelManager.get_diarize_model(options.diarize_model)
# Determine min/max speakers
min_spk = options.min_speakers
max_spk = options.max_speakers
if options.speakers_expected is not None:
min_spk = min_spk or options.speakers_expected
max_spk = max_spk or options.speakers_expected
# Pass waveform dict to avoid file re-read
# pyannote 4.x expects torch Tensor, murmurai returns numpy array
import torch
waveform = torch.from_numpy(audio[None, :])
diarization = diarize_pipeline(
{"waveform": waveform, "sample_rate": 16000},
min_speakers=min_spk,
max_speakers=max_spk,
return_embeddings=options.return_speaker_embeddings,
)
# Extract speaker embeddings if requested
if options.return_speaker_embeddings and hasattr(diarization, "embeddings"):
speaker_embeddings = {
speaker: emb.tolist() for speaker, emb in diarization.embeddings.items()
}
# Convert pyannote output to murmurai format
diarize_segments = convert_pyannote_to_murmurai(diarization)
result = murmurai_core.assign_word_speakers(diarize_segments, result)
if progress_callback:
progress_callback(0.95) # Diarization done
formatted = format_result(
result,
detected_language,
speaker_embeddings,
speaker_labels=options.speaker_labels,
word_timestamps=options.word_timestamps,
)
if progress_callback:
progress_callback(1.0) # Complete
# Log job completion
segment_count = len(result.get("segments", []))
word_count = len(formatted.get("words", []))
logger.info(
f"Job completed: {segment_count} segments, {word_count} words",
extra={
"segments": segment_count,
"words": word_count,
"language": detected_language,
},
)
return formatted
def format_result(
result: dict[str, Any],
language: str,
speaker_embeddings: dict[str, list[float]] | None = None,
speaker_labels: bool = False,
word_timestamps: bool = False,
) -> dict[str, Any]:
"""Format result to API response format.
Args:
result: Raw result with segments.
language: Detected/specified language code.
speaker_embeddings: Optional speaker embedding vectors.
speaker_labels: Whether speaker diarization was requested.
word_timestamps: Whether word-level timestamps were requested.
Returns:
Formatted transcript with words and utterances.
"""
words: list[dict[str, Any]] = []
utterances: list[dict[str, Any]] = []
for segment in result.get("segments", []):
# Only include speaker if diarization was requested
speaker = segment.get("speaker") if speaker_labels else None
utterance_words: list[dict[str, Any]] = []
for word in segment.get("words", []):
word_data: dict[str, Any] = {
"text": word.get("word", ""),
"start": int(word.get("start", 0) * 1000), # Convert to ms
"end": int(word.get("end", 0) * 1000),
"confidence": word.get("score", 0.0),
}
# Only include speaker if diarization was requested and speaker exists
if speaker:
word_data["speaker"] = speaker
words.append(word_data)
utterance_words.append(word_data)
# Build utterance from segment
utterance: dict[str, Any] = {
"text": segment.get("text", "").strip(),
"start": int(segment.get("start", 0) * 1000),
"end": int(segment.get("end", 0) * 1000),
"words": utterance_words,
}
# Only include speaker if diarization was requested and speaker exists
if speaker:
utterance["speaker"] = speaker
# Only include confidence if we have word-level data
if utterance_words:
utterance["confidence"] = sum(w["confidence"] for w in utterance_words) / len(
utterance_words
)
utterances.append(utterance)
# Calculate overall metrics
full_text = " ".join(s.get("text", "").strip() for s in result.get("segments", []))
# Audio duration: use word-level if available, otherwise use utterance end times
if words:
audio_duration = max((w["end"] for w in words), default=0)
elif utterances:
audio_duration = max((u["end"] for u in utterances), default=0)
else:
audio_duration = 0
formatted: dict[str, Any] = {
"text": full_text,
"words": words,
"utterances": utterances,
"audio_duration": audio_duration,
"language_code": language,
}
# Only include confidence if we have word-level data
if words:
formatted["confidence"] = sum(w["confidence"] for w in words) / len(words)
# Include speaker embeddings if available
if speaker_embeddings:
formatted["speaker_embeddings"] = speaker_embeddings
return formatted