Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ transcriptions-old/
temp/
users.db
start.sh
.DS_Store
49 changes: 38 additions & 11 deletions src/murmurai_server/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,15 @@ class ModelManager:

@classmethod
def _hash_options(
cls, asr_options: dict | None, vad_options: dict | None, vad_method: str
cls,
model_name: str | None,
asr_options: dict | None,
vad_options: dict | None,
vad_method: str,
) -> str:
"""Create hash key for model options."""
data = {
"model": model_name,
"asr": asr_options or {},
"vad": vad_options or {},
"vad_method": vad_method,
Expand All @@ -68,13 +73,16 @@ def _hash_options(
@classmethod
def get_model(
cls,
model_name: str | None = None,
asr_options: dict | None = None,
vad_options: dict | None = None,
vad_method: str = "pyannote",
) -> Any:
"""Get model with specified options, using cache when possible.

Args:
model_name: Whisper model name (e.g., "base", "large-v3-turbo").
None = use server default from settings.
asr_options: Custom ASR options dict. None = use defaults (fast path).
vad_options: Custom VAD options dict. None = use defaults.
vad_method: VAD method ("pyannote" or "silero").
Expand All @@ -88,12 +96,20 @@ def get_model(
"""
settings = get_settings()

# Fast path: use default model (no custom options)
if asr_options is None and vad_options is None and vad_method == settings.vad_method:
# Resolve model name (None = server default)
effective_model = model_name or settings.model

# Fast path: use default model (no custom options, default model)
if (
effective_model == settings.model
and asr_options is None
and vad_options is None
and vad_method == settings.vad_method
):
return cls._get_default_model()

# Slow path: get/create model with custom options
return cls._get_custom_model(asr_options, vad_options, vad_method)
# Slow path: get/create model with custom options or different model
return cls._get_custom_model(effective_model, asr_options, vad_options, vad_method)

@classmethod
def _get_default_model(cls) -> Any:
Expand All @@ -120,16 +136,27 @@ def _get_default_model(cls) -> Any:

@classmethod
def _get_custom_model(
cls, asr_options: dict | None, vad_options: dict | None, vad_method: str
cls,
model_name: str,
asr_options: dict | None,
vad_options: dict | None,
vad_method: str,
) -> Any:
"""Get or load model with custom options (may be slow on cache miss)."""
"""Get or load model with custom options or model name (may be slow on cache miss).

Args:
model_name: Whisper model name to load (already resolved from settings default).
asr_options: Custom ASR options dict. None = use defaults.
vad_options: Custom VAD options dict. None = use defaults.
vad_method: VAD method ("pyannote" or "silero").
"""
settings = get_settings()
logger = get_logger()

# Build full options (merge with defaults)
full_asr = {**DEFAULT_ASR_OPTIONS, **(asr_options or {})}
full_vad = {**DEFAULT_VAD_OPTIONS, **(vad_options or {})}
options_key = cls._hash_options(full_asr, full_vad, vad_method)
options_key = cls._hash_options(model_name, full_asr, full_vad, vad_method)

with cls._lock:
# Check cache
Expand All @@ -138,14 +165,14 @@ def _get_custom_model(
return cls._custom_models[options_key]

# Load new model with custom options
logger.info(f"Loading custom model (key={options_key})...")
logger.info(f"Loading custom model: {model_name} (key={options_key})...")
logger.info(f" VAD method: {vad_method}")
logger.info(
f" ASR: beam_size={full_asr.get('beam_size')}, temps={full_asr.get('temperatures')}"
)

model = murmurai_core.load_model(
settings.model,
model_name,
Comment thread
Odrec marked this conversation as resolved.
device="cuda",
compute_type=settings.compute_type,
asr_options=full_asr,
Expand All @@ -162,7 +189,7 @@ def _get_custom_model(
torch.cuda.empty_cache()

cls._custom_models[options_key] = model
logger.info(f"Custom model loaded and cached: {options_key}")
logger.info(f"Custom model '{model_name}' loaded and cached: {options_key}")
return model

@classmethod
Expand Down
40 changes: 40 additions & 0 deletions src/murmurai_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@
)
from murmurai_server.transcriber import TranscribeOptions, download_audio, transcribe # noqa: E402

# Allow-list of valid Whisper model names (prevents path traversal attacks)
ALLOWED_MODELS = frozenset(
{
"tiny",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium",
"medium.en",
"large",
"large-v1",
"large-v2",
"large-v3",
"large-v3-turbo",
"distil-large-v2",
"distil-large-v3",
"deepdml/faster-whisper-large-v3-turbo-ct2",
}
)


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
Expand Down Expand Up @@ -253,6 +275,15 @@ async def submit_transcript(
audio_url: Annotated[
str | None, Form(description="URL to download audio from", examples=[""])
] = None,
# Model selection (overrides server default)
model: Annotated[
str | None,
Form(
description="Whisper model to use (e.g., 'base', 'small', 'medium', 'large-v2', "
"'large-v3', 'large-v3-turbo'). Empty = server default.",
examples=[""],
),
] = None,
# All optional parameters with defaults
language_code: Annotated[
str | None, Form(description="Language code (auto-detect if empty)", examples=[""])
Expand Down Expand Up @@ -347,6 +378,7 @@ async def submit_transcript(
# Sanitize nullable string fields (convert empty strings to None)
# This handles Swagger UI sending "" instead of omitting the field
audio_url = audio_url if audio_url else None
model = model if model else None
language_code = language_code if language_code else None
initial_prompt = initial_prompt if initial_prompt else None
hotwords = hotwords if hotwords else None
Expand All @@ -367,6 +399,13 @@ async def submit_transcript(
max_line_width_int: int | None = int(max_line_width) if max_line_width else None
max_line_count_int: int | None = int(max_line_count) if max_line_count else None

# Validate model name against allow-list (prevents path traversal attacks)
if model and model not in ALLOWED_MODELS:
raise HTTPException(
status_code=400,
detail=f"Invalid model: '{model}'. Allowed models: {', '.join(sorted(ALLOWED_MODELS))}",
)

# Validate input
if not file and not audio_url:
raise HTTPException(status_code=400, detail="Either file or audio_url is required")
Expand Down Expand Up @@ -409,6 +448,7 @@ async def submit_transcript(

# Build options (all params already have defaults from Form)
options = TranscribeOptions(
model=model,
language=language_code,
task=task,
speaker_labels=speaker_labels,
Expand Down
6 changes: 5 additions & 1 deletion src/murmurai_server/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def _ensure_ffmpeg() -> None:
class TranscribeOptions:
"""Options for transcription pipeline."""

# Model selection (None = use server default from settings)
model: str | None = None

# Language
language: str | None = None

Expand Down Expand Up @@ -335,8 +338,9 @@ def transcribe(
f" VAD: onset={vad_options.get('vad_onset')}, offset={vad_options.get('vad_offset')}"
)

# Get model (fast path if defaults, slow path if custom options)
# Get model (fast path if defaults, slow path if custom options or model)
model = ModelManager.get_model(
model_name=options.model,
asr_options=asr_options,
vad_options=vad_options,
vad_method=options.vad_method,
Expand Down
89 changes: 89 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,95 @@ async def test_submit_transcript_with_options(
data = response.json()
assert data["language_code"] == "en"

@pytest.mark.asyncio
async def test_submit_transcript_with_model(
self, async_client: AsyncClient, auth_headers: dict, tmp_path: Path
):
"""Test POST /v1/transcript with per-request model selection."""
test_audio = tmp_path / "test.mp3"
test_audio.touch()

with (
patch("murmurai_server.server.download_audio", new_callable=AsyncMock) as mock_dl,
patch("murmurai_server.server.process_transcription") as mock_process,
):
mock_dl.return_value = test_audio

response = await async_client.post(
"/v1/transcript",
headers=auth_headers,
data={
"audio_url": "https://example.com/test.mp3",
"model": "base",
},
)

assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["status"] == "queued"

# Verify the correct model was passed to the background task
mock_process.assert_called_once()
call_kwargs = mock_process.call_args.kwargs
assert call_kwargs["options"].model == "base"

@pytest.mark.asyncio
async def test_submit_transcript_with_empty_model_uses_default(
self, async_client: AsyncClient, auth_headers: dict, tmp_path: Path
):
"""Test POST /v1/transcript with empty model string uses server default."""
test_audio = tmp_path / "test.mp3"
test_audio.touch()

with (
patch("murmurai_server.server.download_audio", new_callable=AsyncMock) as mock_dl,
patch("murmurai_server.server.process_transcription") as mock_process,
):
mock_dl.return_value = test_audio

response = await async_client.post(
"/v1/transcript",
headers=auth_headers,
data={
"audio_url": "https://example.com/test.mp3",
"model": "", # empty string should be treated as None (server default)
},
)

assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["status"] == "queued"
Comment thread
Odrec marked this conversation as resolved.

# Verify model is None (will use server default)
mock_process.assert_called_once()
call_kwargs = mock_process.call_args.kwargs
assert call_kwargs["options"].model is None

@pytest.mark.asyncio
async def test_submit_transcript_with_invalid_model_rejected(
self, async_client: AsyncClient, auth_headers: dict, tmp_path: Path
):
"""Test POST /v1/transcript rejects invalid model names (path traversal prevention)."""
test_audio = tmp_path / "test.mp3"
test_audio.touch()

with patch("murmurai_server.server.download_audio", new_callable=AsyncMock) as mock_dl:
mock_dl.return_value = test_audio

response = await async_client.post(
"/v1/transcript",
headers=auth_headers,
data={
"audio_url": "https://example.com/test.mp3",
"model": "../../etc/passwd",
},
)

assert response.status_code == 400
assert "Invalid model" in response.json()["detail"]

@pytest.mark.asyncio
async def test_submit_transcript_no_audio(self, async_client: AsyncClient, auth_headers: dict):
"""Test POST /v1/transcript without audio fails."""
Expand Down
Loading