22
33import re
44from abc import ABC , abstractmethod
5- from collections .abc import Iterator
5+ from collections .abc import Iterable
66from pathlib import Path
77from typing import Any
88
@@ -46,7 +46,7 @@ def recognize(
4646class _AsrWithDecoding (Asr ):
4747 DECODE_SPACE_PATTERN = re .compile (r"\A\u2581|\u2581\B|(\u2581)\b" )
4848
49- def __init__ (self , preprocessor_name : Preprocessor . PreprocessorNames , vocab_path : Path , ** kwargs ):
49+ def __init__ (self , preprocessor_name : str , vocab_path : Path , ** kwargs : Any ):
5050 self ._preprocessor = Preprocessor (preprocessor_name , ** kwargs )
5151 with Path (vocab_path ).open ("rt" ) as f :
5252 tokens = {token : int (id ) for token , id in (line .strip ("\n " ).split (" " ) for line in f .readlines ())}
@@ -59,7 +59,7 @@ def _encode(
5959 ) -> tuple [npt .NDArray [np .float32 ], npt .NDArray [np .int64 ]]: ...
6060
6161 @abstractmethod
62- def _decoding (self , encoder_out : npt .NDArray [np .float32 ], encoder_out_lens : npt .NDArray [np .int64 ]) -> Iterator [list [int ]]: ...
62+ def _decoding (self , encoder_out : npt .NDArray [np .float32 ], encoder_out_lens : npt .NDArray [np .int64 ]) -> Iterable [list [int ]]: ...
6363
6464 def _decode_tokens (self , tokens : list [int ]) -> str :
6565 text = "" .join ([self ._vocab [i ] for i in tokens ])
@@ -70,7 +70,7 @@ def _recognize_batch(self, waveforms: list[npt.NDArray[np.float32]], language: s
7070
7171
7272class _AsrWithCtcDecoding (_AsrWithDecoding ):
73- def _decoding (self , encoder_out : npt .NDArray [np .float32 ], encoder_out_lens : npt .NDArray [np .int64 ]) -> Iterator [list [int ]]:
73+ def _decoding (self , encoder_out : npt .NDArray [np .float32 ], encoder_out_lens : npt .NDArray [np .int64 ]) -> Iterable [list [int ]]:
7474 assert encoder_out .shape [- 1 ] <= len (self ._vocab )
7575
7676 for log_probs , log_probs_len in zip (encoder_out , encoder_out_lens , strict = True ):
@@ -82,21 +82,21 @@ def _decoding(self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.
8282
8383class _AsrWithRnntDecoding (_AsrWithDecoding ):
8484 @abstractmethod
85- def _create_state (self ) -> Any : ...
85+ def _create_state (self ) -> tuple : ...
8686
8787 @property
8888 @abstractmethod
8989 def _max_tokens_per_step (self ) -> int : ...
9090
9191 @abstractmethod
9292 def _decode (
93- self , prev_tokens : list [int ], prev_state : Any , encoder_out : npt .NDArray [np .float32 ]
94- ) -> tuple [npt .NDArray [np .float32 ], Any ]: ...
93+ self , prev_tokens : list [int ], prev_state : tuple , encoder_out : npt .NDArray [np .float32 ]
94+ ) -> tuple [npt .NDArray [np .float32 ], tuple ]: ...
9595
96- def _decoding (self , encoder_out : npt .NDArray [np .float32 ], encoder_out_lens : npt .NDArray [np .int64 ]) -> Iterator [list [int ]]:
96+ def _decoding (self , encoder_out : npt .NDArray [np .float32 ], encoder_out_lens : npt .NDArray [np .int64 ]) -> Iterable [list [int ]]:
9797 for encodings , encodings_len in zip (encoder_out , encoder_out_lens , strict = True ):
9898 prev_state = self ._create_state ()
99- tokens = []
99+ tokens : list [ int ] = []
100100
101101 for t in range (encodings_len ):
102102 emitted_tokens = 0
0 commit comments