Skip to content

Commit 7a0bb22

Browse files
authored
feature: saving data into model from ModelSchema (#24)
## Missing Tests ## Generic Type Support ```python class UserSchema(ModelSchema[User]): class Config: model = User include = ["id", "email", "profile"] ``` Inference type of Django model is supported in the schema class. This allows for better IDE support and type checking. The `ModelSchema` class can be used with any Django model, and the type of the model can be specified as a generic type parameter. inference in save method: ```python from typing import TypeVar from djantic import ModelSchema from myapp.models import User class UserSchema(ModelSchema[User]): class Config: model = User include = ("username", "email", "first_name", "last_name", "is_staff") serialized_user = UserSchema( username="myusername", email="my@email.com", first_name="My First Name", last_name="My Last Name", is_staff=True, ) new_user = serialized_user.save() ``` is optional, but it is recommended to use the `ModelSchema` class with the Django model type as a generic type parameter. This allows for better IDE support and type checking. Also with generic type support, now it's not necessary to define `model` in the `Config` class. The `ModelSchema` class will automatically infer the model type from the generic type parameter. This allows to get model type from the schema class itself. ```python class UserSchema(ModelSchema[User]): class Config: include = ("username", "email", "first_name", "last_name", "is_staff") serialized_user = UserSchema( username="myusername", email="my@email.com", first_name="My First Name", last_name="My Last Name", is_staff=True, ) new_user = serialized_user.save() ```
1 parent 038500a commit 7a0bb22

File tree

5 files changed

+335
-19
lines changed

5 files changed

+335
-19
lines changed

djantic/main.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from functools import reduce
55
from itertools import chain
6-
from typing import Any, Dict, List, Optional, Union, no_type_check
6+
from typing import Any, Dict, List, Optional, TypeVar, Union, no_type_check
77

88
from django.core.serializers.json import DjangoJSONEncoder
99
from django.db.models import Manager, Model
@@ -21,10 +21,16 @@
2121
else:
2222
from typing import Union as UnionType
2323

24+
25+
from django.db.models import Model as DjangoModel
26+
2427
from .fields import ModelSchemaField
28+
from .mixin import ModelSchemaMixin
2529

2630
_is_base_model_class_defined = False
2731

32+
_M = TypeVar("_M", bound=DjangoModel)
33+
2834

2935
class ModelSchemaJSONEncoder(DjangoJSONEncoder):
3036
@no_type_check
@@ -48,14 +54,32 @@ class ModelSchemaMetaclass(ModelMetaclass):
4854
@no_type_check
4955
def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs):
5056
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
57+
58+
config = namespace.get("model_config", {})
59+
5160
for base in reversed(bases):
5261
if (
5362
_is_base_model_class_defined
63+
and not (config is None or config == {})
5464
and issubclass(base, ModelSchema)
55-
and base == ModelSchema
65+
and
66+
## Start to ensure generic origin is ModelSchema
67+
# When schema is inherited from another class with generic
68+
# origin, we need to check if the base class is ModelSchema
69+
(
70+
(
71+
hasattr(base, "__pydantic_generic_metadata__")
72+
and (
73+
issubclass(
74+
base.__pydantic_generic_metadata__.get("origin"),
75+
ModelSchema,
76+
)
77+
)
78+
)
79+
or base == ModelSchema
80+
)
5681
):
57-
58-
config = namespace["model_config"]
82+
## Finish to ensure generic origin is ModelSchema
5983
include = config.get("include", None)
6084
exclude = config.get("exclude", None)
6185

@@ -69,12 +93,33 @@ def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs):
6993
annotations = namespace.get("__annotations__", {})
7094

7195
try:
96+
## Get from generic metadata if available
97+
#
98+
if "model" not in config:
99+
if hasattr(
100+
base, "__pydantic_generic_metadata__"
101+
) and base.__pydantic_generic_metadata__.get("args"):
102+
config["model"] = base.__pydantic_generic_metadata__.get(
103+
"args"
104+
)[0]
105+
72106
fields = config["model"]._meta.get_fields()
107+
73108
except (AttributeError, KeyError) as exc:
74109
raise PydanticUserError(
75-
f'{exc} (Is `model_config["model"]` a valid Django model class?)',
110+
(
111+
f'{exc} (Is model_config["model"] a valid Django model class?) '
112+
'\nPlease set the model_config["model"] or a generic type with a '
113+
"Django model class as the first argument. \n\n"
114+
"Example: \n\n"
115+
"- class MyModelSchema(ModelSchema):\n"
116+
'\n model_config = {"model": MyModel}\n\n'
117+
"or\n\n"
118+
"- class MyModelSchema(ModelSchema[MyModel]):\n"
119+
" ...\n"
120+
),
76121
code="class-not-valid",
77-
)
122+
) from None
78123

79124
if include == "__annotations__":
80125
include = list(annotations.keys())
@@ -103,7 +148,6 @@ def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs):
103148
python_type = None
104149
pydantic_field = None
105150
if field_name in annotations and field_name in namespace:
106-
107151
python_type = annotations.pop(field_name)
108152
pydantic_field = namespace[field_name]
109153
if (
@@ -135,6 +179,7 @@ def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs):
135179
__doc__=cls.__doc__,
136180
**field_values,
137181
)
182+
138183
return model_schema
139184

140185
return cls
@@ -143,10 +188,10 @@ def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs):
143188
def _is_optional_field(annotation) -> bool:
144189
args = get_args(annotation)
145190
return (
146-
(get_origin(annotation) is Union or get_origin(annotation) is UnionType)
147-
and type(None) in args
148-
and len(args) == 2
149-
and any(inspect.isclass(arg) and issubclass(arg, ModelSchema) for arg in args)
191+
(get_origin(annotation) is Union or get_origin(annotation) is UnionType)
192+
and type(None) in args
193+
and len(args) == 2
194+
and any(inspect.isclass(arg) and issubclass(arg, ModelSchema) for arg in args)
150195
)
151196

152197

@@ -157,7 +202,7 @@ def __init__(self, obj: Any, schema_class):
157202

158203
def get(self, key: Any, default: Any = None) -> Any:
159204
if "__" in key:
160-
# Allow double underscores aliases: `first_name: str = Field(alias="user__first_name")`
205+
# Allow double underscores aliases: first_name: str = Field(alias="user__first_name")
161206
keys_map = key.split("__")
162207
attr = reduce(lambda a, b: getattr(a, b, default), keys_map, self._obj)
163208
else:
@@ -221,7 +266,9 @@ def dict(self) -> dict:
221266
non_none_type_annotation = next(
222267
arg for arg in get_args(annotation) if arg is not type(None)
223268
)
224-
data[key] = self._get_annotation_objects(value, non_none_type_annotation)
269+
data[key] = self._get_annotation_objects(
270+
value, non_none_type_annotation
271+
)
225272

226273
elif inspect.isclass(annotation) and issubclass(annotation, ModelSchema):
227274
data[key] = self._get_annotation_objects(self.get(key), annotation)
@@ -231,8 +278,7 @@ def dict(self) -> dict:
231278
return data
232279

233280

234-
class ModelSchema(BaseModel, metaclass=ModelSchemaMetaclass):
235-
281+
class ModelSchema(BaseModel, ModelSchemaMixin[_M], metaclass=ModelSchemaMetaclass):
236282
def __eq__(self, other: Any) -> bool:
237283
result = super().__eq__(other)
238284
if isinstance(result, bool):

djantic/mixin.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from typing import Any, Generic, Optional, TypeVar, Union
2+
3+
from django.db.models import Model as DjangoModel
4+
5+
_M = TypeVar("_M", bound=DjangoModel)
6+
7+
8+
class ModelSchemaMixin(Generic[_M]):
9+
def create(self, *args: Any, **kwargs: Any) -> _M:
10+
ModelDjangoClass: type[_M] = self.model_config["model"]
11+
12+
record: _M = ModelDjangoClass._default_manager.create(**self.model_dump())
13+
14+
return record
15+
16+
def update(
17+
self, instance: _M, partial: Optional[bool] = None, *args: Any, **kwargs: Any
18+
) -> _M:
19+
if not isinstance(instance, self.model_config["model"]):
20+
raise TypeError(
21+
"instance is not of the type {0}".format(self.model_config["model"]) # noqa
22+
)
23+
24+
data = self.model_dump() if not partial else self.model_dump(exclude_unset=True)
25+
26+
if instance:
27+
# Update the existing instance with the new data
28+
for key, value in data.items():
29+
if hasattr(instance, key):
30+
setattr(instance, key, value)
31+
else:
32+
raise ValueError(f"Field {key} does not exist on the model.")
33+
instance.save(*args, **kwargs)
34+
35+
return instance
36+
37+
def save(
38+
self,
39+
instance: Optional[_M] = None,
40+
partial: Union[bool, None] = None,
41+
*args: Any,
42+
**kwargs: Any,
43+
) -> _M:
44+
"""Save the model instance to the database.
45+
46+
This method saves the current model data to the database. If an instance is
47+
provided, it updates the existing instance with the data from the current object.
48+
If no instance is provided, it creates a new instance in the database.
49+
50+
Args:
51+
instance: An optional model instance to update. If provided, the instance will be
52+
updated with the current model data. If None, a new instance will be created.
53+
partial: If True, only fields that have been explicitly set will be updated.
54+
If Unset, all fields will be updated/saved.
55+
*args: Additional positional arguments to pass to the model's save method.
56+
**kwargs: Additional keyword arguments to pass to the model's save method.
57+
58+
Returns:
59+
The saved model instance.
60+
61+
Raises:
62+
ValueError: If a field in the model data does not exist on the provided instance.
63+
"""
64+
65+
if instance:
66+
record = self.update(instance, partial, *args, *kwargs)
67+
assert record is not None, "`update()` did not return an object instance."
68+
else:
69+
record = self.create(*args, **kwargs)
70+
assert record is not None, "`create()` did not return an object instance."
71+
return record

docs/usage.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,65 @@ Output:
456456
"updated_at": "2021-04-04T08:47:39.567455+00:00"
457457
}
458458
```
459+
## Generic Type Support
460+
461+
```python
462+
class UserSchema(ModelSchema[User]):
463+
class Config:
464+
model = User
465+
include = ["id", "email", "profile"]
466+
```
467+
468+
Inference type of Django model is supported in the schema class. This allows for better IDE support and type checking. The `ModelSchema` class can be used with any Django model, and the type of the model can be specified as a generic type parameter.
469+
470+
inference in save method:
471+
472+
```python
473+
from typing import TypeVar
474+
475+
from djantic import ModelSchema
476+
from myapp.models import User
477+
478+
479+
class UserSchema(ModelSchema[User]):
480+
481+
class Config:
482+
model = User
483+
include = ("username", "email", "first_name", "last_name", "is_staff")
484+
485+
486+
serialized_user = UserSchema(
487+
username="myusername",
488+
email="my@email.com",
489+
first_name="My First Name",
490+
last_name="My Last Name",
491+
is_staff=True,
492+
)
493+
494+
new_user = serialized_user.save()
495+
```
496+
497+
is optional, but it is recommended to use the `ModelSchema` class with the Django model type as a generic type parameter. This allows for better IDE support and type checking.
498+
499+
Also with generic type support, now it's not necessary to define `model` in the `Config` class. The `ModelSchema` class will automatically infer the model type from the generic type parameter. This allows to get model type from the schema class itself.
500+
501+
```python
502+
503+
class UserSchema(ModelSchema[User]):
504+
505+
class Config:
506+
include = ("username", "email", "first_name", "last_name", "is_staff")
507+
508+
509+
serialized_user = UserSchema(
510+
username="myusername",
511+
email="my@email.com",
512+
first_name="My First Name",
513+
last_name="My Last Name",
514+
is_staff=True,
515+
)
516+
517+
new_user = serialized_user.save()
518+
```
519+
520+
IDE SUPPORT

tests/test_main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
import pytest
2+
from pydantic import ConfigDict
23
from pydantic.errors import PydanticUserError
3-
from testapp.models import User
44

5-
from pydantic import ConfigDict
65
from djantic import ModelSchema
6+
from testapp.models import User
77

88

99
@pytest.mark.django_db
1010
def test_model_config_contains_valid_model():
11-
error_msg = r"(Is `model_config\[\"model\"\]` a valid Django model class?)"
12-
with pytest.raises(PydanticUserError, match=error_msg):
11+
error_msg = "(Is model_config[\"model\"] a valid Django model class?)" # fmt: skip
12+
with pytest.raises(PydanticUserError) as exc_info:
1313

1414
class InvalidModelErrorSchema2(ModelSchema):
1515
model_config = ConfigDict(model="Ok")
1616

17+
assert error_msg in str(exc_info.value)
18+
1719

1820
@pytest.mark.django_db
1921
def test_include_and_exclude_mutually_exclusive():

0 commit comments

Comments
 (0)