Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions src/murmurai_server/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,15 @@ class ModelManager:

@classmethod
def _hash_options(
cls, asr_options: dict | None, vad_options: dict | None, vad_method: str
cls,
model_name: str | None,
asr_options: dict | None,
vad_options: dict | None,
vad_method: str,
) -> str:
"""Create hash key for model options."""
data = {
"model": model_name,
"asr": asr_options or {},
"vad": vad_options or {},
"vad_method": vad_method,
Expand All @@ -68,13 +73,16 @@ def _hash_options(
@classmethod
def get_model(
cls,
model_name: str | None = None,
asr_options: dict | None = None,
vad_options: dict | None = None,
vad_method: str = "pyannote",
) -> Any:
"""Get model with specified options, using cache when possible.

Args:
model_name: Whisper model name (e.g., "base", "large-v3-turbo").
None = use server default from settings.
asr_options: Custom ASR options dict. None = use defaults (fast path).
vad_options: Custom VAD options dict. None = use defaults.
vad_method: VAD method ("pyannote" or "silero").
Expand All @@ -88,12 +96,20 @@ def get_model(
"""
settings = get_settings()

# Fast path: use default model (no custom options)
if asr_options is None and vad_options is None and vad_method == settings.vad_method:
# Resolve model name (None = server default)
effective_model = model_name or settings.model

# Fast path: use default model (no custom options, default model)
if (
effective_model == settings.model
and asr_options is None
and vad_options is None
and vad_method == settings.vad_method
):
return cls._get_default_model()

# Slow path: get/create model with custom options
return cls._get_custom_model(asr_options, vad_options, vad_method)
# Slow path: get/create model with custom options or different model
return cls._get_custom_model(effective_model, asr_options, vad_options, vad_method)

@classmethod
def _get_default_model(cls) -> Any:
Expand All @@ -120,16 +136,27 @@ def _get_default_model(cls) -> Any:

@classmethod
def _get_custom_model(
cls, asr_options: dict | None, vad_options: dict | None, vad_method: str
cls,
model_name: str,
asr_options: dict | None,
vad_options: dict | None,
vad_method: str,
) -> Any:
"""Get or load model with custom options (may be slow on cache miss)."""
"""Get or load model with custom options or model name (may be slow on cache miss).

Args:
model_name: Whisper model name to load (already resolved from settings default).
asr_options: Custom ASR options dict. None = use defaults.
vad_options: Custom VAD options dict. None = use defaults.
vad_method: VAD method ("pyannote" or "silero").
"""
settings = get_settings()
logger = get_logger()

# Build full options (merge with defaults)
full_asr = {**DEFAULT_ASR_OPTIONS, **(asr_options or {})}
full_vad = {**DEFAULT_VAD_OPTIONS, **(vad_options or {})}
options_key = cls._hash_options(full_asr, full_vad, vad_method)
options_key = cls._hash_options(model_name, full_asr, full_vad, vad_method)

with cls._lock:
# Check cache
Expand All @@ -138,14 +165,14 @@ def _get_custom_model(
return cls._custom_models[options_key]

# Load new model with custom options
logger.info(f"Loading custom model (key={options_key})...")
logger.info(f"Loading custom model: {model_name} (key={options_key})...")
logger.info(f" VAD method: {vad_method}")
logger.info(
f" ASR: beam_size={full_asr.get('beam_size')}, temps={full_asr.get('temperatures')}"
)

model = murmurai_core.load_model(
settings.model,
model_name,
Comment thread
Odrec marked this conversation as resolved.
device="cuda",
compute_type=settings.compute_type,
asr_options=full_asr,
Expand All @@ -162,7 +189,7 @@ def _get_custom_model(
torch.cuda.empty_cache()

cls._custom_models[options_key] = model
logger.info(f"Custom model loaded and cached: {options_key}")
logger.info(f"Custom model '{model_name}' loaded and cached: {options_key}")
return model

@classmethod
Expand Down
11 changes: 11 additions & 0 deletions src/murmurai_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,15 @@ async def submit_transcript(
audio_url: Annotated[
str | None, Form(description="URL to download audio from", examples=[""])
] = None,
# Model selection (overrides server default)
model: Annotated[
str | None,
Form(
description="Whisper model to use (e.g., 'base', 'small', 'medium', 'large-v2', "
"'large-v3', 'large-v3-turbo'). Empty = server default.",
examples=[""],
),
] = None,
# All optional parameters with defaults
language_code: Annotated[
str | None, Form(description="Language code (auto-detect if empty)", examples=[""])
Expand Down Expand Up @@ -347,6 +356,7 @@ async def submit_transcript(
# Sanitize nullable string fields (convert empty strings to None)
# This handles Swagger UI sending "" instead of omitting the field
audio_url = audio_url if audio_url else None
model = model if model else None
language_code = language_code if language_code else None
initial_prompt = initial_prompt if initial_prompt else None
hotwords = hotwords if hotwords else None
Expand Down Expand Up @@ -409,6 +419,7 @@ async def submit_transcript(

# Build options (all params already have defaults from Form)
options = TranscribeOptions(
model=model,
language=language_code,
task=task,
speaker_labels=speaker_labels,
Expand Down
6 changes: 5 additions & 1 deletion src/murmurai_server/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def _ensure_ffmpeg() -> None:
class TranscribeOptions:
"""Options for transcription pipeline."""

# Model selection (None = use server default from settings)
model: str | None = None

# Language
language: str | None = None

Expand Down Expand Up @@ -335,8 +338,9 @@ def transcribe(
f" VAD: onset={vad_options.get('vad_onset')}, offset={vad_options.get('vad_offset')}"
)

# Get model (fast path if defaults, slow path if custom options)
# Get model (fast path if defaults, slow path if custom options or model)
model = ModelManager.get_model(
model_name=options.model,
asr_options=asr_options,
vad_options=vad_options,
vad_method=options.vad_method,
Expand Down
50 changes: 50 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,56 @@ async def test_submit_transcript_with_options(
data = response.json()
assert data["language_code"] == "en"

@pytest.mark.asyncio
async def test_submit_transcript_with_model(
self, async_client: AsyncClient, auth_headers: dict, tmp_path: Path
):
"""Test POST /v1/transcript with per-request model selection."""
test_audio = tmp_path / "test.mp3"
test_audio.touch()

with patch("murmurai_server.server.download_audio", new_callable=AsyncMock) as mock_dl:
mock_dl.return_value = test_audio

response = await async_client.post(
"/v1/transcript",
headers=auth_headers,
data={
"audio_url": "https://example.com/test.mp3",
"model": "base",
},
)

assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["status"] == "queued"

@pytest.mark.asyncio
async def test_submit_transcript_with_empty_model_uses_default(
self, async_client: AsyncClient, auth_headers: dict, tmp_path: Path
):
"""Test POST /v1/transcript with empty model string uses server default."""
test_audio = tmp_path / "test.mp3"
test_audio.touch()

with patch("murmurai_server.server.download_audio", new_callable=AsyncMock) as mock_dl:
mock_dl.return_value = test_audio

response = await async_client.post(
"/v1/transcript",
headers=auth_headers,
data={
"audio_url": "https://example.com/test.mp3",
"model": "", # empty string should be treated as None (server default)
},
)

assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["status"] == "queued"
Comment thread
Odrec marked this conversation as resolved.

@pytest.mark.asyncio
async def test_submit_transcript_no_audio(self, async_client: AsyncClient, auth_headers: dict):
"""Test POST /v1/transcript without audio fails."""
Expand Down
Loading