Skip to content

Commit 51ea4cc

Browse files
committed
refactor(models): move schema, parameters and paginated response to models_utils
1 parent 8a071e4 commit 51ea4cc

File tree

4 files changed

+140
-122
lines changed

4 files changed

+140
-122
lines changed

pygitguardian/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
CreateTeamInvitation,
3333
CreateTeamMember,
3434
CreateTeamMemberParameter,
35-
CursorPaginatedResponse,
3635
DeleteMember,
3736
Detail,
3837
Document,
@@ -66,6 +65,7 @@
6665
UpdateTeam,
6766
UpdateTeamSource,
6867
)
68+
from .models_utils import CursorPaginatedResponse
6969
from .sca_models import (
7070
ComputeSCAFilesResult,
7171
SCAScanAllOutput,

pygitguardian/models.py

Lines changed: 23 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,18 @@
44
from dataclasses import dataclass, field
55
from datetime import date, datetime
66
from enum import Enum
7-
from typing import (
8-
TYPE_CHECKING,
9-
Any,
10-
ClassVar,
11-
Dict,
12-
Generic,
13-
List,
14-
Literal,
15-
Optional,
16-
Type,
17-
TypeVar,
18-
cast,
19-
)
7+
from typing import Any, Dict, List, Literal, Optional, Type, cast
208
from uuid import UUID
219

2210
import marshmallow_dataclass
2311
from marshmallow import (
24-
EXCLUDE,
25-
Schema,
2612
ValidationError,
2713
fields,
2814
post_dump,
2915
post_load,
3016
pre_load,
3117
validate,
3218
)
33-
from typing_extensions import Self
3419

3520
from .config import (
3621
DEFAULT_PRE_COMMIT_MESSAGE,
@@ -39,64 +24,14 @@
3924
DOCUMENT_SIZE_THRESHOLD_BYTES,
4025
MULTI_DOCUMENT_LIMIT,
4126
)
42-
43-
44-
if TYPE_CHECKING:
45-
import requests
46-
47-
48-
class ToDictMixin:
49-
"""
50-
Provides a type-safe `to_dict()` method for classes using Marshmallow
51-
"""
52-
53-
SCHEMA: ClassVar[Schema]
54-
55-
def to_dict(self) -> Dict[str, Any]:
56-
return cast(Dict[str, Any], self.SCHEMA.dump(self))
57-
58-
59-
class FromDictMixin:
60-
"""This class must be used as an additional base class for all classes whose schema
61-
implements a `post_load` function turning the received dict into a class instance.
62-
63-
It makes it possible to deserialize an object using `MyClass.from_dict(dct)` instead
64-
of `MyClass.SCHEMA.load(dct)`. The `from_dict()` method is shorter, but more
65-
importantly, type-safe: its return type is an instance of `MyClass`, not
66-
`list[Any] | Any`.
67-
68-
Reference: https://marshmallow.readthedocs.io/en/stable/quickstart.html#deserializing-to-objects E501
69-
"""
70-
71-
SCHEMA: ClassVar[Schema]
72-
73-
@classmethod
74-
def from_dict(cls, dct: Dict[str, Any]) -> Self:
75-
return cast(Self, cls.SCHEMA.load(dct))
76-
77-
78-
class BaseSchema(Schema):
79-
class Meta:
80-
ordered = True
81-
unknown = EXCLUDE
82-
83-
84-
class Base(ToDictMixin):
85-
def __init__(self, status_code: Optional[int] = None) -> None:
86-
self.status_code = status_code
87-
88-
def to_json(self) -> str:
89-
"""
90-
to_json converts model to JSON string.
91-
"""
92-
return cast(str, self.SCHEMA.dumps(self))
93-
94-
@property
95-
def success(self) -> bool:
96-
return self.__bool__()
97-
98-
def __bool__(self) -> bool:
99-
return self.status_code == 200
27+
from .models_utils import (
28+
Base,
29+
BaseSchema,
30+
FromDictMixin,
31+
PaginationParameter,
32+
SearchParameter,
33+
ToDictMixin,
34+
)
10035

10136

10237
class DocumentSchema(BaseSchema):
@@ -1148,44 +1083,6 @@ class AccessLevel(str, Enum):
11481083
RESTRICTED = "restricted"
11491084

11501085

1151-
class PaginationParameter(ToDictMixin):
1152-
"""Pagination mixin used for endpoints that support pagination."""
1153-
1154-
cursor: str = ""
1155-
per_page: int = 20
1156-
1157-
1158-
class SearchParameter(ToDictMixin):
1159-
search: Optional[str] = None
1160-
1161-
1162-
PaginatedData = TypeVar("PaginatedData", bound=FromDictMixin)
1163-
1164-
1165-
@dataclass
1166-
class CursorPaginatedResponse(Generic[PaginatedData]):
1167-
status_code: int
1168-
data: List[PaginatedData]
1169-
prev: Optional[str] = None
1170-
next: Optional[str] = None
1171-
1172-
@classmethod
1173-
def from_response(
1174-
cls, response: "requests.Response", data_type: Type[PaginatedData]
1175-
) -> "CursorPaginatedResponse[PaginatedData]":
1176-
data = cast(
1177-
List[PaginatedData], [data_type.from_dict(obj) for obj in response.json()]
1178-
)
1179-
paginated_response = cls(status_code=response.status_code, data=data)
1180-
1181-
if previous_page := response.links.get("prev"):
1182-
paginated_response.prev = previous_page["url"]
1183-
if next_page := response.links.get("next"):
1184-
paginated_response.prev = next_page["url"]
1185-
1186-
return paginated_response
1187-
1188-
11891086
@dataclass
11901087
class MembersParameters(PaginationParameter, SearchParameter, ToDictMixin):
11911088
"""
@@ -1228,6 +1125,11 @@ class Member(Base, FromDictMixin):
12281125

12291126

12301127
class MemberSchema(BaseSchema):
1128+
"""
1129+
This schema cannot be done through marshmallow_dataclass as we want to use the
1130+
values of the AccessLevel enum to create the enum field
1131+
"""
1132+
12311133
id = fields.Int(required=True)
12321134
access_level = fields.Enum(AccessLevel, by_value=True, required=True)
12331135
email = fields.Str(required=True)
@@ -1249,6 +1151,11 @@ def return_member(
12491151

12501152

12511153
class UpdateMemberSchema(BaseSchema):
1154+
"""
1155+
This schema cannot be done through marshmallow_dataclass as we want to use the
1156+
values of the AccessLevel enum to create the enum field
1157+
"""
1158+
12521159
id = fields.Int(required=True)
12531160
access_level = fields.Enum(AccessLevel, by_value=True, allow_none=True)
12541161
active = fields.Bool(allow_none=True)
@@ -1323,14 +1230,10 @@ class CreateTeam(Base, FromDictMixin):
13231230
description: Optional[str] = ""
13241231

13251232

1326-
class CreateTeamSchema(BaseSchema):
1327-
many = False
1328-
1329-
name = fields.Str(required=True)
1330-
description = fields.Str(allow_none=True)
1331-
1332-
class Meta:
1333-
exclude_none = True
1233+
CreateTeamSchema = cast(
1234+
Type[BaseSchema],
1235+
marshmallow_dataclass.class_schema(CreateTeam, base_schema=BaseSchema),
1236+
)
13341237

13351238

13361239
CreateTeam.SCHEMA = CreateTeamSchema()

pygitguardian/models_utils.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# pyright: reportIncompatibleVariableOverride=false
2+
# Disable this check because of multiple non-dangerous violations (SCHEMA variables,
3+
# BaseSchema.Meta class)
4+
from dataclasses import dataclass
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
ClassVar,
9+
Dict,
10+
Generic,
11+
List,
12+
Optional,
13+
Type,
14+
TypeVar,
15+
cast,
16+
)
17+
18+
from marshmallow import EXCLUDE, Schema
19+
from typing_extensions import Self
20+
21+
22+
if TYPE_CHECKING:
23+
import requests
24+
25+
26+
class ToDictMixin:
27+
"""
28+
Provides a type-safe `to_dict()` method for classes using Marshmallow
29+
"""
30+
31+
SCHEMA: ClassVar[Schema]
32+
33+
def to_dict(self) -> Dict[str, Any]:
34+
return cast(Dict[str, Any], self.SCHEMA.dump(self))
35+
36+
37+
class FromDictMixin:
38+
"""This class must be used as an additional base class for all classes whose schema
39+
implements a `post_load` function turning the received dict into a class instance.
40+
41+
It makes it possible to deserialize an object using `MyClass.from_dict(dct)` instead
42+
of `MyClass.SCHEMA.load(dct)`. The `from_dict()` method is shorter, but more
43+
importantly, type-safe: its return type is an instance of `MyClass`, not
44+
`list[Any] | Any`.
45+
46+
Reference: https://marshmallow.readthedocs.io/en/stable/quickstart.html#deserializing-to-objects E501
47+
"""
48+
49+
SCHEMA: ClassVar[Schema]
50+
51+
@classmethod
52+
def from_dict(cls, dct: Dict[str, Any]) -> Self:
53+
return cast(Self, cls.SCHEMA.load(dct))
54+
55+
56+
class BaseSchema(Schema):
57+
class Meta:
58+
ordered = True
59+
unknown = EXCLUDE
60+
61+
62+
class Base(ToDictMixin):
63+
def __init__(self, status_code: Optional[int] = None) -> None:
64+
self.status_code = status_code
65+
66+
def to_json(self) -> str:
67+
"""
68+
to_json converts model to JSON string.
69+
"""
70+
return cast(str, self.SCHEMA.dumps(self))
71+
72+
@property
73+
def success(self) -> bool:
74+
return self.__bool__()
75+
76+
def __bool__(self) -> bool:
77+
return self.status_code == 200
78+
79+
80+
class PaginationParameter(ToDictMixin):
81+
"""Pagination mixin used for endpoints that support pagination."""
82+
83+
cursor: str = ""
84+
per_page: int = 20
85+
86+
87+
class SearchParameter(ToDictMixin):
88+
search: Optional[str] = None
89+
90+
91+
PaginatedData = TypeVar("PaginatedData", bound=FromDictMixin)
92+
93+
94+
@dataclass
95+
class CursorPaginatedResponse(Generic[PaginatedData]):
96+
status_code: int
97+
data: List[PaginatedData]
98+
prev: Optional[str] = None
99+
next: Optional[str] = None
100+
101+
@classmethod
102+
def from_response(
103+
cls, response: "requests.Response", data_type: Type[PaginatedData]
104+
) -> "CursorPaginatedResponse[PaginatedData]":
105+
data = cast(
106+
List[PaginatedData], [data_type.from_dict(obj) for obj in response.json()]
107+
)
108+
paginated_response = cls(status_code=response.status_code, data=data)
109+
110+
if previous_page := response.links.get("prev"):
111+
paginated_response.prev = previous_page["url"]
112+
if next_page := response.links.get("next"):
113+
paginated_response.prev = next_page["url"]
114+
115+
return paginated_response

tests/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
CreateTeamInvitation,
3232
CreateTeamMember,
3333
CreateTeamMemberParameter,
34-
CursorPaginatedResponse,
3534
DeleteMember,
3635
Detail,
3736
HoneytokenResponse,
@@ -58,6 +57,7 @@
5857
UpdateTeam,
5958
UpdateTeamSource,
6059
)
60+
from pygitguardian.models_utils import CursorPaginatedResponse
6161
from pygitguardian.sca_models import (
6262
ComputeSCAFilesResult,
6363
SCAScanAllOutput,

0 commit comments

Comments
 (0)