77import onnxruntime as rt
88
99from onnx_asr .asr import _AsrWithCtcDecoding , _AsrWithDecoding , _AsrWithTransducerDecoding
10- from onnx_asr .utils import OnnxSessionOptions
10+ from onnx_asr .utils import OnnxSessionOptions , is_float32_array
1111
1212
1313class _NemoConformer (_AsrWithDecoding ):
@@ -47,6 +47,7 @@ def _encode(
4747 self , features : npt .NDArray [np .float32 ], features_lens : npt .NDArray [np .int64 ]
4848 ) -> tuple [npt .NDArray [np .float32 ], npt .NDArray [np .int64 ]]:
4949 (logprobs ,) = self ._model .run (["logprobs" ], {"audio_signal" : features , "length" : features_lens })
50+ assert is_float32_array (logprobs )
5051 return logprobs , (features_lens - 1 ) // self ._subsampling_factor + 1
5152
5253
@@ -86,7 +87,7 @@ def _encode(
8687 encoder_out , encoder_out_lens = self ._encoder .run (
8788 ["outputs" , "encoded_lengths" ], {"audio_signal" : features , "length" : features_lens }
8889 )
89- return encoder_out , encoder_out_lens
90+ return encoder_out , encoder_out_lens # type: ignore
9091
9192 def _create_state (self ) -> _STATE_TYPE :
9293 shapes = {x .name : x .shape for x in self ._decoder_joint .get_inputs ()}
@@ -98,7 +99,7 @@ def _create_state(self) -> _STATE_TYPE:
9899 def _decode (
99100 self , prev_tokens : list [int ], prev_state : _STATE_TYPE , encoder_out : npt .NDArray [np .float32 ]
100101 ) -> tuple [npt .NDArray [np .float32 ], int , _STATE_TYPE ]:
101- outputs , * state = self ._decoder_joint .run (
102+ outputs , state1 , state2 = self ._decoder_joint .run (
102103 ["outputs" , "output_states_1" , "output_states_2" ],
103104 {
104105 "encoder_outputs" : encoder_out [None , :, None ],
@@ -108,7 +109,8 @@ def _decode(
108109 "input_states_2" : prev_state [1 ],
109110 },
110111 )
111- return np .squeeze (outputs ), - 1 , tuple (state )
112+ assert is_float32_array (outputs ) and is_float32_array (state1 ) and is_float32_array (state2 )
113+ return np .squeeze (outputs ), - 1 , (state1 , state2 )
112114
113115
114116class NemoConformerTdt (NemoConformerRnnt ):
0 commit comments