Skip to content

Commit 7008e24

Browse files
committed
Add misc changes:
- Update type annotations - Add decorator for restricting endpoint function calls via the owner class
1 parent 427cb4f commit 7008e24

File tree

1 file changed

+60
-22
lines changed
  • src/openapi_test_client/libraries/api/api_functions

1 file changed

+60
-22
lines changed

src/openapi_test_client/libraries/api/api_functions/endpoints.py

+60-22
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

3+
import inspect
34
from collections.abc import Callable, Sequence
45
from copy import deepcopy
56
from dataclasses import dataclass
67
from functools import partial, update_wrapper, wraps
78
from threading import RLock
8-
from typing import TYPE_CHECKING, Any, ClassVar, ParamSpec, TypeAlias, TypeVar, cast
9+
from typing import TYPE_CHECKING, Any, ClassVar, Concatenate, ParamSpec, TypeAlias, TypeVar, cast
910

1011
from common_libs.ansi_colors import ColorCodes, color
1112
from common_libs.clients.rest_client import RestResponse
@@ -33,7 +34,14 @@
3334

3435
P = ParamSpec("P")
3536
R = TypeVar("R")
36-
_EndpointFunc = TypeVar("_EndpointFunc", bound=Callable[..., RestResponse]) # For making IDE happy
37+
38+
39+
_EndpointFunc = TypeVar(
40+
# TODO: Remove this
41+
# A workaound for https://youtrack.jetbrains.com/issue/PY-57765
42+
"_EndpointFunc",
43+
bound=Callable[..., RestResponse],
44+
)
3745
EndpointFunction: TypeAlias = _EndpointFunc | "EndpointFunc"
3846
EndpointDecorator: TypeAlias = Callable[[EndpointFunction], EndpointFunction]
3947

@@ -108,7 +116,7 @@ class endpoint:
108116
>>>
109117
>>> class AuthAPI(DemoAppBaseAPI):
110118
>>> @endpoint.post("/v1/login")
111-
>>> def login(self, *, username: str = Unset, password: str = Unset, **params):
119+
>>> def login(self, *, username: str = Unset, password: str = Unset, **kwargs: Any) -> RestResponse:
112120
>>> ...
113121
>>>
114122
>>> client = DemoAppAPIClient()
@@ -274,37 +282,51 @@ def undocumented(obj: EndpointHandler | type[APIBase] | EndpointFunction) -> End
274282
"""Mark an endpoint as undocumented. If an API class is decorated, all endpoints on the class will be
275283
automatically marked as undocumented.
276284
The flag value is available with an Endpoint object's is_documented attribute
285+
286+
:param obj: Endpoint handler or API class
287+
NOTE: EndpointFunction type was added for mypy only
277288
"""
289+
assert isinstance(obj, EndpointHandler) or (inspect.isclass(obj) and issubclass(obj, APIBase))
278290
obj.is_documented = False
279291
return cast(EndpointFunction, obj)
280292

281293
@staticmethod
282294
def is_public(obj: EndpointHandler | EndpointFunction) -> EndpointFunction:
283295
"""Mark an endpoint as a public API that does not require authentication.
284296
The flag value is available with an Endpoint object's is_public attribute
297+
298+
:param obj: Endpoint handler
299+
NOTE: EndpointFunction type was added for mypy only
285300
"""
301+
assert isinstance(obj, EndpointHandler)
286302
obj.is_public = True
287303
return cast(EndpointFunction, obj)
288304

289305
@staticmethod
290306
def is_deprecated(obj: EndpointHandler | type[APIBase] | EndpointFunction) -> EndpointFunction:
291307
"""Mark an endpoint as a deprecated API. If an API class is decorated, all endpoints on the class will be
292308
automatically marked as deprecated.
309+
310+
:param obj: Endpoint handler or API class
311+
NOTE: EndpointFunction type was added for mypy only
293312
"""
313+
assert isinstance(obj, EndpointHandler) or (inspect.isclass(obj) and issubclass(obj, APIBase))
294314
obj.is_deprecated = True
295315
return cast(EndpointFunction, obj)
296316

297317
@staticmethod
298318
def content_type(content_type: str) -> Callable[..., EndpointFunction]:
299-
"""Explicitly set Content-Type for this endpoint"""
319+
"""Explicitly set Content-Type for this endpoint
320+
321+
:param content_type: Content type to explicitly set
322+
"""
300323

301-
def decorator_with_arg(
302-
obj: EndpointHandler | EndpointFunction,
303-
) -> EndpointHandler | EndpointFunction:
304-
obj.content_type = content_type
305-
return obj
324+
def wrapper(endpoint_handler: EndpointHandler) -> EndpointHandler:
325+
assert isinstance(endpoint_handler, EndpointHandler)
326+
endpoint_handler.content_type = content_type
327+
return endpoint_handler
306328

307-
return cast(Callable[..., EndpointFunction], decorator_with_arg)
329+
return cast(Callable[..., EndpointFunction], wrapper)
308330

309331
@staticmethod
310332
def decorator(
@@ -332,17 +354,16 @@ def decorator(
332354
>>> ...
333355
"""
334356

335-
@wraps(f)
336357
def wrapper(*args: Any, **kwargs: Any) -> EndpointHandler | Callable[[EndpointHandler], EndpointHandler]:
337358
if not kwargs and args and len(args) == 1 and isinstance(args[0], EndpointHandler):
338359
# This is a regular decorator
339360
endpoint_handler: EndpointHandler = args[0]
340-
endpoint_handler.register_decorator(f)
361+
endpoint_handler.register_decorator(cast(EndpointDecorator, f))
341362
return endpoint_handler
342363
else:
343364
# The decorator takes arguments
344365
def _wrapper(endpoint_handler: EndpointHandler) -> EndpointHandler:
345-
endpoint_handler.register_decorator(partial(f, *args, **kwargs))
366+
endpoint_handler.register_decorator(cast(EndpointDecorator, partial(f, *args, **kwargs)))
346367
return endpoint_handler
347368

348369
return _wrapper
@@ -405,11 +426,13 @@ def __init__(
405426
self.path = path
406427
self.use_query_string = use_query_string
407428
self.requests_lib_options = requests_lib_options
408-
self.content_type = None # Will be set by @endpoint.content_type decorator (or application/json by default)
429+
430+
# Will be set via @endpoint.<decorator_name>
431+
self.content_type: str | None = None # application/json by default
409432
self.is_public = False
410433
self.is_documented = True
411434
self.is_deprecated = False
412-
self.__decorators: list[Callable[..., Any]] = []
435+
self.__decorators: list[EndpointDecorator] = []
413436

414437
def __get__(self, instance: APIBase | None, owner: type[APIBase]) -> EndpointFunc:
415438
"""Return an EndpointFunc object"""
@@ -421,19 +444,30 @@ def __get__(self, instance: APIBase | None, owner: type[APIBase]) -> EndpointFun
421444
)
422445
EndpointFuncClass = type(endpoint_func_name, (EndpointFunc,), {})
423446
endpoint_func = EndpointFuncClass(self, instance, owner)
424-
EndpointHandler._endpoint_functions[key] = endpoint_func
425-
return cast(EndpointFunc, update_wrapper(endpoint_func, self.original_func))
447+
EndpointHandler._endpoint_functions[key] = update_wrapper(endpoint_func, self.original_func)
448+
return cast(EndpointFunc, endpoint_func)
426449

427450
@property
428-
def decorators(self) -> list[Callable[..., Any]]:
451+
def decorators(self) -> list[EndpointDecorator]:
429452
"""Returns decorators that should be applied on an endpoint function"""
430453
return self.__decorators
431454

432-
def register_decorator(self, *decorator: Callable[..., Any]) -> None:
455+
def register_decorator(self, *decorator: EndpointDecorator) -> None:
433456
"""Register a decorator that will be applied on an endpoint function"""
434457
self.__decorators.extend([d for d in decorator])
435458

436459

460+
def requires_instance(f: Callable[Concatenate[EndpointFunc, P], R]) -> Callable[Concatenate[EndpointFunc, P], R]:
461+
@wraps(f)
462+
def wrapper(self: EndpointFunc, *args: P.args, **kwargs: P.kwargs) -> R:
463+
if self._instance is None:
464+
func_name = self._original_func.__name__ if f.__name__ == "__call__" else f.__name__
465+
raise TypeError(f"You can not access {func_name}() directly through the {self._owner.__name__} class.")
466+
return f(self, *args, **kwargs)
467+
468+
return wrapper
469+
470+
437471
class EndpointFunc:
438472
"""Endpoint function class
439473
@@ -459,9 +493,9 @@ def __init__(self, endpoint_handler: EndpointHandler, instance: APIBase | None,
459493
# Control a retry in a request wrapper to prevent a loop
460494
self.retried = False
461495

496+
self._instance: APIBase | None = instance
497+
self._owner: type[APIBase] = owner
462498
self._original_func: Callable[..., RestResponse] = endpoint_handler.original_func
463-
self._instance = instance
464-
self._owner = owner
465499
self._use_query_string = endpoint_handler.use_query_string
466500
self._requests_lib_options = endpoint_handler.requests_lib_options
467501

@@ -497,6 +531,7 @@ def __init__(self, endpoint_handler: EndpointHandler, instance: APIBase | None,
497531
def __repr__(self) -> str:
498532
return f"{super().__repr__()}\n(mapped to: {self._original_func!r})"
499533

534+
@requires_instance
500535
def __call__(
501536
self,
502537
*path_params: Any,
@@ -506,7 +541,7 @@ def __call__(
506541
with_hooks: bool | None = True,
507542
validate: bool | None = None,
508543
**params: Any,
509-
) -> RestResponse | None:
544+
) -> RestResponse:
510545
"""Make an API call to the endpoint
511546
512547
:param path_params: Path parameters
@@ -619,6 +654,7 @@ def docs(self) -> None:
619654
else:
620655
print("Docs not available") # noqa: T201
621656

657+
@requires_instance
622658
def with_retry(
623659
self,
624660
*args: Any,
@@ -638,6 +674,7 @@ def with_retry(
638674
f = retry_on(condition, num_retry=num_retry, retry_after=retry_after, safe_methods_only=False)(self)
639675
return f(*args, **kwargs)
640676

677+
@requires_instance
641678
def with_lock(self, *args: Any, lock_name: str | None = None, **kwargs: Any) -> RestResponse:
642679
"""Make an API call with lock
643680
@@ -664,4 +701,5 @@ def get_usage(self) -> str | None:
664701

665702
if TYPE_CHECKING:
666703
# For making IDE happy
704+
# TODO: Remove this
667705
EndpointFunc: TypeAlias = _EndpointFunc | EndpointFunc # type: ignore[no-redef]

0 commit comments

Comments
 (0)