From f5ec28384c851e25daeb1832e5bc259956a6ffda Mon Sep 17 00:00:00 2001 From: Adriaan Mulder Date: Thu, 17 Apr 2025 08:07:43 -0700 Subject: [PATCH] Add PatchSchema --- ninja/patch_schema.py | 143 ++++++++++++++++++++++++++++ tests/test_patch_schema.py | 186 +++++++++++++++++++++++++++++++++++++ 2 files changed, 329 insertions(+) create mode 100644 ninja/patch_schema.py create mode 100644 tests/test_patch_schema.py diff --git a/ninja/patch_schema.py b/ninja/patch_schema.py new file mode 100644 index 000000000..a63221b7c --- /dev/null +++ b/ninja/patch_schema.py @@ -0,0 +1,143 @@ +from typing import ( + Any, + Dict, + Generic, + Optional, + Type, + TypeVar, + Union, + get_args, + get_origin, +) + +from pydantic import BaseModel, create_model, model_validator + +ModelT = TypeVar("ModelT", bound=BaseModel) + +# Type alias for patched models to help with type checking +# This allows using cast(PatchedModel, PatchSchema[SomeModel]()) to properly type the model_dump() method +PatchedModel = BaseModel + + +class PatchSchema(Generic[ModelT]): + """Generate a patchable version of a Pydantic model. + + Makes all fields optional, but doesn't allow None values unless the field was originally defined as Optional. + This allows for partial updates where fields can be omitted or provided with legitimate values. + + Example: + Given a schema: + class ExampleSchema(BaseModel): + example_field: str + optional_field: Optional[str] = None + + PatchSchema[ExampleSchema] will allow: + - PatchSchema[ExampleSchema]() (no fields provided) + - PatchSchema[ExampleSchema](example_field="example") (field provided) + - PatchSchema[ExampleSchema](optional_field=None) (None allowed for originally optional fields) + + But will not allow: + - PatchSchema[ExampleSchema](example_field=None) (None is not allowed for non-optional fields) + + Usage: + # Define a regular schema + class UserSchema(BaseModel): + name: str + email: str + avatar_url: Optional[str] = None + + # Create a patchable version that allows partial updates + PatchUserSchema = PatchSchema[UserSchema] + + # Use the patched schema for partial updates + patch_data = PatchUserSchema(name="New Name") # Only updates the name field + patch_data = PatchUserSchema(avatar_url=None) # Can set avatar_url to None since it's optional + """ + + def __new__( + cls, + *args: Any, + **kwargs: Any, + ) -> "PatchSchema[ModelT]": + """Cannot instantiate directly.""" + raise TypeError("Cannot instantiate abstract PatchSchema class.") + + def __init_subclass__( + cls, + *args: Any, + **kwargs: Any, + ) -> None: + """Cannot subclass.""" + raise TypeError(f"Cannot subclass {cls.__module__}.PatchSchema") + + @classmethod + def _is_optional_type(cls, annotation): + """Check if a type annotation is Optional[X].""" + if get_origin(annotation) is Union: + args = get_args(annotation) + return type(None) in args + return False + + @classmethod + def __class_getitem__( + cls, + wrapped_class: Type[ModelT], + ) -> Type[PatchedModel]: + """Convert model to a patchable model where fields are optional but can't be None unless originally Optional.""" + + # Create field definitions for the new model + fields = {} + # Keep track of which fields were originally optional + originally_optional_fields = set() + + # Access model_fields through instance property, not class attribute + model_fields = getattr(wrapped_class, "model_fields", {}) + for field_name, field_info in model_fields.items(): + # Make the field optional by setting a default value + annotation = field_info.annotation + + # Check if the field was originally optional + if cls._is_optional_type(annotation): + originally_optional_fields.add(field_name) + + fields[field_name] = (Optional[annotation], None) + + # Create the new model class + class PatchModel(BaseModel): + model_config = {"extra": "ignore", "arbitrary_types_allowed": True} + + @model_validator(mode="before") + @classmethod + def validate_no_none_values(cls, data): + if isinstance(data, dict): + # Check for explicit None values and raise error for non-optional fields + for key, value in list(data.items()): + if value is None and key not in originally_optional_fields: + raise ValueError(f"Field '{key}' cannot be None") + # Keep only non-None values + return { + k: v + for k, v in data.items() + if v is not None or k in originally_optional_fields + } + return data + + # We don't need a custom schema generator anymore since Pydantic v2 uses anyOf for optional fields + + def model_dump(self, **kwargs) -> Dict[str, Any]: + # Filter out None values from the serialized object + # Only include fields that were explicitly set (not default None values) + dump = super().model_dump(**kwargs) + # Get fields that were explicitly set (excluding default None values) + fields_set = self.model_fields_set + return {k: v for k, v in dump.items() if k in fields_set} + + patched_model = create_model( + f"Patched{wrapped_class.__name__}", __base__=PatchModel, **fields + ) + + # Pass the originally_optional_fields to the patched model + patched_model._originally_optional_fields = originally_optional_fields + + # Fix return type by using explicit cast to match the declared return type + return patched_model # type: ignore diff --git a/tests/test_patch_schema.py b/tests/test_patch_schema.py new file mode 100644 index 000000000..7ad61e59d --- /dev/null +++ b/tests/test_patch_schema.py @@ -0,0 +1,186 @@ +from typing import Optional, cast + +import pytest +from pydantic import BaseModel, ValidationError + +from ninja import NinjaAPI, Schema +from ninja.patch_schema import PatchedModel, PatchSchema +from ninja.testing import TestClient + + +class UserSchema(BaseModel): + name: str + email: str + age: int + avatar_url: Optional[str] = None + + +class UserSchemaWithDefault(BaseModel): + name: str = "Default" + email: str + age: int + avatar_url: Optional[str] = None + + +def test_patch_schema_basic(): + PatchUserSchema = PatchSchema[UserSchema] + + # Test empty initialization + patch_data = cast(PatchedModel, PatchUserSchema()) + assert patch_data.model_dump() == {} + + # Test field initialization + patch_data = cast(PatchedModel, PatchUserSchema(name="New Name")) + assert patch_data.model_dump() == {"name": "New Name"} + + # Test multiple fields + patch_data = cast( + PatchedModel, PatchUserSchema(name="New Name", email="new@example.com") + ) + assert patch_data.model_dump() == {"name": "New Name", "email": "new@example.com"} + + +def test_patch_schema_with_optional_fields(): + PatchUserSchema = PatchSchema[UserSchema] + + # Test setting optional field to None + patch_data = cast(PatchedModel, PatchUserSchema(avatar_url=None)) + assert patch_data.model_dump() == {"avatar_url": None} + + # Test setting optional field to value + patch_data = cast( + PatchedModel, PatchUserSchema(avatar_url="https://example.com/avatar.png") + ) + assert patch_data.model_dump() == {"avatar_url": "https://example.com/avatar.png"} + + +def test_patch_schema_none_validation(): + PatchUserSchema = PatchSchema[UserSchema] + + # Non-optional fields should not allow None + with pytest.raises(ValidationError) as exc_info: + PatchUserSchema(name=None) + + assert "Field 'name' cannot be None" in str(exc_info.value) + + with pytest.raises(ValidationError) as exc_info: + PatchUserSchema(email=None) + + assert "Field 'email' cannot be None" in str(exc_info.value) + + +def test_patch_schema_with_defaults(): + PatchUserSchemaWithDefault = PatchSchema[UserSchemaWithDefault] + + # Default values should not be included in the output unless explicitly set + patch_data = cast(PatchedModel, PatchUserSchemaWithDefault()) + assert patch_data.model_dump() == {} + + patch_data = cast(PatchedModel, PatchUserSchemaWithDefault(name="Custom Name")) + assert patch_data.model_dump() == {"name": "Custom Name"} + + +# API integration tests +api = NinjaAPI() +client = TestClient(api) + + +class UserSchemaAPI(Schema): + name: str + email: str + age: int + avatar_url: Optional[str] = None + + +@api.post("/users") +def create_user(request, data: UserSchemaAPI): + return data + + +@api.patch("/users/{user_id}") +def update_user(request, user_id: int, data: PatchSchema[UserSchemaAPI]): + # Return the data and its type to verify it's working correctly + return {"id": user_id, "data": data.model_dump(), "type": str(type(data).__name__)} + + +def test_api_integration(): + # First create a user + create_response = client.post( + "/users", json={"name": "Test User", "email": "test@example.com", "age": 30} + ) + assert create_response.status_code == 200 + + # Test partial update with patch + patch_response = client.patch("/users/1", json={"name": "Updated Name"}) + assert patch_response.status_code == 200 + assert patch_response.json() == { + "id": 1, + "data": {"name": "Updated Name"}, + "type": f"Patched{UserSchemaAPI.__name__}", + } + + # Test multiple fields update + patch_response = client.patch("/users/1", json={"name": "New Name", "age": 31}) + assert patch_response.status_code == 200 + assert patch_response.json() == { + "id": 1, + "data": {"name": "New Name", "age": 31}, + "type": f"Patched{UserSchemaAPI.__name__}", + } + + # Test optional field set to null + patch_response = client.patch("/users/1", json={"avatar_url": None}) + assert patch_response.status_code == 200 + assert patch_response.json() == { + "id": 1, + "data": {"avatar_url": None}, + "type": f"Patched{UserSchemaAPI.__name__}", + } + + # Test validation error when setting non-optional field to null + error_response = client.patch("/users/1", json={"name": None}) + assert error_response.status_code == 422 # Validation error + + +def test_direct_instantiation_error(): + with pytest.raises(TypeError) as exc_info: + PatchSchema() + + assert "Cannot instantiate abstract PatchSchema class" in str(exc_info.value) + + +def test_subclass_error(): + with pytest.raises(TypeError) as exc_info: + + class MyPatchSchema(PatchSchema): + pass + + assert "Cannot subclass" in str(exc_info.value) + + +def test_openapi_schema(): + """Test that the OpenAPI schema for a patched model is correctly generated.""" + schema = api.get_openapi_schema() + patched_schema = schema["components"]["schemas"][f"Patched{UserSchemaAPI.__name__}"] + + assert patched_schema["type"] == "object" + assert "properties" in patched_schema + + # Check that name is optional in the schema + assert "name" in patched_schema["properties"] + + # In Pydantic v2, optional fields use anyOf with multiple types including null + name_prop = patched_schema["properties"]["name"] + assert "anyOf" in name_prop + assert any(item["type"] == "string" for item in name_prop["anyOf"]) + assert any(item["type"] == "null" for item in name_prop["anyOf"]) + + # No required fields in patched schema + assert "required" not in patched_schema or "name" not in patched_schema["required"] + + # Check that avatar_url is still optional + assert "avatar_url" in patched_schema["properties"] + avatar_prop = patched_schema["properties"]["avatar_url"] + assert "anyOf" in avatar_prop + assert any(item["type"] == "string" for item in avatar_prop["anyOf"]) + assert any(item["type"] == "null" for item in avatar_prop["anyOf"])