Skip to content

Commit 5228031

Browse files
committed
Fix typing with onnxruntime 1.22
1 parent a695e46 commit 5228031

9 files changed

Lines changed: 63 additions & 29 deletions

File tree

src/onnx_asr/models/gigaam.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import onnxruntime as rt
88

99
from onnx_asr.asr import _AsrWithCtcDecoding, _AsrWithDecoding, _AsrWithTransducerDecoding
10-
from onnx_asr.utils import OnnxSessionOptions
10+
from onnx_asr.utils import OnnxSessionOptions, is_float32_array, is_int32_array
1111

1212

1313
class _GigaamV2(_AsrWithDecoding):
@@ -48,6 +48,7 @@ def _encode(
4848
self, features: npt.NDArray[np.float32], features_lens: npt.NDArray[np.int64]
4949
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
5050
(log_probs,) = self._model.run(["log_probs"], {"features": features, "feature_lengths": features_lens})
51+
assert is_float32_array(log_probs)
5152
return log_probs, (features_lens - 1) // self._subsampling_factor + 1
5253

5354

@@ -91,6 +92,7 @@ def _encode(
9192
encoder_out, encoder_out_lens = self._encoder.run(
9293
["encoded", "encoded_len"], {"audio_signal": features, "length": features_lens}
9394
)
95+
assert is_float32_array(encoder_out) and is_int32_array(encoder_out_lens)
9496
return encoder_out, encoder_out_lens.astype(np.int64)
9597

9698
def _create_state(self) -> _STATE_TYPE:
@@ -102,8 +104,10 @@ def _create_state(self) -> _STATE_TYPE:
102104
def _decode(
103105
self, prev_tokens: list[int], prev_state: _STATE_TYPE, encoder_out: npt.NDArray[np.float32]
104106
) -> tuple[npt.NDArray[np.float32], int, _STATE_TYPE]:
105-
decoder_out, *state = self._decoder.run(
107+
decoder_out, state1, state2 = self._decoder.run(
106108
["dec", "h", "c"], {"x": [[[self._blank_idx, *prev_tokens][-1]]], "h.1": prev_state[0], "c.1": prev_state[1]}
107109
)
110+
assert is_float32_array(decoder_out) and is_float32_array(state1) and is_float32_array(state2)
108111
(joint,) = self._joiner.run(["joint"], {"enc": encoder_out[None, :, None], "dec": decoder_out.transpose(0, 2, 1)})
109-
return np.squeeze(joint), -1, tuple(state)
112+
assert is_float32_array(joint)
113+
return np.squeeze(joint), -1, (state1, state2)

src/onnx_asr/models/kaldi.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import onnxruntime as rt
88

99
from onnx_asr.asr import _AsrWithTransducerDecoding
10-
from onnx_asr.utils import OnnxSessionOptions
10+
from onnx_asr.utils import OnnxSessionOptions, is_float32_array, is_int64_array
1111

1212
_STATE_TYPE = dict[tuple[int, ...], npt.NDArray[np.float32]]
1313

@@ -59,6 +59,7 @@ def _encode(
5959
encoder_out, encoder_out_lens = self._encoder.run(
6060
["encoder_out", "encoder_out_lens"], {"x": features, "x_lens": features_lens}
6161
)
62+
assert is_float32_array(encoder_out) and is_int64_array(encoder_out_lens)
6263
return encoder_out.transpose(0, 2, 1), encoder_out_lens
6364

6465
def _create_state(self) -> _STATE_TYPE:
@@ -68,7 +69,9 @@ def _decode(
6869
self, prev_tokens: list[int], prev_state: _STATE_TYPE, encoder_out: npt.NDArray[np.float32]
6970
) -> tuple[npt.NDArray[np.float32], int, _STATE_TYPE]:
7071
(decoder_out,) = self._decoder.run(["decoder_out"], {"y": [[-1, self._blank_idx, *prev_tokens][-self.CONTEXT_SIZE :]]})
72+
assert is_float32_array(decoder_out)
7173
(logit,) = self._joiner.run(["logit"], {"encoder_out": encoder_out[None, :], "decoder_out": decoder_out})
74+
assert is_float32_array(logit)
7275
return np.squeeze(logit), -1, prev_state
7376

7477

@@ -82,8 +85,10 @@ def _decode(
8285

8386
decoder_out = prev_state.get(context)
8487
if decoder_out is None:
85-
(decoder_out,) = self._decoder.run(["decoder_out"], {"y": [context]})
86-
prev_state[context] = decoder_out
88+
(_decoder_out,) = self._decoder.run(["decoder_out"], {"y": [context]})
89+
assert is_float32_array(_decoder_out)
90+
prev_state[context] = (decoder_out := _decoder_out)
8791

8892
(logit,) = self._joiner.run(["logit"], {"encoder_out": encoder_out[None, :], "decoder_out": decoder_out})
93+
assert is_float32_array(logit)
8994
return np.squeeze(logit), -1, prev_state

src/onnx_asr/models/nemo.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import onnxruntime as rt
88

99
from onnx_asr.asr import _AsrWithCtcDecoding, _AsrWithDecoding, _AsrWithTransducerDecoding
10-
from onnx_asr.utils import OnnxSessionOptions
10+
from onnx_asr.utils import OnnxSessionOptions, is_float32_array
1111

1212

1313
class _NemoConformer(_AsrWithDecoding):
@@ -47,6 +47,7 @@ def _encode(
4747
self, features: npt.NDArray[np.float32], features_lens: npt.NDArray[np.int64]
4848
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
4949
(logprobs,) = self._model.run(["logprobs"], {"audio_signal": features, "length": features_lens})
50+
assert is_float32_array(logprobs)
5051
return logprobs, (features_lens - 1) // self._subsampling_factor + 1
5152

5253

@@ -86,7 +87,7 @@ def _encode(
8687
encoder_out, encoder_out_lens = self._encoder.run(
8788
["outputs", "encoded_lengths"], {"audio_signal": features, "length": features_lens}
8889
)
89-
return encoder_out, encoder_out_lens
90+
return encoder_out, encoder_out_lens # type: ignore
9091

9192
def _create_state(self) -> _STATE_TYPE:
9293
shapes = {x.name: x.shape for x in self._decoder_joint.get_inputs()}
@@ -98,7 +99,7 @@ def _create_state(self) -> _STATE_TYPE:
9899
def _decode(
99100
self, prev_tokens: list[int], prev_state: _STATE_TYPE, encoder_out: npt.NDArray[np.float32]
100101
) -> tuple[npt.NDArray[np.float32], int, _STATE_TYPE]:
101-
outputs, *state = self._decoder_joint.run(
102+
outputs, state1, state2 = self._decoder_joint.run(
102103
["outputs", "output_states_1", "output_states_2"],
103104
{
104105
"encoder_outputs": encoder_out[None, :, None],
@@ -108,7 +109,8 @@ def _decode(
108109
"input_states_2": prev_state[1],
109110
},
110111
)
111-
return np.squeeze(outputs), -1, tuple(state)
112+
assert is_float32_array(outputs) and is_float32_array(state1) and is_float32_array(state2)
113+
return np.squeeze(outputs), -1, (state1, state2)
112114

113115

114116
class NemoConformerTdt(NemoConformerRnnt):

src/onnx_asr/models/pyannote.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
"""PyAnnote VAD implementation."""
22

3-
import typing
43
from pathlib import Path
54

65
import numpy as np
76
import numpy.typing as npt
87
import onnxruntime as rt
98

10-
from onnx_asr.utils import OnnxSessionOptions
9+
from onnx_asr.utils import OnnxSessionOptions, is_float32_array
1110
from onnx_asr.vad import Vad
1211

1312

@@ -31,4 +30,5 @@ def _get_model_files(quantization: str | None = None) -> dict[str, str]:
3130

3231
def _encode(self, waveforms: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
3332
(logits,) = self._model.run(["logits"], {"input_values": waveforms[:, None]})
34-
return typing.cast(npt.NDArray[np.float32], logits)
33+
assert is_float32_array(logits)
34+
return logits

src/onnx_asr/models/silero.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Silero VAD implementation."""
22

3-
import typing
43
from collections.abc import Iterable, Iterator
54
from itertools import chain
65
from pathlib import Path
@@ -9,7 +8,7 @@
98
import numpy.typing as npt
109
import onnxruntime as rt
1110

12-
from onnx_asr.utils import OnnxSessionOptions
11+
from onnx_asr.utils import OnnxSessionOptions, is_float32_array
1312
from onnx_asr.vad import Vad
1413

1514

@@ -44,8 +43,10 @@ def _encode(self, waveforms: npt.NDArray[np.float32]) -> Iterator[npt.NDArray[np
4443

4544
def process(frame: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
4645
nonlocal state
47-
output, state = self._model.run(["output", "stateN"], {"input": frame, "state": state, "sr": [self.SAMPLE_RATE]})
48-
return typing.cast(npt.NDArray[np.float32], output[:, 0])
46+
output, new_state = self._model.run(["output", "stateN"], {"input": frame, "state": state, "sr": [self.SAMPLE_RATE]})
47+
assert is_float32_array(output) and is_float32_array(new_state)
48+
state = new_state
49+
return output[:, 0]
4950

5051
yield process(np.pad(waveforms[:, : self.HOP_SIZE], ((0, 0), (self.CONTEXT_SIZE, 0))))
5152

src/onnx_asr/models/whisper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import onnxruntime as rt
1212

1313
from onnx_asr.asr import Asr, TimestampedResult
14-
from onnx_asr.utils import OnnxSessionOptions
14+
from onnx_asr.utils import OnnxSessionOptions, is_float32_array, is_int32_array
1515

1616

1717
@typing.no_type_check
@@ -129,7 +129,8 @@ def _decoding(
129129
"decoder_input_ids": tokens.astype(np.int32),
130130
},
131131
)
132-
return typing.cast(npt.NDArray[np.int32], sequences)[:, 0, :].astype(np.int64)
132+
assert is_int32_array(sequences)
133+
return sequences[:, 0, :].astype(np.int64)
133134

134135

135136
class WhisperHf(_Whisper):
@@ -162,11 +163,13 @@ def _preprocessor_name(self) -> str:
162163
def _encode(self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64]) -> npt.NDArray[np.float32]:
163164
input_features = super()._encode(waveforms, waveforms_len)
164165
(last_hidden_state,) = self._encoder.run(["last_hidden_state"], {"input_features": input_features})
165-
return typing.cast(npt.NDArray[np.float32], last_hidden_state)
166+
assert is_float32_array(last_hidden_state)
167+
return last_hidden_state
166168

167169
def _decode(self, tokens: npt.NDArray[np.int64], encoder_out: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
168170
(logits,) = self._decoder.run(["logits"], {"input_ids": tokens, "encoder_hidden_states": encoder_out})
169-
return typing.cast(npt.NDArray[np.float32], logits)
171+
assert is_float32_array(logits)
172+
return logits
170173

171174
def _decoding(
172175
self, input_features: npt.NDArray[np.float32], tokens: npt.NDArray[np.int64], max_length: int = 448

src/onnx_asr/preprocessors/preprocessor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy.typing as npt
88
import onnxruntime as rt
99

10-
from onnx_asr.utils import OnnxSessionOptions
10+
from onnx_asr.utils import OnnxSessionOptions, is_float32_array, is_int64_array
1111

1212

1313
class Preprocessor:
@@ -33,4 +33,5 @@ def __call__(
3333
features, features_lens = self._preprocessor.run(
3434
["features", "features_lens"], {"waveforms": waveforms, "waveforms_lens": waveforms_lens}
3535
)
36+
assert is_float32_array(features) and is_int64_array(features_lens)
3637
return features, features_lens

src/onnx_asr/preprocessors/resampler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy.typing as npt
77
import onnxruntime as rt
88

9-
from onnx_asr.utils import OnnxSessionOptions, SampleRates
9+
from onnx_asr.utils import OnnxSessionOptions, SampleRates, is_float32_array, is_int64_array
1010

1111

1212
class Resampler:
@@ -27,9 +27,12 @@ def __call__(
2727
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64], sample_rate: SampleRates
2828
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
2929
"""Resample waveform to 16 kHz."""
30-
if sample_rate != 16_000:
31-
waveforms, waveforms_lens = self._preprocessor.run(
32-
["resampled", "resampled_lens"],
33-
{"waveforms": waveforms, "waveforms_lens": waveforms_lens, "sample_rate": [sample_rate]},
34-
)
35-
return waveforms, waveforms_lens
30+
if sample_rate == 16_000:
31+
return waveforms, waveforms_lens
32+
33+
resampled, resampled_lens = self._preprocessor.run(
34+
["resampled", "resampled_lens"],
35+
{"waveforms": waveforms, "waveforms_lens": waveforms_lens, "sample_rate": [sample_rate]},
36+
)
37+
assert is_float32_array(resampled) and is_int64_array(resampled_lens)
38+
return resampled, resampled_lens

src/onnx_asr/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,21 @@ def is_supported_sample_rate(sample_rate: int) -> TypeGuard[SampleRates]:
1616
return sample_rate in get_args(SampleRates)
1717

1818

19+
def is_float32_array(x: object) -> TypeGuard[npt.NDArray[np.float32]]:
20+
"""Numpy array is float32."""
21+
return isinstance(x, np.ndarray) and x.dtype == np.float32
22+
23+
24+
def is_int32_array(x: object) -> TypeGuard[npt.NDArray[np.int32]]:
25+
"""Numpy array is int32."""
26+
return isinstance(x, np.ndarray) and x.dtype == np.int32
27+
28+
29+
def is_int64_array(x: object) -> TypeGuard[npt.NDArray[np.int64]]:
30+
"""Numpy array is int64."""
31+
return isinstance(x, np.ndarray) and x.dtype == np.int64
32+
33+
1934
class SupportedOnlyMonoAudioError(ValueError):
2035
"""Supported only mono audio error."""
2136

0 commit comments

Comments
 (0)