Skip to content

Commit 1e9d3e8

Browse files
committed
Improve the decorator typing
This should make the type checking of decorated functions better and hence more likely to catch bugs. Note that the validation decorators cannot yet be improved as they inject keyword arguments, which is not supported as per https://peps.python.org/pep-0612/#concatenating-keyword-parameters .
1 parent 93f2620 commit 1e9d3e8

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

src/quart_schema/documentation.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from typing import Callable, Dict, Optional, Tuple
2-
3-
from quart import ResponseReturnValue as QuartResponseReturnValue
1+
from typing import Callable, Dict, Optional, Tuple, TypeVar
42

53
from .typing import Model
64
from .validation import (
@@ -11,8 +9,10 @@
119
QUART_SCHEMA_RESPONSE_ATTRIBUTE,
1210
)
1311

12+
T = TypeVar("T", bound=Callable)
13+
1414

15-
def document_querystring(model_class: Model) -> Callable:
15+
def document_querystring(model_class: Model) -> Callable[[T], T]:
1616
"""Document the request querystring arguments.
1717
1818
Add the request querystring **model_class** to the openapi
@@ -25,15 +25,15 @@ def document_querystring(model_class: Model) -> Callable:
2525
2626
"""
2727

28-
def decorator(func: Callable) -> Callable:
28+
def decorator(func: T) -> T:
2929
setattr(func, QUART_SCHEMA_QUERYSTRING_ATTRIBUTE, model_class)
3030

3131
return func
3232

3333
return decorator
3434

3535

36-
def document_headers(model_class: Model) -> Callable:
36+
def document_headers(model_class: Model) -> Callable[[T], T]:
3737
"""Document the request headers.
3838
3939
Add the request **model_class** to the openapi generated
@@ -46,7 +46,7 @@ def document_headers(model_class: Model) -> Callable:
4646
4747
"""
4848

49-
def decorator(func: Callable) -> Callable:
49+
def decorator(func: T) -> T:
5050
setattr(func, QUART_SCHEMA_HEADERS_ATTRIBUTE, model_class)
5151

5252
return func
@@ -58,7 +58,7 @@ def document_request(
5858
model_class: Model,
5959
*,
6060
source: DataSource = DataSource.JSON,
61-
) -> Callable:
61+
) -> Callable[[T], T]:
6262
"""Document the request data.
6363
6464
Add the request **model_class** to the openapi generated
@@ -72,7 +72,7 @@ def document_request(
7272
encoded).
7373
"""
7474

75-
def decorator(func: Callable) -> Callable:
75+
def decorator(func: T) -> T:
7676
setattr(func, QUART_SCHEMA_REQUEST_ATTRIBUTE, (model_class, source))
7777

7878
return func
@@ -84,7 +84,7 @@ def document_response(
8484
model_class: Model,
8585
status_code: int = 200,
8686
headers_model_class: Optional[Model] = None,
87-
) -> Callable:
87+
) -> Callable[[T], T]:
8888
"""Document the response data.
8989
9090
Add the response **model_class**, and its corresponding (optional)
@@ -103,9 +103,7 @@ def document_response(
103103
104104
"""
105105

106-
def decorator(
107-
func: Callable[..., QuartResponseReturnValue]
108-
) -> Callable[..., QuartResponseReturnValue]:
106+
def decorator(func: T) -> T:
109107
schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {})
110108
schemas[status_code] = (model_class, headers_model_class)
111109
setattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, schemas)
@@ -122,15 +120,15 @@ def document(
122120
request_source: DataSource = DataSource.JSON,
123121
headers: Optional[Model] = None,
124122
responses: Dict[int, Tuple[Model, Optional[Model]]],
125-
) -> Callable:
123+
) -> Callable[[T], T]:
126124
"""Document the route.
127125
128126
This is a shorthand combination of of the document_querystring,
129127
document_request, document_headers, and document_response
130128
decorators. Please see the docstrings for those decorators.
131129
"""
132130

133-
def decorator(func: Callable) -> Callable:
131+
def decorator(func: T) -> T:
134132
if querystring is not None:
135133
func = document_querystring(querystring)(func)
136134
if request is not None:

src/quart_schema/extension.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict
66
from functools import wraps
77
from types import new_class
8-
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
8+
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, TypeVar, Union
99

1010
import click
1111
import humps
@@ -48,6 +48,7 @@
4848
except ImportError:
4949
to_builtins = None
5050

51+
T = TypeVar("T", bound=Callable)
5152

5253
SecurityScheme = Union[
5354
APIKeySecurityScheme,
@@ -150,7 +151,7 @@ def default(object_: Any) -> Any:
150151
return JSONProvider(app)
151152

152153

153-
def hide(func: Callable) -> Callable:
154+
def hide(func: T) -> T:
154155
"""Mark the func as hidden.
155156
156157
This will prevent the route from being included in the
@@ -398,7 +399,7 @@ async def decorator(result: ResponseReturnValue) -> QuartResponseReturnValue:
398399
return decorator
399400

400401

401-
def operation_id(operationid: str) -> Callable:
402+
def operation_id(operationid: str) -> Callable[[T], T]:
402403
"""Override the operationId of the route.
403404
404405
This allows for overriding the operationId, which is normally calculated from the
@@ -409,15 +410,15 @@ def operation_id(operationid: str) -> Callable:
409410
410411
"""
411412

412-
def decorator(func: Callable) -> Callable:
413+
def decorator(func: T) -> T:
413414
setattr(func, QUART_SCHEMA_OPERATION_ID_ATTRIBUTE, str(operationid))
414415

415416
return func
416417

417418
return decorator
418419

419420

420-
def tag(tags: Iterable[str]) -> Callable:
421+
def tag(tags: Iterable[str]) -> Callable[[T], T]:
421422
"""Add tag names to the route.
422423
423424
This allows for tags to be associated with the route, thereby
@@ -428,22 +429,22 @@ def tag(tags: Iterable[str]) -> Callable:
428429
429430
"""
430431

431-
def decorator(func: Callable) -> Callable:
432+
def decorator(func: T) -> T:
432433
setattr(func, QUART_SCHEMA_TAG_ATTRIBUTE, set(tags))
433434

434435
return func
435436

436437
return decorator
437438

438439

439-
def deprecate(func: Callable) -> Callable:
440+
def deprecate(func: T) -> T:
440441
"""Mark endpoint as deprecated."""
441442
setattr(func, QUART_SCHEMA_DEPRECATED, True)
442443

443444
return func
444445

445446

446-
def security_scheme(schemes: Iterable[Dict[str, List[str]]]) -> Callable:
447+
def security_scheme(schemes: Iterable[Dict[str, List[str]]]) -> Callable[[T], T]:
447448
"""Add security schemes to the route.
448449
449450
Allows security schemes to be associated with this route. Security
@@ -455,7 +456,7 @@ def security_scheme(schemes: Iterable[Dict[str, List[str]]]) -> Callable:
455456
456457
"""
457458

458-
def decorator(func: Callable) -> Callable:
459+
def decorator(func: T) -> T:
459460
setattr(func, QUART_SCHEMA_SECURITY_ATTRIBUTE, schemes)
460461

461462
return func

0 commit comments

Comments
 (0)