Skip to content

Commit d207827

Browse files
committed
Add Wespeaker embeddings model
1 parent b841a8d commit d207827

6 files changed

Lines changed: 224 additions & 12 deletions

File tree

src/onnx_asr/adapters.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from onnx_asr.asr import Asr, TimestampedResult
1515
from onnx_asr.preprocessors.resampler import Resampler
16+
from onnx_asr.se import SpeakerEmbedding
1617
from onnx_asr.utils import SampleRates, read_wav_files
1718
from onnx_asr.vad import SegmentResult, TimestampedSegmentResult, Vad
1819

@@ -230,3 +231,50 @@ def _recognize_batch(
230231
self.asr, waveforms, waveforms_len, self.asr._get_sample_rate(), {**kwargs}, **self._vadargs
231232
)
232233
)
234+
235+
236+
class SeAdapter:
237+
"""Speaker Embedding adapter class."""
238+
239+
se: SpeakerEmbedding
240+
resampler: Resampler
241+
242+
def __init__(self, se: SpeakerEmbedding, resampler: Resampler):
243+
"""Create SE adapter."""
244+
self.se = se
245+
self.resampler = resampler
246+
247+
def embedding(
248+
self,
249+
waveform: str | Path | npt.NDArray[np.float32] | list[str | Path | npt.NDArray[np.float32]],
250+
*,
251+
sample_rate: SampleRates = 16_000,
252+
channel: int | Literal["mean"] | None = None,
253+
) -> npt.NDArray[np.float32]:
254+
"""Compute speaker embedding (single or batch).
255+
256+
Args:
257+
waveform: Path to wav file (only PCM_U8, PCM_16, PCM_24 and PCM_32 formats are supported)
258+
or Numpy array with PCM waveform.
259+
A list of file paths or numpy arrays for batch recognition are also supported.
260+
sample_rate: Sample rate for Numpy arrays in waveform.
261+
channel: Channel selector for multi-channel audio.
262+
263+
Returns:
264+
speaker embedding results.
265+
266+
Raises:
267+
utils.AudioLoadingError: Audio loading error (onnx-asr specific).
268+
FileNotFoundError: File not found error.
269+
wave.Error: WAV file reading error.
270+
OSError: Other IO errors.
271+
272+
"""
273+
if isinstance(waveform, list) and not waveform:
274+
return np.array(None, dtype=np.float32)
275+
276+
waveform_batch = waveform if isinstance(waveform, list) else [waveform]
277+
result = self.se.embedding(*self.resampler(*read_wav_files(waveform_batch, sample_rate, channel)))
278+
if isinstance(waveform, list):
279+
return result
280+
return result.squeeze(0)

src/onnx_asr/loader.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77

88
import onnxruntime as rt
99

10-
from onnx_asr.adapters import TextResultsAsrAdapter
10+
from onnx_asr.adapters import SeAdapter, TextResultsAsrAdapter
1111
from onnx_asr.asr import Asr, Preprocessor
1212
from onnx_asr.models.gigaam import GigaamV2Ctc, GigaamV2Rnnt, GigaamV3E2eCtc, GigaamV3E2eRnnt
1313
from onnx_asr.models.kaldi import KaldiTransducer
1414
from onnx_asr.models.nemo import NemoConformerAED, NemoConformerCtc, NemoConformerRnnt, NemoConformerTdt
1515
from onnx_asr.models.pyannote import PyAnnoteVad
1616
from onnx_asr.models.silero import SileroVad
1717
from onnx_asr.models.tone import TOneCtc
18+
from onnx_asr.models.wespeaker import WespeakerEmbeddings
1819
from onnx_asr.models.whisper import WhisperHf, WhisperOrt
1920
from onnx_asr.onnx import OnnxSessionOptions, get_onnx_providers, update_onnx_providers
2021
from onnx_asr.preprocessors.numpy_preprocessor import (
@@ -26,6 +27,7 @@
2627
from onnx_asr.preprocessors.preprocessor import ConcurrentPreprocessor, IdentityPreprocessor, OnnxPreprocessor
2728
from onnx_asr.preprocessors.resampler import Resampler
2829
from onnx_asr.resolver import Resolver
30+
from onnx_asr.se import SpeakerEmbedding
2931
from onnx_asr.utils import (
3032
ModelNotSupportedError,
3133
)
@@ -82,7 +84,7 @@
8284

8385

8486
def create_asr_resolver(
85-
model: str, local_dir: str | Path | None = None, *, offline: bool | None = None
87+
model: str | None = None, local_dir: str | Path | None = None, *, offline: bool | None = None
8688
) -> Resolver[AsrTypes]:
8789
"""Create resolver for ASR models."""
8890
model_types: dict[str, type[AsrTypes]] = {
@@ -120,13 +122,20 @@ def create_asr_resolver(
120122

121123

122124
def create_vad_resolver(
123-
model: str, local_dir: str | Path | None = None, *, offline: bool | None = None
125+
model: str | None = None, local_dir: str | Path | None = None, *, offline: bool | None = None
124126
) -> Resolver[VadTypes]:
125127
"""Create resolver for VAD models."""
126128
model_types: dict[str, type[VadTypes]] = {"silero": SileroVad, "pyannote": PyAnnoteVad}
127129
return Resolver(model_types, model, local_dir, offline=offline)
128130

129131

132+
def create_se_resolver(
133+
model: str | None = None, local_dir: str | Path | None = None, *, offline: bool | None = None
134+
) -> Resolver[WespeakerEmbeddings]:
135+
"""Create resolver for SE models."""
136+
return Resolver(WespeakerEmbeddings, model, local_dir, offline=offline)
137+
138+
130139
class PreprocessorRuntimeConfig(OnnxSessionOptions, total=False):
131140
"""Preprocessor runtime config."""
132141

@@ -206,30 +215,34 @@ def _create_preprocessor(self, name: str) -> Preprocessor:
206215
def _create_resampler(self, sample_rate: Literal[8000, 16000]) -> Resampler:
207216
return Resampler(sample_rate, self.resampler_config)
208217

218+
def _create_asr_adapter(self, asr: Asr) -> TextResultsAsrAdapter:
219+
return TextResultsAsrAdapter(asr, self._create_resampler(asr._get_sample_rate()))
220+
221+
def _create_se_adapter(self, se: SpeakerEmbedding) -> SeAdapter:
222+
return SeAdapter(se, self._create_resampler(se._get_sample_rate()))
223+
209224
def create_asr(
210225
self,
211-
model: str,
226+
model: str | ModelNames | ModelTypes | None = None,
212227
local_dir: str | Path | None = None,
213228
*,
214229
quantization: str | None = None,
215230
offline: bool | None = None,
216231
config: OnnxSessionOptions | None = None,
217-
) -> Asr:
232+
) -> TextResultsAsrAdapter:
218233
"""Create ASR model."""
219234
resolver = create_asr_resolver(model, local_dir, offline=offline)
220235
if config is None:
221236
config = update_onnx_providers(
222237
self.default_onnx_config, excluded_providers=resolver.model_type._get_excluded_providers()
223238
)
224-
return resolver.model_type(resolver.resolve_model(quantization=quantization), self._create_preprocessor, config)
225-
226-
def create_adapter(self, asr: Asr) -> TextResultsAsrAdapter:
227-
"""Create ASR adapter."""
228-
return TextResultsAsrAdapter(asr, self._create_resampler(asr._get_sample_rate()))
239+
return self._create_asr_adapter(
240+
resolver.model_type(resolver.resolve_model(quantization=quantization), self._create_preprocessor, config)
241+
)
229242

230243
def create_vad(
231244
self,
232-
model: str,
245+
model: str | VadNames | None = None,
233246
local_dir: str | Path | None = None,
234247
*,
235248
quantization: str | None = None,
@@ -244,6 +257,25 @@ def create_vad(
244257
)
245258
return resolver.model_type(resolver.resolve_model(quantization=quantization), config)
246259

260+
def create_se(
261+
self,
262+
model: str | None = None,
263+
local_dir: str | Path | None = None,
264+
*,
265+
quantization: str | None = None,
266+
offline: bool | None = None,
267+
config: OnnxSessionOptions | None = None,
268+
) -> SeAdapter:
269+
"""Create SE model."""
270+
resolver = create_se_resolver(model, local_dir, offline=offline)
271+
if config is None:
272+
config = update_onnx_providers(
273+
self.default_onnx_config, excluded_providers=resolver.model_type._get_excluded_providers()
274+
)
275+
return self._create_se_adapter(
276+
resolver.model_type(resolver.resolve_model(quantization=quantization), self._create_preprocessor, config)
277+
)
278+
247279

248280
def load_model(
249281
model: str | ModelNames | ModelTypes,
@@ -304,7 +336,7 @@ def load_model(
304336
)
305337

306338
manager = Manager(sess_options, providers, provider_options, preprocessor_config, resampler_config)
307-
return manager.create_adapter(manager.create_asr(model, path, quantization=quantization, config=asr_config))
339+
return manager.create_asr(model, path, quantization=quantization, config=asr_config)
308340

309341

310342
def load_vad(

src/onnx_asr/models/wespeaker.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Wespeaker SE implementation."""
2+
3+
from collections.abc import Callable
4+
from pathlib import Path
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
import onnxruntime as rt
9+
10+
from onnx_asr.asr import Preprocessor
11+
from onnx_asr.onnx import OnnxSessionOptions
12+
from onnx_asr.se import SpeakerEmbedding
13+
from onnx_asr.utils import is_float32_array
14+
15+
16+
class WespeakerEmbeddings(SpeakerEmbedding):
17+
"""Wespeaker embeddings model."""
18+
19+
def __init__(
20+
self,
21+
model_files: dict[str, Path],
22+
preprocessor_factory: Callable[[str], Preprocessor],
23+
onnx_options: OnnxSessionOptions,
24+
):
25+
"""Create model.
26+
27+
Args:
28+
model_files: Dict with paths to model files.
29+
preprocessor_factory: Factory for preprocessor creation.
30+
onnx_options: Options for onnxruntime InferenceSession.
31+
32+
"""
33+
self._model = rt.InferenceSession(model_files["model"], **onnx_options)
34+
self._preprocessor = preprocessor_factory("wespeaker")
35+
36+
@staticmethod
37+
def _get_excluded_providers() -> list[str]:
38+
return []
39+
40+
@staticmethod
41+
def _get_model_files(quantization: str | None = None) -> dict[str, str]:
42+
suffix = "?" + quantization if quantization else ""
43+
return {"config": "config.yaml", "model": f"*{suffix}.onnx"}
44+
45+
def embedding(
46+
self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64]
47+
) -> npt.NDArray[np.float32]:
48+
"""Compute speaker embedding."""
49+
features, _ = self._preprocessor(waveforms, waveforms_len)
50+
features -= features.mean(axis=1, keepdims=True)
51+
(embs,) = self._model.run(["embs"], {"feats": features})
52+
assert is_float32_array(embs)
53+
return embs

src/onnx_asr/se.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Base Speaker Embedding classes."""
2+
3+
from typing import Literal, Protocol
4+
5+
import numpy as np
6+
import numpy.typing as npt
7+
8+
9+
class SpeakerEmbedding(Protocol):
10+
"""Speaker Embedding protocol."""
11+
12+
@staticmethod
13+
def _get_sample_rate() -> Literal[8_000, 16_000]:
14+
return 16_000
15+
16+
def embedding(
17+
self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64]
18+
) -> npt.NDArray[np.float32]:
19+
"""Compute speaker embedding."""
20+
...

tests/onnx_asr/test_embedding.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
import pytest
3+
4+
from onnx_asr.adapters import SeAdapter
5+
from onnx_asr.loader import Manager
6+
7+
8+
@pytest.fixture(scope="module", params=["wespeaker/wespeaker-voxceleb-resnet34"])
9+
def model(request: pytest.FixtureRequest) -> SeAdapter:
10+
manager = Manager()
11+
return manager.create_se(request.param)
12+
13+
14+
def test_embedding(model: SeAdapter) -> None:
15+
rng = np.random.default_rng(0)
16+
waveform = rng.random((1 * 16_000), dtype=np.float32)
17+
18+
result = model.embedding(waveform)
19+
assert isinstance(result, np.ndarray)
20+
assert result.dtype == np.float32
21+
assert result.ndim == 1
22+
23+
24+
def test_empty_embedding(model: SeAdapter) -> None:
25+
result = model.embedding([])
26+
assert isinstance(result, np.ndarray)
27+
assert result.dtype == np.float32
28+
assert result.ndim == 0
29+
30+
31+
def test_embedding_batch(model: SeAdapter) -> None:
32+
rng = np.random.default_rng(0)
33+
waveform1 = rng.random((2 * 16_000), dtype=np.float32)
34+
waveform2 = rng.random((1 * 16_000), dtype=np.float32)
35+
36+
result = model.embedding([waveform1, waveform2])
37+
assert isinstance(result, np.ndarray)
38+
assert result.dtype == np.float32
39+
assert result.ndim == 2
40+
assert result.shape[0] == 2

tests/onnx_asr/test_resolver.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
ModelTypes,
1111
VadNames,
1212
create_asr_resolver,
13+
create_se_resolver,
1314
create_vad_resolver,
1415
)
1516
from onnx_asr.models.kaldi import KaldiTransducer
1617
from onnx_asr.models.nemo import NemoConformerAED
1718
from onnx_asr.models.pyannote import PyAnnoteVad
1819
from onnx_asr.models.silero import SileroVad
1920
from onnx_asr.models.tone import TOneCtc
21+
from onnx_asr.models.wespeaker import WespeakerEmbeddings
2022
from onnx_asr.models.whisper import WhisperHf
2123
from onnx_asr.resolver import Resolver
2224
from onnx_asr.utils import (
@@ -208,3 +210,20 @@ def test_resolve_vad_file_not_found_error() -> None:
208210
loader = create_vad_resolver("silero")
209211
with pytest.raises(ModelFileNotFoundError):
210212
loader.resolve_model(quantization="xxx")
213+
214+
215+
@pytest.mark.parametrize("model", ["wespeaker/wespeaker-voxceleb-resnet34"])
216+
def test_se(model: str) -> None:
217+
loader = create_se_resolver(model)
218+
assert issubclass(loader.model_type, WespeakerEmbeddings)
219+
assert not loader.offline
220+
assert loader.local_dir is None
221+
assert isinstance(loader.repo_id, str)
222+
223+
224+
def test_se_with_path(tmp_path: Path) -> None:
225+
loader = create_se_resolver(local_dir=tmp_path)
226+
assert issubclass(loader.model_type, WespeakerEmbeddings)
227+
assert loader.offline
228+
assert loader.local_dir == tmp_path
229+
assert loader.repo_id is None

0 commit comments

Comments
 (0)