Skip to content

Commit

Permalink
NEW: support non-Pydantic arguments in Payload and FormData, reso…
Browse files Browse the repository at this point in the history
…lves #77
  • Loading branch information
eigenein committed Nov 10, 2023
1 parent 3d34a74 commit 05bffac
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ format: format/ruff

.PHONY: format/ruff
format/ruff:
poetry run ruff check --fix combadge tests
poetry run ruff format combadge tests
poetry run ruff check --fix combadge tests

.PHONY: test
test:
Expand Down
8 changes: 1 addition & 7 deletions combadge/support/http/abc/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,4 @@ def append_form_field(self, name: str, value: Any) -> None: # noqa: D102
class ContainsPayload(ABC):
"""SOAP request payload."""

payload: Optional[dict] = None

def ensure_payload(self) -> dict:
"""Ensure that the payload is initialized and return it."""
if self.payload is None:
self.payload = {}
return self.payload
payload: Optional[Any] = None
27 changes: 21 additions & 6 deletions combadge/support/http/markers/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from inspect import BoundArguments
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar

from pydantic import BaseModel
from typing_extensions import Annotated, TypeAlias, override

from combadge.core.markers.method import MethodMarker
Expand All @@ -19,6 +18,7 @@
ContainsQueryParams,
ContainsUrlPath,
)
from combadge.support.shared.functools import get_type_adapter

_T = TypeVar("_T")

Expand Down Expand Up @@ -144,8 +144,18 @@ class Payload(ParameterMarker[ContainsPayload]):
by_alias: bool = False

@override
def __call__(self, request: ContainsPayload, value: BaseModel) -> None: # noqa: D102
request.ensure_payload().update(value.model_dump(by_alias=self.by_alias, exclude_unset=self.exclude_unset))
def __call__(self, request: ContainsPayload, value: Any) -> None: # noqa: D102
value = get_type_adapter(type(value)).dump_python(
value,
by_alias=self.by_alias,
exclude_unset=self.exclude_unset,
)
if request.payload is None:
request.payload = value
elif isinstance(request.payload, dict):
request.payload.update(value) # merge into the existing payload
else:
raise ValueError(f"attempting to merge {type(value)} into {type(request.payload)}")

def __class_getitem__(cls, item: type[Any]) -> Any:
return Annotated[item, cls()]
Expand Down Expand Up @@ -178,7 +188,9 @@ class Field(ParameterMarker[ContainsPayload]):

@override
def __call__(self, request: ContainsPayload, value: Any) -> None: # noqa: D102
request.ensure_payload()[self.name] = value.value if isinstance(value, Enum) else value
if request.payload is None:
request.payload = {}
request.payload[self.name] = value.value if isinstance(value, Enum) else value


if not TYPE_CHECKING:
Expand All @@ -201,8 +213,11 @@ class FormData(ParameterMarker[ContainsFormData]):
__slots__ = ()

@override
def __call__(self, request: ContainsFormData, value: BaseModel) -> None: # noqa: D102
for item_name, item_value in value.model_dump(by_alias=True).items():
def __call__(self, request: ContainsFormData, value: Any) -> None: # noqa: D102
value = get_type_adapter(type(value)).dump_python(value, by_alias=True)
if not isinstance(value, dict):
raise TypeError(f"form data requires a dictionary, got {type(value)}")
for item_name, item_value in value.items():
request.append_form_field(item_name, item_value)

def __class_getitem__(cls, item: type[Any]) -> Any:
Expand Down
10 changes: 10 additions & 0 deletions combadge/support/shared/functools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from functools import lru_cache
from typing import Any

from pydantic import TypeAdapter


@lru_cache(maxsize=None)
def get_type_adapter(type_: Any) -> TypeAdapter[Any]:
"""Get cached type adapter for the given type."""
return TypeAdapter(type_)
89 changes: 89 additions & 0 deletions docs/support/models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Models

Combadge is built on top of [Pydantic](https://docs.pydantic.dev/), hence Pydantic models are natively supported in service protocols.

However, thanks to the Pydantic's [`TypeAdapter`](https://docs.pydantic.dev/latest/api/type_adapter/), Combadge automatically supports:

## Built-in Python types

```python title="builtin.py" hl_lines="12 17"
from typing_extensions import Annotated, Protocol

from combadge.core.markers import Extract
from combadge.support.httpx.backends.sync import HttpxBackend
from combadge.support.http.markers import Payload, http_method, path
from httpx import Client


class Httpbin(Protocol):
@http_method("POST")
@path("/anything")
def post_anything(self, foo: Payload[int]) -> Annotated[int, Extract("data")]:
...


backend = HttpxBackend(Client(base_url="https://httpbin.org"))
assert backend[Httpbin].post_anything(42) == 42
```

## Standard [dataclasses](https://docs.python.org/3/library/dataclasses.html)

```python title="dataclasses.py" hl_lines="10-12 15-17 23 28"
from dataclasses import dataclass

from typing_extensions import Protocol

from combadge.support.httpx.backends.sync import HttpxBackend
from combadge.support.http.markers import Payload, http_method, path
from httpx import Client


@dataclass
class Request:
foo: int


@dataclass
class Response:
data: str


class Httpbin(Protocol):
@http_method("POST")
@path("/anything")
def post_anything(self, foo: Payload[Request]) -> Response:
...


backend = HttpxBackend(Client(base_url="https://httpbin.org"))
assert backend[Httpbin].post_anything(Request(42)) == Response(data='{"foo": 42}')
```

## [Typed dictionaries](https://docs.python.org/3/library/typing.html#typing.TypedDict)

```python title="typed_dict.py" hl_lines="8-9 12-13 19 24"
from typing_extensions import Protocol, TypedDict

from combadge.support.httpx.backends.sync import HttpxBackend
from combadge.support.http.markers import Payload, http_method, path
from httpx import Client


class Request(TypedDict):
foo: int


class Response(TypedDict):
data: str


class Httpbin(Protocol):
@http_method("POST")
@path("/anything")
def post_anything(self, foo: Payload[Request]) -> Response:
...


backend = HttpxBackend(Client(base_url="https://httpbin.org"))
assert backend[Httpbin].post_anything({"foo": 42}) == {"data": '{"foo": 42}'}
```
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ nav:
- Backends:
- support/httpx.md
- support/zeep.md
- support/models.md
- support/handling-errors.md
- tags.md
- Cookbook:
Expand All @@ -33,6 +34,7 @@ theme:
- content.action.edit
- content.code.annotate
- content.code.copy
- navigation.expand
- navigation.footer
- navigation.indexes
- navigation.instant
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
interactions:
- request:
body: '42'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '2'
content-type:
- application/json
host:
- httpbin.org
user-agent:
- python-httpx/0.25.1
method: POST
uri: https://httpbin.org/anything
response:
content: "{\n \"args\": {}, \n \"data\": \"42\", \n \"files\": {}, \n \"form\":
{}, \n \"headers\": {\n \"Accept\": \"*/*\", \n \"Accept-Encoding\":
\"gzip, deflate\", \n \"Content-Length\": \"2\", \n \"Content-Type\":
\"application/json\", \n \"Host\": \"httpbin.org\", \n \"User-Agent\":
\"python-httpx/0.25.1\", \n \"X-Amzn-Trace-Id\": \"Root=1-654e3c0b-67b7ec060f0795b446f2cac2\"\n
\ }, \n \"json\": 42, \n \"method\": \"POST\", \n \"origin\": \"86.94.162.190\",
\n \"url\": \"https://httpbin.org/anything\"\n}\n"
headers:
Access-Control-Allow-Credentials:
- 'true'
Access-Control-Allow-Origin:
- '*'
Connection:
- keep-alive
Content-Length:
- '462'
Content-Type:
- application/json
Date:
- Fri, 10 Nov 2023 14:19:55 GMT
Server:
- gunicorn/19.9.0
http_version: HTTP/1.1
status_code: 200
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
interactions:
- request:
body: '{"foo": 42}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '11'
content-type:
- application/json
host:
- httpbin.org
user-agent:
- python-httpx/0.25.1
method: POST
uri: https://httpbin.org/anything
response:
content: "{\n \"args\": {}, \n \"data\": \"{\\\"foo\\\": 42}\", \n \"files\":
{}, \n \"form\": {}, \n \"headers\": {\n \"Accept\": \"*/*\", \n \"Accept-Encoding\":
\"gzip, deflate\", \n \"Content-Length\": \"11\", \n \"Content-Type\":
\"application/json\", \n \"Host\": \"httpbin.org\", \n \"User-Agent\":
\"python-httpx/0.25.1\", \n \"X-Amzn-Trace-Id\": \"Root=1-654e3e5b-7a9b143d4b620fdc7f26dca5\"\n
\ }, \n \"json\": {\n \"foo\": 42\n }, \n \"method\": \"POST\", \n \"origin\":
\"86.94.162.190\", \n \"url\": \"https://httpbin.org/anything\"\n}\n"
headers:
Access-Control-Allow-Credentials:
- 'true'
Access-Control-Allow-Origin:
- '*'
Connection:
- keep-alive
Content-Length:
- '491'
Content-Type:
- application/json
Date:
- Fri, 10 Nov 2023 14:29:47 GMT
Server:
- gunicorn/19.9.0
http_version: HTTP/1.1
status_code: 200
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
interactions:
- request:
body: '{"foo": 42}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '11'
content-type:
- application/json
host:
- httpbin.org
user-agent:
- python-httpx/0.25.1
method: POST
uri: https://httpbin.org/anything
response:
content: "{\n \"args\": {}, \n \"data\": \"{\\\"foo\\\": 42}\", \n \"files\":
{}, \n \"form\": {}, \n \"headers\": {\n \"Accept\": \"*/*\", \n \"Accept-Encoding\":
\"gzip, deflate\", \n \"Content-Length\": \"11\", \n \"Content-Type\":
\"application/json\", \n \"Host\": \"httpbin.org\", \n \"User-Agent\":
\"python-httpx/0.25.1\", \n \"X-Amzn-Trace-Id\": \"Root=1-654e3f28-6dc22fa87c99307a215093d4\"\n
\ }, \n \"json\": {\n \"foo\": 42\n }, \n \"method\": \"POST\", \n \"origin\":
\"86.94.162.190\", \n \"url\": \"https://httpbin.org/anything\"\n}\n"
headers:
Access-Control-Allow-Credentials:
- 'true'
Access-Control-Allow-Origin:
- '*'
Connection:
- keep-alive
Content-Length:
- '491'
Content-Type:
- application/json
Date:
- Fri, 10 Nov 2023 14:33:12 GMT
Server:
- gunicorn/19.9.0
http_version: HTTP/1.1
status_code: 200
version: 1
1 change: 1 addition & 0 deletions tests/integration/test_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ def _generate_params(path: Path) -> Iterator[NamedTuple]:
)
@pytest.mark.vcr(decode_compressed_response=True)
def test_documentation_snippet(snippet: str) -> None:
__tracebackhide__ = True
exec(dedent(snippet), {})
10 changes: 5 additions & 5 deletions tests/integration/test_number_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class NumberTooLargeResponse(RootModel, ErrorResponse):
root: Literal["number too large"]


class TestFault(BaseSoapFault):
class _TestFault(BaseSoapFault):
code: Literal["SOAP-ENV:Server"]
message: Literal["Test Fault"]

Expand All @@ -40,7 +40,7 @@ class SupportsNumberConversion(SupportsService, Protocol):
def number_to_words(
self,
request: Annotated[NumberToWordsRequest, Payload(by_alias=True)],
) -> Union[NumberTooLargeResponse, NumberToWordsResponse, TestFault]:
) -> Union[NumberTooLargeResponse, NumberToWordsResponse, _TestFault]:
raise NotImplementedError


Expand All @@ -50,7 +50,7 @@ class SupportsNumberConversionAsync(SupportsService, Protocol):
async def number_to_words(
self,
request: Annotated[NumberToWordsRequest, Payload(by_alias=True)],
) -> Union[NumberTooLargeResponse, NumberToWordsResponse, TestFault]:
) -> Union[NumberTooLargeResponse, NumberToWordsResponse, _TestFault]:
raise NotImplementedError


Expand Down Expand Up @@ -91,14 +91,14 @@ def test_sad_path_scalar_response(number_conversion_service: SupportsNumberConve
def test_sad_path_web_fault(number_conversion_service: SupportsNumberConversion) -> None:
# Note: the cassette is manually patched to return the SOAP fault.
response = number_conversion_service.number_to_words(NumberToWordsRequest(number=42))
with pytest.raises(TestFault.Error):
with pytest.raises(_TestFault.Error):
response.raise_for_result()


@pytest.mark.vcr()
async def test_happy_path_scalar_response_async(number_conversion_service_async: SupportsNumberConversionAsync) -> None:
response = await number_conversion_service_async.number_to_words(NumberToWordsRequest(number=42))
assert_type(response, Union[NumberToWordsResponse, NumberTooLargeResponse, TestFault])
assert_type(response, Union[NumberToWordsResponse, NumberTooLargeResponse, _TestFault])

response = response.unwrap()
assert_type(response, NumberToWordsResponse)
Expand Down

0 comments on commit 05bffac

Please sign in to comment.