Skip to content

Commit bb86a85

Browse files
committed
Add PatchSchema
1 parent 932d22b commit bb86a85

File tree

2 files changed

+326
-0
lines changed

2 files changed

+326
-0
lines changed

ninja/patch_schema.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import TypeVar, Generic, Any, Type, Optional, Union, get_origin, get_args, Dict
2+
from pydantic import BaseModel, create_model, model_validator
3+
4+
5+
ModelT = TypeVar("ModelT", bound=BaseModel)
6+
7+
# Type alias for patched models to help with type checking
8+
# This allows using cast(PatchedModel, PatchSchema[SomeModel]()) to properly type the model_dump() method
9+
PatchedModel = BaseModel
10+
11+
12+
class PatchSchema(Generic[ModelT]):
13+
"""Generate a patchable version of a Pydantic model.
14+
15+
Makes all fields optional, but doesn't allow None values unless the field was originally defined as Optional.
16+
This allows for partial updates where fields can be omitted or provided with legitimate values.
17+
18+
Example:
19+
Given a schema:
20+
class ExampleSchema(BaseModel):
21+
example_field: str
22+
optional_field: Optional[str] = None
23+
24+
PatchSchema[ExampleSchema] will allow:
25+
- PatchSchema[ExampleSchema]() (no fields provided)
26+
- PatchSchema[ExampleSchema](example_field="example") (field provided)
27+
- PatchSchema[ExampleSchema](optional_field=None) (None allowed for originally optional fields)
28+
29+
But will not allow:
30+
- PatchSchema[ExampleSchema](example_field=None) (None is not allowed for non-optional fields)
31+
32+
Usage:
33+
# Define a regular schema
34+
class UserSchema(BaseModel):
35+
name: str
36+
email: str
37+
avatar_url: Optional[str] = None
38+
39+
# Create a patchable version that allows partial updates
40+
PatchUserSchema = PatchSchema[UserSchema]
41+
42+
# Use the patched schema for partial updates
43+
patch_data = PatchUserSchema(name="New Name") # Only updates the name field
44+
patch_data = PatchUserSchema(avatar_url=None) # Can set avatar_url to None since it's optional
45+
"""
46+
47+
def __new__(
48+
cls,
49+
*args: Any,
50+
**kwargs: Any,
51+
) -> "PatchSchema[ModelT]":
52+
"""Cannot instantiate directly."""
53+
raise TypeError("Cannot instantiate abstract PatchSchema class.")
54+
55+
def __init_subclass__(
56+
cls,
57+
*args: Any,
58+
**kwargs: Any,
59+
) -> None:
60+
"""Cannot subclass."""
61+
raise TypeError(f"Cannot subclass {cls.__module__}.PatchSchema")
62+
63+
@classmethod
64+
def _is_optional_type(cls, annotation):
65+
"""Check if a type annotation is Optional[X]."""
66+
if get_origin(annotation) is Union:
67+
args = get_args(annotation)
68+
return type(None) in args
69+
return False
70+
71+
@classmethod
72+
def __class_getitem__(
73+
cls,
74+
wrapped_class: Type[ModelT],
75+
) -> Type[PatchedModel]:
76+
"""Convert model to a patchable model where fields are optional but can't be None unless originally Optional."""
77+
78+
# Create field definitions for the new model
79+
fields = {}
80+
# Keep track of which fields were originally optional
81+
originally_optional_fields = set()
82+
83+
# Access model_fields through instance property, not class attribute
84+
model_fields = getattr(wrapped_class, "model_fields", {})
85+
for field_name, field_info in model_fields.items():
86+
# Make the field optional by setting a default value
87+
annotation = field_info.annotation
88+
89+
# Check if the field was originally optional
90+
if cls._is_optional_type(annotation):
91+
originally_optional_fields.add(field_name)
92+
93+
fields[field_name] = (Optional[annotation], None)
94+
95+
# Create the new model class
96+
class PatchModel(BaseModel):
97+
model_config = {"extra": "ignore", "arbitrary_types_allowed": True}
98+
99+
@model_validator(mode='before')
100+
@classmethod
101+
def validate_no_none_values(cls, data):
102+
if isinstance(data, dict):
103+
# Check for explicit None values and raise error for non-optional fields
104+
for key, value in list(data.items()):
105+
if value is None and key not in originally_optional_fields:
106+
raise ValueError(f"Field '{key}' cannot be None")
107+
# Keep only non-None values
108+
return {k: v for k, v in data.items() if v is not None or k in originally_optional_fields}
109+
return data
110+
111+
# We don't need a custom schema generator anymore since Pydantic v2 uses anyOf for optional fields
112+
113+
def model_dump(self, **kwargs) -> Dict[str, Any]:
114+
# Filter out None values from the serialized object
115+
# Only include fields that were explicitly set (not default None values)
116+
dump = super().model_dump(**kwargs)
117+
# Get fields that were explicitly set (excluding default None values)
118+
fields_set = self.model_fields_set
119+
return {k: v for k, v in dump.items() if k in fields_set}
120+
121+
patched_model = create_model(
122+
f"Patched{wrapped_class.__name__}",
123+
__base__=PatchModel,
124+
**fields
125+
)
126+
127+
# Pass the originally_optional_fields to the patched model
128+
setattr(patched_model, "_originally_optional_fields", originally_optional_fields)
129+
130+
# Fix return type by using explicit cast to match the declared return type
131+
return patched_model # type: ignore
132+

tests/test_patch_schema.py

+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from typing import Optional, cast
2+
3+
import pytest
4+
from pydantic import BaseModel, ValidationError
5+
6+
from ninja import NinjaAPI, Schema
7+
from ninja.patch_schema import PatchSchema, PatchedModel
8+
from ninja.testing import TestClient
9+
10+
11+
class UserSchema(BaseModel):
12+
name: str
13+
email: str
14+
age: int
15+
avatar_url: Optional[str] = None
16+
17+
18+
class UserSchemaWithDefault(BaseModel):
19+
name: str = "Default"
20+
email: str
21+
age: int
22+
avatar_url: Optional[str] = None
23+
24+
25+
def test_patch_schema_basic():
26+
PatchUserSchema = PatchSchema[UserSchema]
27+
28+
# Test empty initialization
29+
patch_data = cast(PatchedModel, PatchUserSchema())
30+
assert patch_data.model_dump() == {}
31+
32+
# Test field initialization
33+
patch_data = cast(PatchedModel, PatchUserSchema(name="New Name"))
34+
assert patch_data.model_dump() == {"name": "New Name"}
35+
36+
# Test multiple fields
37+
patch_data = cast(PatchedModel, PatchUserSchema(name="New Name", email="[email protected]"))
38+
assert patch_data.model_dump() == {"name": "New Name", "email": "[email protected]"}
39+
40+
41+
def test_patch_schema_with_optional_fields():
42+
PatchUserSchema = PatchSchema[UserSchema]
43+
44+
# Test setting optional field to None
45+
patch_data = cast(PatchedModel, PatchUserSchema(avatar_url=None))
46+
assert patch_data.model_dump() == {"avatar_url": None}
47+
48+
# Test setting optional field to value
49+
patch_data = cast(PatchedModel, PatchUserSchema(avatar_url="https://example.com/avatar.png"))
50+
assert patch_data.model_dump() == {"avatar_url": "https://example.com/avatar.png"}
51+
52+
53+
def test_patch_schema_none_validation():
54+
PatchUserSchema = PatchSchema[UserSchema]
55+
56+
# Non-optional fields should not allow None
57+
with pytest.raises(ValidationError) as exc_info:
58+
PatchUserSchema(name=None)
59+
60+
assert "Field 'name' cannot be None" in str(exc_info.value)
61+
62+
with pytest.raises(ValidationError) as exc_info:
63+
PatchUserSchema(email=None)
64+
65+
assert "Field 'email' cannot be None" in str(exc_info.value)
66+
67+
68+
def test_patch_schema_with_defaults():
69+
PatchUserSchemaWithDefault = PatchSchema[UserSchemaWithDefault]
70+
71+
# Default values should not be included in the output unless explicitly set
72+
patch_data = cast(PatchedModel, PatchUserSchemaWithDefault())
73+
assert patch_data.model_dump() == {}
74+
75+
patch_data = cast(PatchedModel, PatchUserSchemaWithDefault(name="Custom Name"))
76+
assert patch_data.model_dump() == {"name": "Custom Name"}
77+
78+
79+
# API integration tests
80+
api = NinjaAPI()
81+
client = TestClient(api)
82+
83+
84+
class UserSchemaAPI(Schema):
85+
name: str
86+
email: str
87+
age: int
88+
avatar_url: Optional[str] = None
89+
90+
91+
@api.post("/users")
92+
def create_user(request, data: UserSchemaAPI):
93+
return data
94+
95+
96+
@api.patch("/users/{user_id}")
97+
def update_user(request, user_id: int, data: PatchSchema[UserSchemaAPI]):
98+
# Return the data and its type to verify it's working correctly
99+
return {"id": user_id, "data": data.model_dump(), "type": str(type(data).__name__)}
100+
101+
102+
def test_api_integration():
103+
# First create a user
104+
create_response = client.post(
105+
"/users",
106+
json={"name": "Test User", "email": "[email protected]", "age": 30}
107+
)
108+
assert create_response.status_code == 200
109+
110+
# Test partial update with patch
111+
patch_response = client.patch(
112+
"/users/1",
113+
json={"name": "Updated Name"}
114+
)
115+
assert patch_response.status_code == 200
116+
assert patch_response.json() == {
117+
"id": 1,
118+
"data": {"name": "Updated Name"},
119+
"type": f"Patched{UserSchemaAPI.__name__}"
120+
}
121+
122+
# Test multiple fields update
123+
patch_response = client.patch(
124+
"/users/1",
125+
json={"name": "New Name", "age": 31}
126+
)
127+
assert patch_response.status_code == 200
128+
assert patch_response.json() == {
129+
"id": 1,
130+
"data": {"name": "New Name", "age": 31},
131+
"type": f"Patched{UserSchemaAPI.__name__}"
132+
}
133+
134+
# Test optional field set to null
135+
patch_response = client.patch(
136+
"/users/1",
137+
json={"avatar_url": None}
138+
)
139+
assert patch_response.status_code == 200
140+
assert patch_response.json() == {
141+
"id": 1,
142+
"data": {"avatar_url": None},
143+
"type": f"Patched{UserSchemaAPI.__name__}"
144+
}
145+
146+
# Test validation error when setting non-optional field to null
147+
error_response = client.patch(
148+
"/users/1",
149+
json={"name": None}
150+
)
151+
assert error_response.status_code == 422 # Validation error
152+
153+
154+
def test_direct_instantiation_error():
155+
with pytest.raises(TypeError) as exc_info:
156+
PatchSchema()
157+
158+
assert "Cannot instantiate abstract PatchSchema class" in str(exc_info.value)
159+
160+
161+
def test_subclass_error():
162+
with pytest.raises(TypeError) as exc_info:
163+
class MyPatchSchema(PatchSchema):
164+
pass
165+
166+
assert "Cannot subclass" in str(exc_info.value)
167+
168+
169+
def test_openapi_schema():
170+
"""Test that the OpenAPI schema for a patched model is correctly generated."""
171+
schema = api.get_openapi_schema()
172+
patched_schema = schema["components"]["schemas"][f"Patched{UserSchemaAPI.__name__}"]
173+
174+
assert patched_schema["type"] == "object"
175+
assert "properties" in patched_schema
176+
177+
# Check that name is optional in the schema
178+
assert "name" in patched_schema["properties"]
179+
180+
# In Pydantic v2, optional fields use anyOf with multiple types including null
181+
name_prop = patched_schema["properties"]["name"]
182+
assert "anyOf" in name_prop
183+
assert any(item["type"] == "string" for item in name_prop["anyOf"])
184+
assert any(item["type"] == "null" for item in name_prop["anyOf"])
185+
186+
# No required fields in patched schema
187+
assert "required" not in patched_schema or "name" not in patched_schema["required"]
188+
189+
# Check that avatar_url is still optional
190+
assert "avatar_url" in patched_schema["properties"]
191+
avatar_prop = patched_schema["properties"]["avatar_url"]
192+
assert "anyOf" in avatar_prop
193+
assert any(item["type"] == "string" for item in avatar_prop["anyOf"])
194+
assert any(item["type"] == "null" for item in avatar_prop["anyOf"])

0 commit comments

Comments
 (0)