Skip to content

Commit 79b4cf1

Browse files
committed
Add mypy to pre-commit
Add ruff rules
1 parent 3970805 commit 79b4cf1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+475
-435
lines changed

.pre-commit-config.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,8 @@ repos:
2626
entry: ruff format
2727
language: system
2828
types: [ python ]
29+
- id: mypy
30+
name: mypy
31+
entry: mypy src
32+
pass_filenames: false
33+
language: system

pyproject.toml

+30-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ dependencies = [
3030
"pydantic-extra-types==2.9.0",
3131
"pydantic[email]==2.9.2",
3232
"PyYAML==6.0.1",
33-
"ruff==0.9.5"
33+
"ruff==0.11.0"
3434
]
3535
dynamic = ["version"]
3636

@@ -75,6 +75,7 @@ select = [
7575
"PLE", # Pylint error
7676
"PLW", # Pylint warning
7777
"RUF", # Ruff-specific rules
78+
"T20", # flake8-print
7879
"UP", # pyupgrade
7980
"W", # pycodestyle warning
8081
]
@@ -93,3 +94,31 @@ ignore = [
9394
[tool.ruff.format]
9495
docstring-code-format = true
9596
docstring-code-line-length = "dynamic"
97+
98+
[tool.mypy]
99+
plugins = ["pydantic.mypy"]
100+
files = [
101+
"src"
102+
]
103+
mypy_path = [
104+
"src"
105+
]
106+
follow_imports = "skip"
107+
check_untyped_defs = true
108+
disallow_any_generics = true
109+
disallow_untyped_defs = true
110+
follow_untyped_imports = false
111+
ignore_missing_imports = true
112+
strict_equality = true
113+
warn_redundant_casts = true
114+
warn_unreachable = true
115+
warn_unused_ignores = true
116+
disable_error_code = [
117+
"import-untyped",
118+
"misc",
119+
"operator",
120+
"override",
121+
"return",
122+
"type-arg",
123+
"union-attr"
124+
]

src/demo_app/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def create_app(version: int = 1) -> Quart:
2727
return app
2828

2929

30-
def _register_blueprints(app, version: int):
30+
def _register_blueprints(app: Quart, version: int) -> None:
3131
from demo_app.api.auth.auth import bp_auth
3232
from demo_app.api.user.user import bp_user
3333
from demo_app.handlers.error_handlers import bp_error_handler

src/demo_app/api/user/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class SocialLinks(BaseModel):
3030

3131

3232
class Preferences(BaseModel):
33-
theme: UserTheme | None = UserTheme.LIGHT_MODE.value
33+
theme: UserTheme | None = UserTheme.LIGHT_MODE
3434
language: str | None = None
3535
font_size: int | None = Field(None, ge=8, le=40, multiple_of=2)
3636

src/demo_app/handlers/error_handlers.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
11
import json
22
from dataclasses import dataclass
3+
from typing import Any
34

45
from quart import Blueprint, Response, jsonify, make_response, request
56
from quart import current_app as app
67
from quart_schema import RequestSchemaValidationError
7-
from werkzeug.exceptions import NotFound
8+
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
89

910
bp_error_handler = Blueprint("error_handler", __name__)
1011

1112

1213
@dataclass
1314
class Error:
1415
code: int
15-
message: str
16+
message: Any
1617
request_id: str
1718

1819

1920
@bp_error_handler.app_errorhandler(400)
20-
async def handle_bad_request_error(error) -> Response:
21+
async def handle_bad_request_error(error: BadRequest) -> Response:
2122
err = Error(code=400, message=str(error), request_id=request.headers["X-Request-ID"])
2223
return await make_response(jsonify({"error": err}), 400)
2324

2425

2526
@bp_error_handler.app_errorhandler(401)
26-
async def handle_unauthorized_request(error) -> Response:
27+
async def handle_unauthorized_request(error: Unauthorized) -> Response:
2728
err = Error(code=401, message="Login required", request_id=request.headers["X-Request-ID"])
2829
return await make_response(jsonify({"error": err}), 401)
2930

src/demo_app/handlers/request_handlers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
@bp_request_handler.before_app_request
9-
async def before_request():
9+
async def before_request() -> None:
1010
if not request.headers.get("X-Request-ID"):
1111
request.headers["X-Request-ID"] = str(uuid.uuid4())
1212

src/demo_app/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
DEFAULT_PORT = 5000
66

77

8-
def parse_pargs():
8+
def parse_pargs() -> argparse.Namespace:
99
parser = argparse.ArgumentParser()
1010
parser.add_argument("-p", "--port", dest="port", type=int, default=DEFAULT_PORT)
1111
return parser.parse_args()

src/openapi_test_client/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def is_external_project() -> bool:
2828
return bool(os.environ.get(ENV_VAR_PACKAGE_DIR, "") or not Path.cwd().is_relative_to(_PROJECT_ROOT_DIR))
2929

3030

31-
def find_external_package_dir(current_dir: Path | None = None, missing_ok: bool = False) -> Path:
31+
def find_external_package_dir(current_dir: Path | None = None, missing_ok: bool = False) -> Path | None:
3232
"""Find an external package directory
3333
3434
An external directory should have .api_test_client hidden file
@@ -45,17 +45,17 @@ def find_external_package_dir(current_dir: Path | None = None, missing_ok: bool
4545
f"Please create a project directory and start from there"
4646
)
4747

48-
def search_parent_dirs(dir: Path):
48+
def search_parent_dirs(dir: Path) -> Path | None:
4949
if (dir / filename).exists():
5050
return dir
5151

5252
parent = dir.parent
5353
if parent == dir:
54-
return
54+
return None
5555

5656
return search_parent_dirs(parent)
5757

58-
def search_child_dirs(dir: Path):
58+
def search_child_dirs(dir: Path) -> Path | None:
5959
if hidden_files := glob.glob(f"**/{filename}", root_dir=dir, recursive=True):
6060
module_paths = [(dir / x).parent for x in hidden_files]
6161
if len(module_paths) > 1:
@@ -82,7 +82,7 @@ def get_package_dir() -> Path:
8282
return Path(api_client_package_dir).resolve()
8383
else:
8484
try:
85-
return find_external_package_dir()
85+
return find_external_package_dir() or _PACKAGE_DIR
8686
except FileNotFoundError:
8787
# Initial script run from an external location. The directory hasn't been setup yet
8888
return _PACKAGE_DIR
+1-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1 @@
1-
from typing import TYPE_CHECKING, TypeVar
2-
3-
from openapi_test_client.clients.base import OpenAPIClient
4-
5-
if TYPE_CHECKING:
6-
APIClientType = TypeVar("APIClientType", bound=OpenAPIClient)
1+
from .base import OpenAPIClient

src/openapi_test_client/clients/base.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import importlib
44
import inspect
55
import json
6-
from typing import TYPE_CHECKING
6+
from typing import Any, TypeVar
77

88
from common_libs.clients.rest_client import RestClient
99
from common_libs.logging import get_logger
@@ -12,12 +12,10 @@
1212
from openapi_test_client.libraries.api.api_spec import OpenAPISpec
1313
from openapi_test_client.libraries.common.misc import get_module_name_by_file_path
1414

15-
if TYPE_CHECKING:
16-
from openapi_test_client.clients import APIClientType
17-
18-
1915
logger = get_logger(__name__)
2016

17+
T = TypeVar("T", bound="OpenAPIClient")
18+
2119

2220
class OpenAPIClient:
2321
"""Base class for all clients"""
@@ -50,11 +48,11 @@ def base_url(self) -> str:
5048
return self._base_url
5149

5250
@base_url.setter
53-
def base_url(self, url: str):
51+
def base_url(self, url: str) -> None:
5452
self._base_url = url
5553

56-
@staticmethod
57-
def get_client(app_name: str, **init_options) -> APIClientType:
54+
@classmethod
55+
def get_client(cls: type[T], app_name: str, **init_options: Any) -> T:
5856
"""Get API client for the app
5957
6058
:param app_name: App name
@@ -68,13 +66,13 @@ def get_client(app_name: str, **init_options) -> APIClientType:
6866

6967
client_module_name = get_module_name_by_file_path(client_file)
7068
mod = importlib.import_module(client_module_name)
71-
clients = [
69+
client_classes: list[type[T]] = [
7270
x
7371
for x in mod.__dict__.values()
7472
if inspect.isclass(x) and issubclass(x, OpenAPIClient) and x is not OpenAPIClient
7573
]
76-
if len(clients) != 1:
74+
if len(client_classes) != 1:
7775
raise RuntimeError(f"Unable to locate the API client for {app_name} from {mod}")
7876

79-
api_client: type[APIClientType] = clients[0]
80-
return api_client(**init_options)
77+
APIClientClass = client_classes[0]
78+
return APIClientClass(**init_options)

src/openapi_test_client/clients/demo_app/api/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
from .base import DemoAppBaseAPI
44

5-
API_CLASSES = init_api_classes(DemoAppBaseAPI)
5+
API_CLASSES = init_api_classes(DemoAppBaseAPI) # type: ignore[type-abstract]

src/openapi_test_client/clients/demo_app/api/base/demo_app_api.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Any
44

55
from common_libs.clients.rest_client import RestResponse
66
from requests.exceptions import RequestException
@@ -24,9 +24,9 @@ def post_request_hook(
2424
endpoint: Endpoint,
2525
response: RestResponse | None,
2626
request_exception: RequestException | None,
27-
*path_params,
28-
**params,
29-
):
27+
*path_params: Any,
28+
**params: Any,
29+
) -> None:
3030
super().post_request_hook(endpoint, response, request_exception, *path_params, **params)
3131
if response and response.ok:
3232
if endpoint in self.api_client.Auth.endpoints:

src/openapi_test_client/clients/demo_app/api/request_hooks/post_request.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Any
44

55
if TYPE_CHECKING:
66
from common_libs.clients.rest_client import RestResponse
@@ -10,8 +10,8 @@
1010

1111

1212
def do_something_after_request(
13-
api_client: DemoAppAPIClient, endpoint: Endpoint, r: RestResponse, *path_params, **params
14-
):
13+
api_client: DemoAppAPIClient, endpoint: Endpoint, r: RestResponse, *path_params: Any, **params: Any
14+
) -> None:
1515
"""This is a template of the post-request hook that will be called right after making a request
1616
1717
To enable this hook, call this function inside the base API class's post_request_hook():
@@ -31,10 +31,10 @@ def do_something_after_request(
3131
>>> do_something_after_request(self.api_client, endpoint, response, *path_params, *params)
3232
"""
3333
# Do something after request
34-
pass
34+
...
3535

3636

37-
def manage_auth_session(api_client: DemoAppAPIClient, endpoint: Endpoint, r: RestResponse):
37+
def manage_auth_session(api_client: DemoAppAPIClient, endpoint: Endpoint, r: RestResponse) -> None:
3838
"""Manage auth after successful login/logout
3939
4040
:param api_client: API client
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Any
44

55
if TYPE_CHECKING:
66
from openapi_test_client.clients.demo_app import DemoAppAPIClient
77
from openapi_test_client.libraries.api import Endpoint
88

99

10-
def do_something_before_request(api_client: DemoAppAPIClient, endpoint: Endpoint, **params):
10+
def do_something_before_request(api_client: DemoAppAPIClient, endpoint: Endpoint, **params: Any) -> None:
1111
"""This is a template of the pre-request hook that will be called right before making a request
1212
1313
To enable this hook, call this function inside the base API class's pre_request_hook():
@@ -17,4 +17,4 @@ def do_something_before_request(api_client: DemoAppAPIClient, endpoint: Endpoint
1717
>>> do_something_before_request(self.api_client, endpoint, *params)
1818
"""
1919
# Do something before request
20-
pass
20+
...

src/openapi_test_client/clients/demo_app/api/request_hooks/request_wrapper.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,10 @@
22

33
from collections.abc import Callable
44
from functools import wraps
5-
from typing import TYPE_CHECKING, ParamSpec
5+
from typing import ParamSpec
66

77
from common_libs.clients.rest_client import RestResponse
88

9-
if TYPE_CHECKING:
10-
from openapi_test_client.libraries.api import EndpointFunc
11-
12-
139
P = ParamSpec("P")
1410

1511

@@ -30,9 +26,9 @@ def do_something_before_and_after_request(f: Callable[P, RestResponse]) -> Calla
3026
"""
3127

3228
@wraps(f)
33-
def wrapper(endpoint_func: EndpointFunc, *args: P.args, **kwargs: P.kwargs) -> RestResponse:
29+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> RestResponse:
3430
# Do something before request
35-
r = f(endpoint_func, *args, **kwargs)
31+
r = f(*args, **kwargs)
3632
# Do something after request
3733
return r
3834

src/openapi_test_client/libraries/api/api_classes/__init__.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,13 @@
33
import inspect
44
import itertools
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, TypeVar
76

87
from openapi_test_client.libraries.common.misc import import_module_from_file_path
98

109
from .base import APIBase
1110

12-
if TYPE_CHECKING:
13-
APIClassType = TypeVar("APIClassType", bound=APIBase)
1411

15-
16-
def init_api_classes(base_api_class: type[APIClassType]) -> list[type[APIClassType]]:
12+
def init_api_classes(base_api_class: type[APIBase]) -> list[type[APIBase]]:
1713
"""Initialize API classes and return a list of API classes.
1814
1915
- A list of Endpoint objects for an API class is available via its `endpoints` attribute
@@ -37,6 +33,7 @@ def init_api_classes(base_api_class: type[APIClassType]) -> list[type[APIClassTy
3733
from openapi_test_client.libraries.api.api_functions import EndpointFunc, EndpointHandler
3834

3935
previous_frame = inspect.currentframe().f_back
36+
assert previous_frame
4037
caller_file_path = inspect.getframeinfo(previous_frame).filename
4138
assert caller_file_path.endswith("__init__.py"), (
4239
f"API classes must be initialized in __init__.py. Unexpectedly called from {caller_file_path}"
@@ -50,7 +47,7 @@ def init_api_classes(base_api_class: type[APIClassType]) -> list[type[APIClassTy
5047
api_class.endpoints = []
5148
for attr_name, attr in api_class.__dict__.items():
5249
if isinstance(attr, EndpointHandler):
53-
endpoint_func = getattr(api_class, attr_name)
50+
endpoint_func: EndpointFunc = getattr(api_class, attr_name)
5451
assert isinstance(endpoint_func, EndpointFunc)
5552
api_class.endpoints.append(endpoint_func.endpoint)
5653

@@ -59,10 +56,10 @@ def init_api_classes(base_api_class: type[APIClassType]) -> list[type[APIClassTy
5956
itertools.chain(*(x.endpoints for x in api_classes if x.endpoints)),
6057
key=lambda x: (x.tags, x.method, x.path),
6158
)
62-
return sorted(api_classes, key=lambda x: x.TAGs)
59+
return sorted(api_classes, key=lambda x: x.TAGs) # type: ignore[arg-type, return-value]
6360

6461

65-
def get_api_classes(api_class_dir: Path, base_api_class: type[APIClassType]) -> list[type[APIClassType]]:
62+
def get_api_classes(api_class_dir: Path, base_api_class: type[APIBase]) -> list[type[APIBase]]:
6663
"""Get all API classes defined under the given API class directory"""
6764
assert api_class_dir.is_dir()
6865

0 commit comments

Comments
 (0)