Skip to content

Commit 427cb4f

Browse files
committed
Add tests for Pydantic models
1 parent 182b9ca commit 427cb4f

File tree

8 files changed

+248
-29
lines changed

8 files changed

+248
-29
lines changed

src/openapi_test_client/libraries/api/api_functions/utils/param_type.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -327,13 +327,18 @@ def is_optional_type(tp: Any) -> bool:
327327
return False
328328

329329

330-
def is_union_type(tp: Any) -> bool:
330+
def is_union_type(tp: Any, exclude_optional: bool = False) -> bool:
331331
"""Check if the type annotation is a Union type
332332
333333
:param tp: Type annotation
334+
:param exclude_optional: Exclude Optional[] type
334335
"""
335336
origin_type = get_origin(tp)
336-
return origin_type in [Union, UnionType]
337+
is_union = origin_type in [Union, UnionType]
338+
if exclude_optional:
339+
return is_union and NoneType not in get_args(tp)
340+
else:
341+
return is_union
337342

338343

339344
def is_deprecated_param(tp: Any) -> bool:

src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
1+
import datetime
12
import inspect
23
import ipaddress
34
import os
45
from collections.abc import Generator
56
from contextlib import contextmanager
67
from dataclasses import Field as DataclassField
7-
from datetime import date, datetime, time, timedelta
88
from pathlib import Path
99
from types import EllipsisType
1010
from typing import Any, TypeVar, get_args
1111
from uuid import UUID
1212

1313
from pydantic import (
1414
AnyHttpUrl,
15-
Base64Bytes,
1615
Base64Str,
17-
Base64UrlBytes,
1816
Base64UrlStr,
1917
EmailStr,
2018
Field,
@@ -27,23 +25,23 @@
2725

2826
import openapi_test_client.libraries.api.api_functions.utils.param_model as param_model_util
2927
import openapi_test_client.libraries.api.api_functions.utils.param_type as param_type_util
30-
from openapi_test_client.libraries.api.types import Constraint, DataclassModel, EndpointModel, Format, ParamModel
28+
from openapi_test_client.libraries.api.types import Constraint, DataclassModel, EndpointModel, File, Format, ParamModel
3129

3230
T = TypeVar("T")
3331

3432

3533
# TODO: Update this if needed
3634
PARAM_FORMAT_AND_TYPE_MAP = {
3735
"uuid": UUID,
38-
"date-time": datetime,
39-
"date": date,
40-
"time": time,
41-
"duration": timedelta,
42-
"binary": bytes,
36+
"date-time": datetime.datetime,
37+
"date": datetime.date,
38+
"time": datetime.time,
39+
"duration": datetime.timedelta,
40+
"binary": bytes | File,
4341
"byte": bytes,
4442
"path": Path,
45-
"base64": Base64Str | Base64Bytes,
46-
"base64url": Base64UrlStr | Base64UrlBytes,
43+
"base64": Base64Str,
44+
"base64url": Base64UrlStr,
4745
"email": EmailStr,
4846
"name-email": NameEmail,
4947
"uri": AnyHttpUrl,

src/openapi_test_client/libraries/common/json_encoder.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import json
23
from dataclasses import asdict, is_dataclass
34
from datetime import datetime
@@ -12,10 +13,12 @@ class CustomJsonEncoder(json.JSONEncoder):
1213
def default(self, obj: Any) -> Any:
1314
if isinstance(obj, UUID | Decimal):
1415
return str(obj)
16+
elif isinstance(obj, bytes):
17+
return base64.b64encode(obj).decode("utf-8")
1518
elif isinstance(obj, datetime):
1619
return obj.isoformat()
1720
elif is_dataclass(obj) and not isinstance(obj, type):
18-
return asdict(obj)
21+
return {k: self.default(v) for k, v in asdict(obj).items()}
1922
elif isinstance(obj, BaseModel):
2023
return obj.model_dump(mode="json", exclude_unset=True)
2124
else:

tests/conftest.py

+16
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from pytest import Item, TempPathFactory
99
from pytest_mock import MockerFixture
1010

11+
from openapi_test_client.libraries.api.types import File
12+
1113

1214
def pytest_make_parametrize_id(val: Any, argname: str) -> str:
1315
return f"{argname}={val!r}"
@@ -34,3 +36,17 @@ def _mock_sys_path_and_modules(mocker: MockerFixture) -> None:
3436
def temp_dir(tmp_path_factory: TempPathFactory) -> Path:
3537
current_test_name = os.environ["PYTEST_CURRENT_TEST"].rsplit(" ", 1)[0]
3638
return tmp_path_factory.mktemp(clean_obj_name(current_test_name))
39+
40+
41+
@pytest.fixture(scope="session")
42+
def image_data() -> bytes:
43+
return ( # https://evanhahn.com/worlds-smallest-png/
44+
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR"
45+
b"\x00\x00\x00\x01\x00\x00\x00\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx"
46+
b"\x01c`\x00\x00\x00\x02\x00\x01su\x01\x18\x00\x00\x00\x00IEND\xaeB`\x82"
47+
)
48+
49+
50+
@pytest.fixture(scope="session")
51+
def image_file(image_data: bytes) -> File:
52+
return File(filename="test_image.png", content=image_data, content_type="image/png")

tests/integration/test_api_users.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,8 @@ def test_get_users(api_client: DemoAppAPIClient, validation_mode: bool) -> None:
4444

4545

4646
@pytest.mark.parametrize("validation_mode", [False, True])
47-
def test_upload_image(api_client: DemoAppAPIClient, validation_mode: bool) -> None:
47+
def test_upload_image(api_client: DemoAppAPIClient, validation_mode: bool, image_data: bytes) -> None:
4848
"""Check basic client/server functionality of upload user image API"""
49-
image_data = (
50-
# https://evanhahn.com/worlds-smallest-png/
51-
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR"
52-
b"\x00\x00\x00\x01\x00\x00\x00\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx"
53-
b"\x01c`\x00\x00\x00\x02\x00\x01su\x01\x18\x00\x00\x00\x00IEND\xaeB`\x82"
54-
)
5549
file = File(filename="test_image.png", content=image_data, content_type="image/png")
5650
r = api_client.Users.upload_image(file=file, description="test image", validate=validation_mode)
5751
assert r.status_code == 201

tests/unit/conftest.py

+78-6
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@
22
import os
33
import shutil
44
from collections.abc import Generator
5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, make_dataclass
66
from pathlib import Path
7-
from typing import Any
7+
from typing import Annotated, Any, cast
88

99
import pytest
10+
from _pytest.fixtures import SubRequest
1011
from pytest import FixtureRequest
1112
from pytest_mock import MockerFixture
1213

14+
import openapi_test_client.libraries.api.api_functions.utils.param_type as param_type_util
1315
from openapi_test_client import ENV_VAR_PACKAGE_DIR
1416
from openapi_test_client.clients.base import OpenAPIClient
1517
from openapi_test_client.clients.demo_app import DemoAppAPIClient
1618
from openapi_test_client.clients.demo_app.api.auth import AuthAPI
17-
from openapi_test_client.libraries.api.api_client_generator import (
18-
setup_external_directory,
19-
)
19+
from openapi_test_client.libraries.api.api_client_generator import setup_external_directory
20+
from openapi_test_client.libraries.api.api_functions.utils.pydantic_model import PARAM_FORMAT_AND_TYPE_MAP
2021
from openapi_test_client.libraries.api.api_spec import OpenAPISpec
21-
from openapi_test_client.libraries.api.types import ParamModel, Unset
22+
from openapi_test_client.libraries.api.types import File, Format, Optional, ParamModel, Unset
2223
from tests.unit import helper
2324

2425

@@ -96,3 +97,74 @@ class Model(ParamModel):
9697
inner_param2: str = Unset
9798

9899
return Model
100+
101+
102+
@pytest.fixture
103+
def NewParamModel(request: SubRequest) -> type[ParamModel]:
104+
"""A new dataclass param model generated with requested field data via indirect parametrization
105+
106+
The fixture can be take the field data in various shapes as follows:
107+
- Just one field:
108+
- Only field type (field name and the default value will be automatically set)
109+
- As tuple (field name, field type) or (field name, field type, default value)
110+
- Multiple fields: List of above
111+
"""
112+
if not hasattr(request, "param"):
113+
raise ValueError(f"{NewParamModel.__name__} fixture must be used as an indirect parametrization")
114+
115+
def add_field(field_data: Any | tuple[str, Any] | tuple[str, Any, Any], idx: int = 1) -> None:
116+
if isinstance(field_data, tuple):
117+
assert len(field_data) <= 3, f"Invalid field: {field_data}. Each field must be given as 2 or 3 items"
118+
if len(field_data) == 1:
119+
fields.append((f"field_{idx}", field_data, Unset))
120+
elif len(field_data) >= 2:
121+
fields.append(field_data)
122+
else:
123+
fields.append((f"field{idx}", field_data, Unset))
124+
125+
requested_field_data = request.param
126+
fields: list[Any | tuple[str, Any] | tuple[str, Any, Any]] = []
127+
if isinstance(requested_field_data, list):
128+
for i, requested_field in enumerate(requested_field_data, start=1):
129+
add_field(requested_field, idx=i)
130+
else:
131+
add_field(requested_field_data)
132+
133+
param_model = cast(type[ParamModel], make_dataclass("Model", fields, bases=(ParamModel,)))
134+
return param_model
135+
136+
137+
@pytest.fixture(scope="session")
138+
def ParamModelWithParamFormats() -> type[ParamModel]:
139+
"""A a dataclass param model that has fields with various param formats we support"""
140+
fields = [
141+
("uuid", Optional[Annotated[str, Format("uuid")]], Unset),
142+
("date_time", Optional[Annotated[str, Format("date-time")]], Unset),
143+
("date", Optional[Annotated[str, Format("date")]], Unset),
144+
("time", Optional[Annotated[str, Format("time")]], Unset),
145+
("duration", Optional[Annotated[str, Format("duration")]], Unset),
146+
("binary", Optional[Annotated[str, Format("binary")]], Unset),
147+
("file", Optional[Annotated[File, Format("binary")]], Unset),
148+
("byte", Optional[Annotated[str, Format("byte")]], Unset),
149+
("path", Optional[Annotated[str, Format("path")]], Unset),
150+
("base64", Optional[Annotated[str, Format("base64")]], Unset),
151+
("base64url", Optional[Annotated[str, Format("base64url")]], Unset),
152+
("email", Optional[Annotated[str, Format("email")]], Unset),
153+
("name_email", Optional[Annotated[str, Format("name-email")]], Unset),
154+
("uri", Optional[Annotated[str, Format("uri")]], Unset),
155+
("ipv4", Optional[Annotated[str, Format("ipv4")]], Unset),
156+
("ipv6", Optional[Annotated[str, Format("ipv6")]], Unset),
157+
("ipvanyaddress", Optional[Annotated[str, Format("ipvanyaddress")]], Unset),
158+
("ipvanyinterface", Optional[Annotated[str, Format("ipvanyinterface")]], Unset),
159+
("ipvanynetwork", Optional[Annotated[str, Format("ipvanynetwork")]], Unset),
160+
("phone", Optional[Annotated[str, Format("phone")]], Unset),
161+
]
162+
param_model = cast(type[ParamModel], make_dataclass("Model", fields, bases=(ParamModel,)))
163+
164+
# Make sure the model covers all Pydantic specific types we support
165+
annotated_types = [param_type_util.get_annotated_type(t) for _, t, _ in fields]
166+
param_formats = [x.__metadata__[0].value for x in annotated_types]
167+
undefined_formats = set(PARAM_FORMAT_AND_TYPE_MAP.keys()).difference(set(param_formats))
168+
assert not undefined_formats, f"Missing test coverage for these formats: {undefined_formats}"
169+
170+
return param_model

tests/unit/test_param_type.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def test_get_type_annotation_as_str(tp: Any, expected_tp_str: str, is_optional:
109109
(str | int, str | int),
110110
(str | None, str),
111111
(str | int | None, str | int),
112+
(Annotated[str, Constraint(min_len=5)] | Annotated[int, Constraint(min=5)], str | int),
112113
(list[str], str),
113114
(list[dict[str, Any]], dict[str, Any]),
114115
(Annotated[str, "meta"], str),
@@ -225,6 +226,7 @@ def test_is_optional_type(tp: Any, is_optional_type: bool) -> None:
225226
assert param_type_util.is_optional_type(tp) is is_optional_type
226227

227228

229+
@pytest.mark.parametrize("exclude_optional", [False, True])
228230
@pytest.mark.parametrize(
229231
("tp", "is_union_type"),
230232
[
@@ -240,12 +242,16 @@ def test_is_optional_type(tp: Any, is_optional_type: bool) -> None:
240242
(Optional[Annotated[str | int, "meta"]], True),
241243
],
242244
)
243-
def test_is_union_type(tp: Any, is_union_type: bool) -> None:
245+
def test_is_union_type(tp: Any, is_union_type: bool, exclude_optional: bool) -> None:
244246
"""Verify that we can check whether a given type annotation itself is a union type or not
245247
246248
Note: Optional[] is also considered as union
247249
"""
248-
assert param_type_util.is_union_type(tp) is is_union_type
250+
if exclude_optional:
251+
is_union = is_union_type and NoneType not in get_args(tp)
252+
else:
253+
is_union = is_union_type
254+
assert param_type_util.is_union_type(tp, exclude_optional=exclude_optional) is is_union
249255

250256

251257
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)