Skip to content

Commit 03f13c0

Browse files
committed
Change load_model parameters for preprocessor
1 parent 17457ee commit 03f13c0

3 files changed

Lines changed: 21 additions & 25 deletions

File tree

examples/performance benchmark.ipynb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
"import onnx_asr\n",
2121
"from onnx_asr.utils import read_wav_files\n",
2222
"\n",
23-
"model = onnx_asr.load_model(\n",
24-
" \"gigaam-v3-ctc\", providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"], max_preprocessing_threads=None\n",
25-
")"
23+
"model = onnx_asr.load_model(\"gigaam-v3-ctc\", providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])"
2624
]
2725
},
2826
{

src/onnx_asr/loader.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def load_model( # noqa: C901
163163
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
164164
provider_options: Sequence[dict[Any, Any]] | None = None,
165165
cpu_preprocessing: bool = True,
166-
max_preprocessing_threads: int | None = 1,
166+
preprocessor_config: PreprocessorRuntimeConfig | None = None,
167+
resampler_config: OnnxSessionOptions | None = None,
167168
) -> TextResultsAsrAdapter:
168169
"""Load ASR model.
169170
@@ -187,7 +188,8 @@ def load_model( # noqa: C901
187188
providers: Optional providers for onnxruntime.
188189
provider_options: Optional provider_options for onnxruntime.
189190
cpu_preprocessing: Run preprocessors on CPU.
190-
max_preprocessing_threads: Max parallel preprocessing threads (None - auto, 1 - without parallel processing).
191+
preprocessor_config: Preprocessor ONNX and concurrency config.
192+
resampler_config: Resampler ONNX config.
191193
192194
Returns:
193195
ASR model class.
@@ -281,24 +283,25 @@ def load_model( # noqa: C901
281283
case _:
282284
raise ModelNotSupportedError(model)
283285

284-
onnx_options: OnnxSessionOptions = {
286+
onnx_options: PreprocessorRuntimeConfig = {
285287
"sess_options": sess_options,
286288
"providers": providers or rt.get_available_providers(),
287289
"provider_options": provider_options,
288290
}
289291

290-
preprocessing_onnx_options: OnnxSessionOptions = {"sess_options": sess_options} if cpu_preprocessing else onnx_options
291-
if max_preprocessing_threads != 1:
292-
preprocessing_sess_options = preprocessing_onnx_options["sess_options"] or rt.SessionOptions()
293-
preprocessing_sess_options.intra_op_num_threads = 1
294-
preprocessing_onnx_options["sess_options"] = preprocessing_sess_options
292+
if resampler_config is None:
293+
resampler_config = {"sess_options": sess_options} if cpu_preprocessing else onnx_options
294+
295+
if preprocessor_config is None:
296+
preprocessor_config = {"sess_options": sess_options} if cpu_preprocessing else onnx_options
297+
preprocessor_config |= {"max_concurrent_workers": 1}
295298

296299
return TextResultsAsrAdapter(
297300
model_type(
298301
_find_files(path, repo_id, model_type._get_model_files(quantization)),
299-
AsrRuntimeConfig(onnx_options, PreprocessorRuntimeConfig(preprocessing_onnx_options, max_preprocessing_threads)),
302+
AsrRuntimeConfig(onnx_options, preprocessor_config),
300303
),
301-
Resampler(model_type._get_sample_rate(), preprocessing_onnx_options),
304+
Resampler(model_type._get_sample_rate(), resampler_config),
302305
)
303306

304307

src/onnx_asr/preprocessors/preprocessor.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""ASR preprocessor implementations."""
22

33
from concurrent.futures import ThreadPoolExecutor
4-
from dataclasses import dataclass, field
54
from importlib.resources import files
65
from pathlib import Path
76

@@ -12,12 +11,11 @@
1211
from onnx_asr.utils import OnnxSessionOptions, is_float32_array, is_int64_array
1312

1413

15-
@dataclass()
16-
class PreprocessorRuntimeConfig:
14+
class PreprocessorRuntimeConfig(OnnxSessionOptions, total=False):
1715
"""Preprocessor runtime config."""
1816

19-
onnx_options: OnnxSessionOptions = field(default_factory=OnnxSessionOptions)
20-
max_concurrent_workers: int | None = 1
17+
max_concurrent_workers: int | None
18+
"""Max parallel preprocessing threads (None - auto, 1 - without parallel processing)."""
2119

2220

2321
class Preprocessor:
@@ -31,15 +29,12 @@ def __init__(self, name: str, runtime_config: PreprocessorRuntimeConfig):
3129
runtime_config: Runtime configuration.
3230
3331
"""
32+
self._max_concurrent_workers = runtime_config.pop("max_concurrent_workers", 1)
3433
if name == "identity":
3534
self._preprocessor = None
36-
return
37-
38-
filename = str(Path(name).with_suffix(".onnx"))
39-
self._preprocessor = rt.InferenceSession(
40-
files(__package__).joinpath(filename).read_bytes(), **runtime_config.onnx_options
41-
)
42-
self._max_concurrent_workers = runtime_config.max_concurrent_workers
35+
else:
36+
filename = str(Path(name).with_suffix(".onnx"))
37+
self._preprocessor = rt.InferenceSession(files(__package__).joinpath(filename).read_bytes(), **runtime_config)
4338

4439
def _preprocess(
4540
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64]

0 commit comments

Comments
 (0)