Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: "3.x"
python-version: "3.13"

- name: Install
run: |
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 3.1.0

- Refactor to dynamically load models

## 3.0.2

- Set `--data-dir /data` in Docker run script
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "wyoming-faster-whisper"
version = "3.0.2"
version = "3.1.0"
description = "Wyoming Server for Faster Whisper"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
192 changes: 30 additions & 162 deletions wyoming_faster_whisper/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import platform
import re
from functools import partial
from typing import Any, Optional

import faster_whisper
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
from wyoming.server import AsyncServer, AsyncTcpServer

from . import __version__
from .const import AUTO_LANGUAGE, AUTO_MODEL, PARAKEET_LANGUAGES, SttLibrary
from .dispatch_handler import DispatchEventHandler
from .models import ModelLoader

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,55 +107,31 @@ async def main() -> None:
)
_LOGGER.debug(args)

# Automatic configuration
stt_library = SttLibrary(args.stt_library)
if stt_library == SttLibrary.AUTO:
if args.model == AUTO_MODEL:
if args.language in ("en", AUTO_LANGUAGE):
# Prefer parakeet
try:
from .sherpa_handler import SherpaModel

stt_library = SttLibrary.SHERPA
except ImportError:
stt_library = SttLibrary.FASTER_WHISPER
elif args.language == "ru":
# Prefer GigaAM via onnx-asr
try:
from .sherpa_handler import SherpaModel

stt_library = SttLibrary.ONNX_ASR
except ImportError:
stt_library = SttLibrary.FASTER_WHISPER
else:
# Default to faster-whisper if model is provided
stt_library = SttLibrary.FASTER_WHISPER

_LOGGER.debug("Speech-to-text library automatically selected: %s", stt_library)
args.stt_library = SttLibrary(args.stt_library)

machine = platform.machine().lower()
is_arm = ("arm" in machine) or ("aarch" in machine)
if args.model == AUTO_MODEL:
args.model = guess_model(stt_library, args.language, is_arm)
_LOGGER.debug("Model automatically selected: %s", args.model)

if args.beam_size <= 0:
args.beam_size = 1 if is_arm else 5
_LOGGER.debug("Beam size automatically selected: %s", args.beam_size)

# Resolve model name
model_name = args.model
match = re.match(r"^(tiny|base|small|medium)[.-]int8$", args.model)
if match:
model_match = re.match(r"^(tiny|base|small|medium)[.-]int8$", args.model)
if model_match:
# Original models re-uploaded to huggingface
model_size = match.group(1)
model_size = model_match.group(1)
model_name = f"{model_size}-int8"
args.model = f"rhasspy/faster-whisper-{model_name}"

if args.language == AUTO_LANGUAGE:
# Whisper does not understand auto
args.language = None

if args.model == AUTO_MODEL:
args.model = None

wyoming_info = Info(
asr=[
AsrProgram(
Expand Down Expand Up @@ -190,38 +167,22 @@ async def main() -> None:
],
)

# Load model
_LOGGER.debug("Loading %s", args.model)
whisper_model: Any = None

if stt_library == SttLibrary.SHERPA:
# Use Sherpa ONNX with nemo
from .sherpa_handler import SherpaModel # noqa: F811

whisper_model = SherpaModel(args.model, args.download_dir)
elif stt_library == SttLibrary.TRANSFORMERS:
# Use HuggingFace transformers
from .transformers_whisper import TransformersWhisperModel

whisper_model = TransformersWhisperModel(
args.model, args.download_dir, args.local_files_only
)
elif stt_library == SttLibrary.ONNX_ASR:
# Use onnx-asr
from .onnx_asr_handler import OnnxAsrModel
loader = ModelLoader(
preferred_stt_library=args.stt_library,
preferred_language=args.language,
download_dir=args.download_dir,
local_files_only=args.local_files_only,
model=args.model,
compute_type=args.compute_type,
device=args.device,
beam_size=args.beam_size,
cpu_threads=args.cpu_threads,
initial_prompt=args.initial_prompt,
)

whisper_model = OnnxAsrModel(
args.model, args.download_dir, args.local_files_only
)
else:
# Use faster-whisper
whisper_model = faster_whisper.WhisperModel(
args.model,
download_root=args.download_dir,
device=args.device,
compute_type=args.compute_type,
cpu_threads=args.cpu_threads,
)
# Load model
_LOGGER.debug("Pre-loading transcriber")
await loader.load_transcriber()

server = AsyncServer.from_uri(args.uri)

Expand All @@ -239,106 +200,13 @@ async def main() -> None:
_LOGGER.debug("Zeroconf discovery enabled")

_LOGGER.info("Ready")
model_lock = asyncio.Lock()

if stt_library == SttLibrary.SHERPA:
from .sherpa_handler import SherpaEventHandler

await server.run(
partial(
SherpaEventHandler,
wyoming_info,
args.language,
args.beam_size,
whisper_model,
model_lock,
)
)
elif stt_library == SttLibrary.TRANSFORMERS:
# Use HuggingFace transformers
from .transformers_whisper import (
TransformersWhisperEventHandler,
TransformersWhisperModel,
)

assert isinstance(whisper_model, TransformersWhisperModel)

await server.run(
partial(
TransformersWhisperEventHandler,
wyoming_info,
args.language,
args.beam_size,
whisper_model,
model_lock,
)
)
elif stt_library == SttLibrary.ONNX_ASR:
# Use onnx-asr
from .onnx_asr_handler import OnnxAsrEventHandler, OnnxAsrModel

assert isinstance(whisper_model, OnnxAsrModel)

await server.run(
partial(
OnnxAsrEventHandler,
wyoming_info,
args.language,
args.beam_size,
whisper_model,
model_lock,
)
await server.run(
partial(
DispatchEventHandler,
wyoming_info,
loader,
)
else:
# faster-whisper
from .faster_whisper_handler import FasterWhisperEventHandler

assert isinstance(whisper_model, faster_whisper.WhisperModel)
await server.run(
partial(
FasterWhisperEventHandler,
wyoming_info,
args,
whisper_model,
model_lock,
initial_prompt=args.initial_prompt,
)
)


# -----------------------------------------------------------------------------


def guess_model(stt_library: SttLibrary, language: Optional[str], is_arm: bool) -> str:
"""Automatically guess STT model id."""
if stt_library == SttLibrary.SHERPA:
if language == "en":
return "sherpa-onnx-nemo-parakeet-tdt-0.6b-v2-int8"

# Non-English
return "sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8"

if stt_library == SttLibrary.TRANSFORMERS:
if language == "en":
if is_arm:
return "openai/whisper-tiny.en"

return "openai/whisper-base.en"

# Non-English
if is_arm:
return "openai/whisper-tiny"

return "openai/whisper-base"

if stt_library == SttLibrary.ONNX_ASR:
return "gigaam-v2-rnnt"

# faster-whisper
if is_arm:
return "tiny-int8"

return "base-int8"
)


# -----------------------------------------------------------------------------
Expand Down
17 changes: 17 additions & 0 deletions wyoming_faster_whisper/const.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Constants."""

from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Optional, Union


class SttLibrary(str, Enum):
Expand Down Expand Up @@ -43,3 +46,17 @@ class SttLibrary(str, Enum):
"ru",
"uk",
}


class Transcriber(ABC):
"""Base class for transcribers."""

@abstractmethod
def transcribe(
self,
wav_path: Union[str, Path],
language: Optional[str],
beam_size: int = 5,
initial_prompt: Optional[str] = None,
) -> str:
pass
Loading