Skip to content

Commit f5ec283

Browse files
committed
Add PatchSchema
1 parent 932d22b commit f5ec283

File tree

2 files changed

+329
-0
lines changed

2 files changed

+329
-0
lines changed

ninja/patch_schema.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from typing import (
2+
Any,
3+
Dict,
4+
Generic,
5+
Optional,
6+
Type,
7+
TypeVar,
8+
Union,
9+
get_args,
10+
get_origin,
11+
)
12+
13+
from pydantic import BaseModel, create_model, model_validator
14+
15+
ModelT = TypeVar("ModelT", bound=BaseModel)
16+
17+
# Type alias for patched models to help with type checking
18+
# This allows using cast(PatchedModel, PatchSchema[SomeModel]()) to properly type the model_dump() method
19+
PatchedModel = BaseModel
20+
21+
22+
class PatchSchema(Generic[ModelT]):
23+
"""Generate a patchable version of a Pydantic model.
24+
25+
Makes all fields optional, but doesn't allow None values unless the field was originally defined as Optional.
26+
This allows for partial updates where fields can be omitted or provided with legitimate values.
27+
28+
Example:
29+
Given a schema:
30+
class ExampleSchema(BaseModel):
31+
example_field: str
32+
optional_field: Optional[str] = None
33+
34+
PatchSchema[ExampleSchema] will allow:
35+
- PatchSchema[ExampleSchema]() (no fields provided)
36+
- PatchSchema[ExampleSchema](example_field="example") (field provided)
37+
- PatchSchema[ExampleSchema](optional_field=None) (None allowed for originally optional fields)
38+
39+
But will not allow:
40+
- PatchSchema[ExampleSchema](example_field=None) (None is not allowed for non-optional fields)
41+
42+
Usage:
43+
# Define a regular schema
44+
class UserSchema(BaseModel):
45+
name: str
46+
email: str
47+
avatar_url: Optional[str] = None
48+
49+
# Create a patchable version that allows partial updates
50+
PatchUserSchema = PatchSchema[UserSchema]
51+
52+
# Use the patched schema for partial updates
53+
patch_data = PatchUserSchema(name="New Name") # Only updates the name field
54+
patch_data = PatchUserSchema(avatar_url=None) # Can set avatar_url to None since it's optional
55+
"""
56+
57+
def __new__(
58+
cls,
59+
*args: Any,
60+
**kwargs: Any,
61+
) -> "PatchSchema[ModelT]":
62+
"""Cannot instantiate directly."""
63+
raise TypeError("Cannot instantiate abstract PatchSchema class.")
64+
65+
def __init_subclass__(
66+
cls,
67+
*args: Any,
68+
**kwargs: Any,
69+
) -> None:
70+
"""Cannot subclass."""
71+
raise TypeError(f"Cannot subclass {cls.__module__}.PatchSchema")
72+
73+
@classmethod
74+
def _is_optional_type(cls, annotation):
75+
"""Check if a type annotation is Optional[X]."""
76+
if get_origin(annotation) is Union:
77+
args = get_args(annotation)
78+
return type(None) in args
79+
return False
80+
81+
@classmethod
82+
def __class_getitem__(
83+
cls,
84+
wrapped_class: Type[ModelT],
85+
) -> Type[PatchedModel]:
86+
"""Convert model to a patchable model where fields are optional but can't be None unless originally Optional."""
87+
88+
# Create field definitions for the new model
89+
fields = {}
90+
# Keep track of which fields were originally optional
91+
originally_optional_fields = set()
92+
93+
# Access model_fields through instance property, not class attribute
94+
model_fields = getattr(wrapped_class, "model_fields", {})
95+
for field_name, field_info in model_fields.items():
96+
# Make the field optional by setting a default value
97+
annotation = field_info.annotation
98+
99+
# Check if the field was originally optional
100+
if cls._is_optional_type(annotation):
101+
originally_optional_fields.add(field_name)
102+
103+
fields[field_name] = (Optional[annotation], None)
104+
105+
# Create the new model class
106+
class PatchModel(BaseModel):
107+
model_config = {"extra": "ignore", "arbitrary_types_allowed": True}
108+
109+
@model_validator(mode="before")
110+
@classmethod
111+
def validate_no_none_values(cls, data):
112+
if isinstance(data, dict):
113+
# Check for explicit None values and raise error for non-optional fields
114+
for key, value in list(data.items()):
115+
if value is None and key not in originally_optional_fields:
116+
raise ValueError(f"Field '{key}' cannot be None")
117+
# Keep only non-None values
118+
return {
119+
k: v
120+
for k, v in data.items()
121+
if v is not None or k in originally_optional_fields
122+
}
123+
return data
124+
125+
# We don't need a custom schema generator anymore since Pydantic v2 uses anyOf for optional fields
126+
127+
def model_dump(self, **kwargs) -> Dict[str, Any]:
128+
# Filter out None values from the serialized object
129+
# Only include fields that were explicitly set (not default None values)
130+
dump = super().model_dump(**kwargs)
131+
# Get fields that were explicitly set (excluding default None values)
132+
fields_set = self.model_fields_set
133+
return {k: v for k, v in dump.items() if k in fields_set}
134+
135+
patched_model = create_model(
136+
f"Patched{wrapped_class.__name__}", __base__=PatchModel, **fields
137+
)
138+
139+
# Pass the originally_optional_fields to the patched model
140+
patched_model._originally_optional_fields = originally_optional_fields
141+
142+
# Fix return type by using explicit cast to match the declared return type
143+
return patched_model # type: ignore

tests/test_patch_schema.py

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from typing import Optional, cast
2+
3+
import pytest
4+
from pydantic import BaseModel, ValidationError
5+
6+
from ninja import NinjaAPI, Schema
7+
from ninja.patch_schema import PatchedModel, PatchSchema
8+
from ninja.testing import TestClient
9+
10+
11+
class UserSchema(BaseModel):
12+
name: str
13+
email: str
14+
age: int
15+
avatar_url: Optional[str] = None
16+
17+
18+
class UserSchemaWithDefault(BaseModel):
19+
name: str = "Default"
20+
email: str
21+
age: int
22+
avatar_url: Optional[str] = None
23+
24+
25+
def test_patch_schema_basic():
26+
PatchUserSchema = PatchSchema[UserSchema]
27+
28+
# Test empty initialization
29+
patch_data = cast(PatchedModel, PatchUserSchema())
30+
assert patch_data.model_dump() == {}
31+
32+
# Test field initialization
33+
patch_data = cast(PatchedModel, PatchUserSchema(name="New Name"))
34+
assert patch_data.model_dump() == {"name": "New Name"}
35+
36+
# Test multiple fields
37+
patch_data = cast(
38+
PatchedModel, PatchUserSchema(name="New Name", email="[email protected]")
39+
)
40+
assert patch_data.model_dump() == {"name": "New Name", "email": "[email protected]"}
41+
42+
43+
def test_patch_schema_with_optional_fields():
44+
PatchUserSchema = PatchSchema[UserSchema]
45+
46+
# Test setting optional field to None
47+
patch_data = cast(PatchedModel, PatchUserSchema(avatar_url=None))
48+
assert patch_data.model_dump() == {"avatar_url": None}
49+
50+
# Test setting optional field to value
51+
patch_data = cast(
52+
PatchedModel, PatchUserSchema(avatar_url="https://example.com/avatar.png")
53+
)
54+
assert patch_data.model_dump() == {"avatar_url": "https://example.com/avatar.png"}
55+
56+
57+
def test_patch_schema_none_validation():
58+
PatchUserSchema = PatchSchema[UserSchema]
59+
60+
# Non-optional fields should not allow None
61+
with pytest.raises(ValidationError) as exc_info:
62+
PatchUserSchema(name=None)
63+
64+
assert "Field 'name' cannot be None" in str(exc_info.value)
65+
66+
with pytest.raises(ValidationError) as exc_info:
67+
PatchUserSchema(email=None)
68+
69+
assert "Field 'email' cannot be None" in str(exc_info.value)
70+
71+
72+
def test_patch_schema_with_defaults():
73+
PatchUserSchemaWithDefault = PatchSchema[UserSchemaWithDefault]
74+
75+
# Default values should not be included in the output unless explicitly set
76+
patch_data = cast(PatchedModel, PatchUserSchemaWithDefault())
77+
assert patch_data.model_dump() == {}
78+
79+
patch_data = cast(PatchedModel, PatchUserSchemaWithDefault(name="Custom Name"))
80+
assert patch_data.model_dump() == {"name": "Custom Name"}
81+
82+
83+
# API integration tests
84+
api = NinjaAPI()
85+
client = TestClient(api)
86+
87+
88+
class UserSchemaAPI(Schema):
89+
name: str
90+
email: str
91+
age: int
92+
avatar_url: Optional[str] = None
93+
94+
95+
@api.post("/users")
96+
def create_user(request, data: UserSchemaAPI):
97+
return data
98+
99+
100+
@api.patch("/users/{user_id}")
101+
def update_user(request, user_id: int, data: PatchSchema[UserSchemaAPI]):
102+
# Return the data and its type to verify it's working correctly
103+
return {"id": user_id, "data": data.model_dump(), "type": str(type(data).__name__)}
104+
105+
106+
def test_api_integration():
107+
# First create a user
108+
create_response = client.post(
109+
"/users", json={"name": "Test User", "email": "[email protected]", "age": 30}
110+
)
111+
assert create_response.status_code == 200
112+
113+
# Test partial update with patch
114+
patch_response = client.patch("/users/1", json={"name": "Updated Name"})
115+
assert patch_response.status_code == 200
116+
assert patch_response.json() == {
117+
"id": 1,
118+
"data": {"name": "Updated Name"},
119+
"type": f"Patched{UserSchemaAPI.__name__}",
120+
}
121+
122+
# Test multiple fields update
123+
patch_response = client.patch("/users/1", json={"name": "New Name", "age": 31})
124+
assert patch_response.status_code == 200
125+
assert patch_response.json() == {
126+
"id": 1,
127+
"data": {"name": "New Name", "age": 31},
128+
"type": f"Patched{UserSchemaAPI.__name__}",
129+
}
130+
131+
# Test optional field set to null
132+
patch_response = client.patch("/users/1", json={"avatar_url": None})
133+
assert patch_response.status_code == 200
134+
assert patch_response.json() == {
135+
"id": 1,
136+
"data": {"avatar_url": None},
137+
"type": f"Patched{UserSchemaAPI.__name__}",
138+
}
139+
140+
# Test validation error when setting non-optional field to null
141+
error_response = client.patch("/users/1", json={"name": None})
142+
assert error_response.status_code == 422 # Validation error
143+
144+
145+
def test_direct_instantiation_error():
146+
with pytest.raises(TypeError) as exc_info:
147+
PatchSchema()
148+
149+
assert "Cannot instantiate abstract PatchSchema class" in str(exc_info.value)
150+
151+
152+
def test_subclass_error():
153+
with pytest.raises(TypeError) as exc_info:
154+
155+
class MyPatchSchema(PatchSchema):
156+
pass
157+
158+
assert "Cannot subclass" in str(exc_info.value)
159+
160+
161+
def test_openapi_schema():
162+
"""Test that the OpenAPI schema for a patched model is correctly generated."""
163+
schema = api.get_openapi_schema()
164+
patched_schema = schema["components"]["schemas"][f"Patched{UserSchemaAPI.__name__}"]
165+
166+
assert patched_schema["type"] == "object"
167+
assert "properties" in patched_schema
168+
169+
# Check that name is optional in the schema
170+
assert "name" in patched_schema["properties"]
171+
172+
# In Pydantic v2, optional fields use anyOf with multiple types including null
173+
name_prop = patched_schema["properties"]["name"]
174+
assert "anyOf" in name_prop
175+
assert any(item["type"] == "string" for item in name_prop["anyOf"])
176+
assert any(item["type"] == "null" for item in name_prop["anyOf"])
177+
178+
# No required fields in patched schema
179+
assert "required" not in patched_schema or "name" not in patched_schema["required"]
180+
181+
# Check that avatar_url is still optional
182+
assert "avatar_url" in patched_schema["properties"]
183+
avatar_prop = patched_schema["properties"]["avatar_url"]
184+
assert "anyOf" in avatar_prop
185+
assert any(item["type"] == "string" for item in avatar_prop["anyOf"])
186+
assert any(item["type"] == "null" for item in avatar_prop["anyOf"])

0 commit comments

Comments
 (0)