3
3
import inspect
4
4
import itertools
5
5
from pathlib import Path
6
- from typing import TYPE_CHECKING , TypeVar
7
6
8
7
from openapi_test_client .libraries .common .misc import import_module_from_file_path
9
8
10
9
from .base import APIBase
11
10
12
- if TYPE_CHECKING :
13
- APIClassType = TypeVar ("APIClassType" , bound = APIBase )
14
11
15
-
16
- def init_api_classes (base_api_class : type [APIClassType ]) -> list [type [APIClassType ]]:
12
+ def init_api_classes (base_api_class : type [APIBase ]) -> list [type [APIBase ]]:
17
13
"""Initialize API classes and return a list of API classes.
18
14
19
15
- A list of Endpoint objects for an API class is available via its `endpoints` attribute
@@ -37,6 +33,7 @@ def init_api_classes(base_api_class: type[APIClassType]) -> list[type[APIClassTy
37
33
from openapi_test_client .libraries .api .api_functions import EndpointFunc , EndpointHandler
38
34
39
35
previous_frame = inspect .currentframe ().f_back
36
+ assert previous_frame
40
37
caller_file_path = inspect .getframeinfo (previous_frame ).filename
41
38
assert caller_file_path .endswith ("__init__.py" ), (
42
39
f"API classes must be initialized in __init__.py. Unexpectedly called from { caller_file_path } "
@@ -50,7 +47,7 @@ def init_api_classes(base_api_class: type[APIClassType]) -> list[type[APIClassTy
50
47
api_class .endpoints = []
51
48
for attr_name , attr in api_class .__dict__ .items ():
52
49
if isinstance (attr , EndpointHandler ):
53
- endpoint_func = getattr (api_class , attr_name )
50
+ endpoint_func : EndpointFunc = getattr (api_class , attr_name )
54
51
assert isinstance (endpoint_func , EndpointFunc )
55
52
api_class .endpoints .append (endpoint_func .endpoint )
56
53
@@ -59,10 +56,10 @@ def init_api_classes(base_api_class: type[APIClassType]) -> list[type[APIClassTy
59
56
itertools .chain (* (x .endpoints for x in api_classes if x .endpoints )),
60
57
key = lambda x : (x .tags , x .method , x .path ),
61
58
)
62
- return sorted (api_classes , key = lambda x : x .TAGs )
59
+ return sorted (api_classes , key = lambda x : x .TAGs ) # type: ignore[arg-type, return-value]
63
60
64
61
65
- def get_api_classes (api_class_dir : Path , base_api_class : type [APIClassType ]) -> list [type [APIClassType ]]:
62
+ def get_api_classes (api_class_dir : Path , base_api_class : type [APIBase ]) -> list [type [APIBase ]]:
66
63
"""Get all API classes defined under the given API class directory"""
67
64
assert api_class_dir .is_dir ()
68
65
0 commit comments