Skip to content

Commit 6bd7139

Browse files
committed
Add MyPy. Fix errors.
1 parent bf3624d commit 6bd7139

12 files changed

Lines changed: 225 additions & 122 deletions

File tree

.github/workflows/python-package.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ jobs:
2626
with:
2727
python-version: ${{ matrix.python-version }}
2828
- name: Install dependencies
29-
run: pdm install
29+
run: pdm sync
3030
- name: Lint code with Ruff
3131
run: pdm run ruff check --output-format=github
3232
- name: Check code formatting with Ruff
3333
run: pdm run ruff format --diff
34+
- name: Check types with MyPy
35+
run: pdm run mypy .
3436
- name: Test with pytest
3537
run: pdm run pytest

pdm.lock

Lines changed: 133 additions & 65 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ onnx-asr = "onnx_asr.cli:run"
4848
build = [
4949
"onnx>=1.17.0",
5050
"onnxscript>=0.2.5",
51-
"torch~=2.6.0",
52-
"torchaudio~=2.6.0",
51+
"torch>=2.6.0",
52+
"torchaudio>=2.6.0",
5353
]
5454
asrs = [
5555
"kaldi-native-fbank>=1.21.1",
@@ -62,7 +62,7 @@ test = [
6262
{ include-group = "build" },
6363
{ include-group = "asrs" },
6464
]
65-
lint = ["ruff>=0.11.6"]
65+
lint = ["ruff>=0.11.6", "mypy>=1.15.0"]
6666

6767
[tool.pdm]
6868
distribution = true
@@ -76,19 +76,33 @@ source-includes = ["preprocessors", "tests"]
7676
[tool.pdm.scripts]
7777
build_preprocessors = { call = "preprocessors.build:build" }
7878
post_install = { composite = ["build_preprocessors"] }
79-
pre_build = { composite = ["pdm install --with build"] }
79+
pre_build = { composite = ["pdm sync --group build"] }
80+
lint = { composite = ["ruff format --diff", "ruff check", "mypy ."] }
8081

8182
[[tool.pdm.source]]
8283
name = "torch-cpu"
8384
url = "https://download.pytorch.org/whl/cpu"
8485
include_packages = ["torch*"]
8586

87+
[tool.mypy]
88+
python_version = "3.10"
89+
warn_return_any = true
90+
warn_unused_configs = true
91+
disallow_untyped_defs = true
92+
pretty = true
93+
exclude = ['^preprocessors.', '^tests.']
94+
95+
[[tool.mypy.overrides]]
96+
module = ["onnxruntime.*"]
97+
follow_untyped_imports = true
98+
8699
[tool.ruff]
87100
line-length = 130
88101
indent-width = 4
89102
target-version = "py310"
90103

91104
[tool.ruff.lint]
105+
exclude = ["*.ipynb"]
92106
select = [
93107
"B", # flake8-bugbear
94108
"C4", # flake8-comprehensions
@@ -114,6 +128,7 @@ select = [
114128
"W", # pycodestyle
115129
"YTT", # flake8-2020
116130
]
131+
ignore = ["D203", "D213"]
117132

118133
[tool.ruff.lint.per-file-ignores]
119134
"tests/*" = ["D100", "D103", "D104"]
@@ -123,5 +138,6 @@ select = [
123138
filterwarnings = [
124139
"ignore::DeprecationWarning:onnxscript.*",
125140
"ignore::DeprecationWarning:google.protobuf.*",
141+
"ignore::DeprecationWarning:torchmetrics.*",
126142
"ignore::FutureWarning:onnxscript.*",
127143
]

src/onnx_asr/asr.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from abc import ABC, abstractmethod
5-
from collections.abc import Iterator
5+
from collections.abc import Iterable
66
from pathlib import Path
77
from typing import Any
88

@@ -46,7 +46,7 @@ def recognize(
4646
class _AsrWithDecoding(Asr):
4747
DECODE_SPACE_PATTERN = re.compile(r"\A\u2581|\u2581\B|(\u2581)\b")
4848

49-
def __init__(self, preprocessor_name: Preprocessor.PreprocessorNames, vocab_path: Path, **kwargs):
49+
def __init__(self, preprocessor_name: str, vocab_path: Path, **kwargs: Any):
5050
self._preprocessor = Preprocessor(preprocessor_name, **kwargs)
5151
with Path(vocab_path).open("rt") as f:
5252
tokens = {token: int(id) for token, id in (line.strip("\n").split(" ") for line in f.readlines())}
@@ -59,7 +59,7 @@ def _encode(
5959
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: ...
6060

6161
@abstractmethod
62-
def _decoding(self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.NDArray[np.int64]) -> Iterator[list[int]]: ...
62+
def _decoding(self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.NDArray[np.int64]) -> Iterable[list[int]]: ...
6363

6464
def _decode_tokens(self, tokens: list[int]) -> str:
6565
text = "".join([self._vocab[i] for i in tokens])
@@ -70,7 +70,7 @@ def _recognize_batch(self, waveforms: list[npt.NDArray[np.float32]], language: s
7070

7171

7272
class _AsrWithCtcDecoding(_AsrWithDecoding):
73-
def _decoding(self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.NDArray[np.int64]) -> Iterator[list[int]]:
73+
def _decoding(self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.NDArray[np.int64]) -> Iterable[list[int]]:
7474
assert encoder_out.shape[-1] <= len(self._vocab)
7575

7676
for log_probs, log_probs_len in zip(encoder_out, encoder_out_lens, strict=True):
@@ -82,21 +82,21 @@ def _decoding(self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.
8282

8383
class _AsrWithRnntDecoding(_AsrWithDecoding):
8484
@abstractmethod
85-
def _create_state(self) -> Any: ...
85+
def _create_state(self) -> tuple: ...
8686

8787
@property
8888
@abstractmethod
8989
def _max_tokens_per_step(self) -> int: ...
9090

9191
@abstractmethod
9292
def _decode(
93-
self, prev_tokens: list[int], prev_state: Any, encoder_out: npt.NDArray[np.float32]
94-
) -> tuple[npt.NDArray[np.float32], Any]: ...
93+
self, prev_tokens: list[int], prev_state: tuple, encoder_out: npt.NDArray[np.float32]
94+
) -> tuple[npt.NDArray[np.float32], tuple]: ...
9595

96-
def _decoding(self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.NDArray[np.int64]) -> Iterator[list[int]]:
96+
def _decoding(self, encoder_out: npt.NDArray[np.float32], encoder_out_lens: npt.NDArray[np.int64]) -> Iterable[list[int]]:
9797
for encodings, encodings_len in zip(encoder_out, encoder_out_lens, strict=True):
9898
prev_state = self._create_state()
99-
tokens = []
99+
tokens: list[int] = []
100100

101101
for t in range(encodings_len):
102102
emitted_tokens = 0

src/onnx_asr/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from onnx_asr.loader import ModelNames, ModelTypes
99

1010

11-
def run():
11+
def run() -> None:
1212
"""Run CLI for ASR models."""
1313
parser = argparse.ArgumentParser(prog="onnx_asr", description="Automatic Speech Recognition in Python using ONNX models.")
1414
parser.add_argument(

src/onnx_asr/loader.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Sequence
44
from pathlib import Path
5-
from typing import Any, Literal, get_args
5+
from typing import Literal, get_args
66

77
import onnxruntime as rt
88

@@ -39,7 +39,17 @@
3939
ModelVersions = Literal["int8"] | None
4040

4141

42-
def _get_model_class(model: str):
42+
def _get_model_class(
43+
model: str,
44+
) -> (
45+
type[GigaamV2Ctc]
46+
| type[GigaamV2Rnnt]
47+
| type[KaldiTransducer]
48+
| type[NemoConformerCtc]
49+
| type[NemoConformerRnnt]
50+
| type[WhisperOrt]
51+
| type[WhisperHf]
52+
):
4353
match model.split("/"):
4454
case ("gigaam-v2-ctc",):
4555
return GigaamV2Ctc
@@ -61,10 +71,10 @@ def _get_model_class(model: str):
6171
raise ValueError(f"Model '{model}' not supported!") # noqa: TRY003
6272

6373

64-
def _resolve_paths(path: str | Path, model_files: dict[str, str]):
74+
def _resolve_paths(path: str | Path, model_files: dict[str, str]) -> dict[str, Path]:
6575
assert Path(path).is_dir(), f"The path '{path}' is not a directory."
6676

67-
def find(filename):
77+
def find(filename: str) -> Path:
6878
files = list(Path(path).glob(filename))
6979
assert len(files) > 0, f"File '{filename}' not found in path '{path}'."
7080
assert len(files) == 1, f"Found more than 1 file '{filename}' found in path '{path}'."
@@ -73,7 +83,7 @@ def find(filename):
7383
return {key: find(filename) for key, filename in model_files.items()}
7484

7585

76-
def _download_model(model: ModelNames, files: list[str]) -> str:
86+
def _download_model(model: str, files: list[str]) -> str:
7787
from huggingface_hub import snapshot_download
7888

7989
match model:
@@ -94,7 +104,7 @@ def load_model(
94104
model: str | ModelNames | ModelTypes,
95105
path: str | Path | None = None,
96106
quantization: str | None = None,
97-
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
107+
providers: Sequence[str | tuple[str, dict]] | None = None,
98108
) -> Asr:
99109
"""Load ASR model.
100110
@@ -122,7 +132,7 @@ def load_model(
122132
assert model in get_args(ModelNames) or model.startswith("onnx-community/"), (
123133
"If the path is not specified, you must specify a specific model name."
124134
)
125-
path = _download_model(model, list(files.values())) # type: ignore
135+
path = _download_model(model, list(files.values()))
126136

127137
if providers is None:
128138
providers = rt.get_available_providers()

src/onnx_asr/models/gigaam.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""GigaAM v2 model implementations."""
22

33
from pathlib import Path
4+
from typing import Any
45

56
import numpy as np
67
import numpy.typing as npt
@@ -10,7 +11,7 @@
1011

1112

1213
class _GigaamV2(_AsrWithDecoding):
13-
def __init__(self, model_files: dict[str, Path], **kwargs):
14+
def __init__(self, model_files: dict[str, Path], **kwargs: Any):
1415
super().__init__("gigaam", model_files["vocab"], **kwargs)
1516

1617
@staticmethod
@@ -21,7 +22,7 @@ def _get_model_files(quantization: str | None = None) -> dict[str, str]:
2122
class GigaamV2Ctc(_AsrWithCtcDecoding, _GigaamV2):
2223
"""GigaAM v2 CTC model implementation."""
2324

24-
def __init__(self, model_files: dict[str, Path], **kwargs):
25+
def __init__(self, model_files: dict[str, Path], **kwargs: Any):
2526
"""Create GigaAM v2 CTC model.
2627
2728
Args:
@@ -50,7 +51,7 @@ class GigaamV2Rnnt(_AsrWithRnntDecoding, _GigaamV2):
5051
PRED_HIDDEN = 320
5152
STATE_TYPE = tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]
5253

53-
def __init__(self, model_files: dict[str, Path], **kwargs):
54+
def __init__(self, model_files: dict[str, Path], **kwargs: Any):
5455
"""Create GigaAM v2 RNN-T model.
5556
5657
Args:

src/onnx_asr/models/kaldi.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Kaldi model implementations."""
22

33
from pathlib import Path
4+
from typing import Any
45

56
import numpy as np
67
import numpy.typing as npt
@@ -14,7 +15,7 @@ class KaldiTransducer(_AsrWithRnntDecoding):
1415

1516
CONTEXT_SIZE = 2
1617

17-
def __init__(self, model_files: dict[str, Path], **kwargs):
18+
def __init__(self, model_files: dict[str, Path], **kwargs: Any):
1819
"""Create Kaldi Transducer model.
1920
2021
Args:
@@ -49,12 +50,12 @@ def _encode(
4950
)
5051
return encoder_out.transpose(0, 2, 1), encoder_out_lens
5152

52-
def _create_state(self) -> None:
53-
return None
53+
def _create_state(self) -> tuple:
54+
return ()
5455

5556
def _decode(
56-
self, prev_tokens: list[int], prev_state: None, encoder_out: npt.NDArray[np.float32]
57-
) -> tuple[npt.NDArray[np.float32], None]:
57+
self, prev_tokens: list[int], prev_state: tuple, encoder_out: npt.NDArray[np.float32]
58+
) -> tuple[npt.NDArray[np.float32], tuple]:
5859
(decoder_out,) = self._decoder.run(["decoder_out"], {"y": [[-1, self._blank_idx, *prev_tokens][-self.CONTEXT_SIZE :]]})
5960
(logit,) = self._joiner.run(["logit"], {"encoder_out": encoder_out[None, :], "decoder_out": decoder_out})
60-
return np.squeeze(logit), None
61+
return np.squeeze(logit), prev_state

src/onnx_asr/models/nemo.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""NeMo model implementations."""
22

33
from pathlib import Path
4+
from typing import Any
45

56
import numpy as np
67
import numpy.typing as npt
@@ -10,7 +11,7 @@
1011

1112

1213
class _NemoConformer(_AsrWithDecoding):
13-
def __init__(self, model_files: dict[str, Path], **kwargs):
14+
def __init__(self, model_files: dict[str, Path], **kwargs: Any):
1415
super().__init__("nemo", model_files["vocab"], **kwargs)
1516

1617
@staticmethod
@@ -21,7 +22,7 @@ def _get_model_files(quantization: str | None = None) -> dict[str, str]:
2122
class NemoConformerCtc(_AsrWithCtcDecoding, _NemoConformer):
2223
"""NeMo Conformer CTC model implementations."""
2324

24-
def __init__(self, model_files: dict[str, Path], **kwargs):
25+
def __init__(self, model_files: dict[str, Path], **kwargs: Any):
2526
"""Create NeMo Conformer CTC model.
2627
2728
Args:
@@ -57,7 +58,7 @@ class NemoConformerRnnt(_AsrWithRnntDecoding, _NemoConformer):
5758
MAX_TOKENS_PER_STEP = 10
5859
STATE_TYPE = tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]
5960

60-
def __init__(self, model_files: dict[str, Path], **kwargs):
61+
def __init__(self, model_files: dict[str, Path], **kwargs: Any):
6162
"""Create NeMo Conformer RNN-T model.
6263
6364
Args:

0 commit comments

Comments
 (0)