77
88import onnxruntime as rt
99
10- from onnx_asr .adapters import TextResultsAsrAdapter
10+ from onnx_asr .adapters import SeAdapter , TextResultsAsrAdapter
1111from onnx_asr .asr import Asr , Preprocessor
1212from onnx_asr .models .gigaam import GigaamV2Ctc , GigaamV2Rnnt , GigaamV3E2eCtc , GigaamV3E2eRnnt
1313from onnx_asr .models .kaldi import KaldiTransducer
1414from onnx_asr .models .nemo import NemoConformerAED , NemoConformerCtc , NemoConformerRnnt , NemoConformerTdt
1515from onnx_asr .models .pyannote import PyAnnoteVad
1616from onnx_asr .models .silero import SileroVad
1717from onnx_asr .models .tone import TOneCtc
18+ from onnx_asr .models .wespeaker import WespeakerEmbeddings
1819from onnx_asr .models .whisper import WhisperHf , WhisperOrt
1920from onnx_asr .onnx import OnnxSessionOptions , get_onnx_providers , update_onnx_providers
2021from onnx_asr .preprocessors .numpy_preprocessor import (
2627from onnx_asr .preprocessors .preprocessor import ConcurrentPreprocessor , IdentityPreprocessor , OnnxPreprocessor
2728from onnx_asr .preprocessors .resampler import Resampler
2829from onnx_asr .resolver import Resolver
30+ from onnx_asr .se import SpeakerEmbedding
2931from onnx_asr .utils import (
3032 ModelNotSupportedError ,
3133)
8284
8385
8486def create_asr_resolver (
85- model : str , local_dir : str | Path | None = None , * , offline : bool | None = None
87+ model : str | None = None , local_dir : str | Path | None = None , * , offline : bool | None = None
8688) -> Resolver [AsrTypes ]:
8789 """Create resolver for ASR models."""
8890 model_types : dict [str , type [AsrTypes ]] = {
@@ -120,13 +122,20 @@ def create_asr_resolver(
120122
121123
122124def create_vad_resolver (
123- model : str , local_dir : str | Path | None = None , * , offline : bool | None = None
125+ model : str | None = None , local_dir : str | Path | None = None , * , offline : bool | None = None
124126) -> Resolver [VadTypes ]:
125127 """Create resolver for VAD models."""
126128 model_types : dict [str , type [VadTypes ]] = {"silero" : SileroVad , "pyannote" : PyAnnoteVad }
127129 return Resolver (model_types , model , local_dir , offline = offline )
128130
129131
132+ def create_se_resolver (
133+ model : str | None = None , local_dir : str | Path | None = None , * , offline : bool | None = None
134+ ) -> Resolver [WespeakerEmbeddings ]:
135+ """Create resolver for SE models."""
136+ return Resolver (WespeakerEmbeddings , model , local_dir , offline = offline )
137+
138+
130139class PreprocessorRuntimeConfig (OnnxSessionOptions , total = False ):
131140 """Preprocessor runtime config."""
132141
@@ -206,30 +215,34 @@ def _create_preprocessor(self, name: str) -> Preprocessor:
206215 def _create_resampler (self , sample_rate : Literal [8000 , 16000 ]) -> Resampler :
207216 return Resampler (sample_rate , self .resampler_config )
208217
218+ def _create_asr_adapter (self , asr : Asr ) -> TextResultsAsrAdapter :
219+ return TextResultsAsrAdapter (asr , self ._create_resampler (asr ._get_sample_rate ()))
220+
221+ def _create_se_adapter (self , se : SpeakerEmbedding ) -> SeAdapter :
222+ return SeAdapter (se , self ._create_resampler (se ._get_sample_rate ()))
223+
209224 def create_asr (
210225 self ,
211- model : str ,
226+ model : str | ModelNames | ModelTypes | None = None ,
212227 local_dir : str | Path | None = None ,
213228 * ,
214229 quantization : str | None = None ,
215230 offline : bool | None = None ,
216231 config : OnnxSessionOptions | None = None ,
217- ) -> Asr :
232+ ) -> TextResultsAsrAdapter :
218233 """Create ASR model."""
219234 resolver = create_asr_resolver (model , local_dir , offline = offline )
220235 if config is None :
221236 config = update_onnx_providers (
222237 self .default_onnx_config , excluded_providers = resolver .model_type ._get_excluded_providers ()
223238 )
224- return resolver .model_type (resolver .resolve_model (quantization = quantization ), self ._create_preprocessor , config )
225-
226- def create_adapter (self , asr : Asr ) -> TextResultsAsrAdapter :
227- """Create ASR adapter."""
228- return TextResultsAsrAdapter (asr , self ._create_resampler (asr ._get_sample_rate ()))
239+ return self ._create_asr_adapter (
240+ resolver .model_type (resolver .resolve_model (quantization = quantization ), self ._create_preprocessor , config )
241+ )
229242
230243 def create_vad (
231244 self ,
232- model : str ,
245+ model : str | VadNames | None = None ,
233246 local_dir : str | Path | None = None ,
234247 * ,
235248 quantization : str | None = None ,
@@ -244,6 +257,25 @@ def create_vad(
244257 )
245258 return resolver .model_type (resolver .resolve_model (quantization = quantization ), config )
246259
260+ def create_se (
261+ self ,
262+ model : str | None = None ,
263+ local_dir : str | Path | None = None ,
264+ * ,
265+ quantization : str | None = None ,
266+ offline : bool | None = None ,
267+ config : OnnxSessionOptions | None = None ,
268+ ) -> SeAdapter :
269+ """Create SE model."""
270+ resolver = create_se_resolver (model , local_dir , offline = offline )
271+ if config is None :
272+ config = update_onnx_providers (
273+ self .default_onnx_config , excluded_providers = resolver .model_type ._get_excluded_providers ()
274+ )
275+ return self ._create_se_adapter (
276+ resolver .model_type (resolver .resolve_model (quantization = quantization ), self ._create_preprocessor , config )
277+ )
278+
247279
248280def load_model (
249281 model : str | ModelNames | ModelTypes ,
@@ -304,7 +336,7 @@ def load_model(
304336 )
305337
306338 manager = Manager (sess_options , providers , provider_options , preprocessor_config , resampler_config )
307- return manager .create_adapter ( manager . create_asr (model , path , quantization = quantization , config = asr_config ) )
339+ return manager .create_asr (model , path , quantization = quantization , config = asr_config )
308340
309341
310342def load_vad (
0 commit comments