Skip to content

Commit

Permalink
refactor(models): move schema, parameters and paginated response to m…
Browse files Browse the repository at this point in the history
…odels_utils
  • Loading branch information
GG-Yanne committed Dec 11, 2024
1 parent 8a071e4 commit 51ea4cc
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 122 deletions.
2 changes: 1 addition & 1 deletion pygitguardian/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
CreateTeamInvitation,
CreateTeamMember,
CreateTeamMemberParameter,
CursorPaginatedResponse,
DeleteMember,
Detail,
Document,
Expand Down Expand Up @@ -66,6 +65,7 @@
UpdateTeam,
UpdateTeamSource,
)
from .models_utils import CursorPaginatedResponse
from .sca_models import (
ComputeSCAFilesResult,
SCAScanAllOutput,
Expand Down
143 changes: 23 additions & 120 deletions pygitguardian/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,18 @@
from dataclasses import dataclass, field
from datetime import date, datetime
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
List,
Literal,
Optional,
Type,
TypeVar,
cast,
)
from typing import Any, Dict, List, Literal, Optional, Type, cast
from uuid import UUID

import marshmallow_dataclass
from marshmallow import (
EXCLUDE,
Schema,
ValidationError,
fields,
post_dump,
post_load,
pre_load,
validate,
)
from typing_extensions import Self

from .config import (
DEFAULT_PRE_COMMIT_MESSAGE,
Expand All @@ -39,64 +24,14 @@
DOCUMENT_SIZE_THRESHOLD_BYTES,
MULTI_DOCUMENT_LIMIT,
)


if TYPE_CHECKING:
import requests


class ToDictMixin:
"""
Provides a type-safe `to_dict()` method for classes using Marshmallow
"""

SCHEMA: ClassVar[Schema]

def to_dict(self) -> Dict[str, Any]:
return cast(Dict[str, Any], self.SCHEMA.dump(self))


class FromDictMixin:
"""This class must be used as an additional base class for all classes whose schema
implements a `post_load` function turning the received dict into a class instance.
It makes it possible to deserialize an object using `MyClass.from_dict(dct)` instead
of `MyClass.SCHEMA.load(dct)`. The `from_dict()` method is shorter, but more
importantly, type-safe: its return type is an instance of `MyClass`, not
`list[Any] | Any`.
Reference: https://marshmallow.readthedocs.io/en/stable/quickstart.html#deserializing-to-objects E501
"""

SCHEMA: ClassVar[Schema]

@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> Self:
return cast(Self, cls.SCHEMA.load(dct))


class BaseSchema(Schema):
class Meta:
ordered = True
unknown = EXCLUDE


class Base(ToDictMixin):
def __init__(self, status_code: Optional[int] = None) -> None:
self.status_code = status_code

def to_json(self) -> str:
"""
to_json converts model to JSON string.
"""
return cast(str, self.SCHEMA.dumps(self))

@property
def success(self) -> bool:
return self.__bool__()

def __bool__(self) -> bool:
return self.status_code == 200
from .models_utils import (
Base,
BaseSchema,
FromDictMixin,
PaginationParameter,
SearchParameter,
ToDictMixin,
)


class DocumentSchema(BaseSchema):
Expand Down Expand Up @@ -1148,44 +1083,6 @@ class AccessLevel(str, Enum):
RESTRICTED = "restricted"


class PaginationParameter(ToDictMixin):
"""Pagination mixin used for endpoints that support pagination."""

cursor: str = ""
per_page: int = 20


class SearchParameter(ToDictMixin):
search: Optional[str] = None


PaginatedData = TypeVar("PaginatedData", bound=FromDictMixin)


@dataclass
class CursorPaginatedResponse(Generic[PaginatedData]):
status_code: int
data: List[PaginatedData]
prev: Optional[str] = None
next: Optional[str] = None

@classmethod
def from_response(
cls, response: "requests.Response", data_type: Type[PaginatedData]
) -> "CursorPaginatedResponse[PaginatedData]":
data = cast(
List[PaginatedData], [data_type.from_dict(obj) for obj in response.json()]
)
paginated_response = cls(status_code=response.status_code, data=data)

if previous_page := response.links.get("prev"):
paginated_response.prev = previous_page["url"]
if next_page := response.links.get("next"):
paginated_response.prev = next_page["url"]

return paginated_response


@dataclass
class MembersParameters(PaginationParameter, SearchParameter, ToDictMixin):
"""
Expand Down Expand Up @@ -1228,6 +1125,11 @@ class Member(Base, FromDictMixin):


class MemberSchema(BaseSchema):
"""
This schema cannot be done through marshmallow_dataclass as we want to use the
values of the AccessLevel enum to create the enum field
"""

id = fields.Int(required=True)
access_level = fields.Enum(AccessLevel, by_value=True, required=True)
email = fields.Str(required=True)
Expand All @@ -1249,6 +1151,11 @@ def return_member(


class UpdateMemberSchema(BaseSchema):
"""
This schema cannot be done through marshmallow_dataclass as we want to use the
values of the AccessLevel enum to create the enum field
"""

id = fields.Int(required=True)
access_level = fields.Enum(AccessLevel, by_value=True, allow_none=True)
active = fields.Bool(allow_none=True)
Expand Down Expand Up @@ -1323,14 +1230,10 @@ class CreateTeam(Base, FromDictMixin):
description: Optional[str] = ""


class CreateTeamSchema(BaseSchema):
many = False

name = fields.Str(required=True)
description = fields.Str(allow_none=True)

class Meta:
exclude_none = True
CreateTeamSchema = cast(
Type[BaseSchema],
marshmallow_dataclass.class_schema(CreateTeam, base_schema=BaseSchema),
)


CreateTeam.SCHEMA = CreateTeamSchema()
Expand Down
115 changes: 115 additions & 0 deletions pygitguardian/models_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# pyright: reportIncompatibleVariableOverride=false
# Disable this check because of multiple non-dangerous violations (SCHEMA variables,
# BaseSchema.Meta class)
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
cast,
)

from marshmallow import EXCLUDE, Schema
from typing_extensions import Self


if TYPE_CHECKING:
import requests


class ToDictMixin:
"""
Provides a type-safe `to_dict()` method for classes using Marshmallow
"""

SCHEMA: ClassVar[Schema]

def to_dict(self) -> Dict[str, Any]:
return cast(Dict[str, Any], self.SCHEMA.dump(self))


class FromDictMixin:
"""This class must be used as an additional base class for all classes whose schema
implements a `post_load` function turning the received dict into a class instance.
It makes it possible to deserialize an object using `MyClass.from_dict(dct)` instead
of `MyClass.SCHEMA.load(dct)`. The `from_dict()` method is shorter, but more
importantly, type-safe: its return type is an instance of `MyClass`, not
`list[Any] | Any`.
Reference: https://marshmallow.readthedocs.io/en/stable/quickstart.html#deserializing-to-objects E501
"""

SCHEMA: ClassVar[Schema]

@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> Self:
return cast(Self, cls.SCHEMA.load(dct))


class BaseSchema(Schema):
class Meta:
ordered = True
unknown = EXCLUDE


class Base(ToDictMixin):
def __init__(self, status_code: Optional[int] = None) -> None:
self.status_code = status_code

def to_json(self) -> str:
"""
to_json converts model to JSON string.
"""
return cast(str, self.SCHEMA.dumps(self))

@property
def success(self) -> bool:
return self.__bool__()

def __bool__(self) -> bool:
return self.status_code == 200


class PaginationParameter(ToDictMixin):
"""Pagination mixin used for endpoints that support pagination."""

cursor: str = ""
per_page: int = 20


class SearchParameter(ToDictMixin):
search: Optional[str] = None


PaginatedData = TypeVar("PaginatedData", bound=FromDictMixin)


@dataclass
class CursorPaginatedResponse(Generic[PaginatedData]):
status_code: int
data: List[PaginatedData]
prev: Optional[str] = None
next: Optional[str] = None

@classmethod
def from_response(
cls, response: "requests.Response", data_type: Type[PaginatedData]
) -> "CursorPaginatedResponse[PaginatedData]":
data = cast(
List[PaginatedData], [data_type.from_dict(obj) for obj in response.json()]
)
paginated_response = cls(status_code=response.status_code, data=data)

if previous_page := response.links.get("prev"):
paginated_response.prev = previous_page["url"]
if next_page := response.links.get("next"):
paginated_response.prev = next_page["url"]

return paginated_response
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
CreateTeamInvitation,
CreateTeamMember,
CreateTeamMemberParameter,
CursorPaginatedResponse,
DeleteMember,
Detail,
HoneytokenResponse,
Expand All @@ -58,6 +57,7 @@
UpdateTeam,
UpdateTeamSource,
)
from pygitguardian.models_utils import CursorPaginatedResponse
from pygitguardian.sca_models import (
ComputeSCAFilesResult,
SCAScanAllOutput,
Expand Down

0 comments on commit 51ea4cc

Please sign in to comment.